In [41]:
import torch
from torch import nn 
import numpy as np
import gym
import matplotlib.pyplot as plt
import time
import torch.nn.functional as F

In [42]:
class CEM(nn.Module):
    def __init__(self, state_dim, action_n, lr=0.01, eps=0.1):
        super().__init__()
        self.state_dim = state_dim
        self.action_n = action_n
        self.lr = lr
        self.eps = eps
        self.eps = eps

        self.network = nn.Sequential(
            nn.Linear(self.state_dim, 16), 
            nn.ReLU(),
            nn.Linear(16, 32), 
            nn.ReLU(),
            # nn.Linear(32, 16), 
            # nn.ReLU(),
            nn.Linear(32, self.action_n)
        )
        
        self.softmax = nn.Softmax()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
        self.loss = nn.CrossEntropyLoss()
        
    def forward(self, _input):
        return self.network(_input) 
    
    def get_action(self, state):
        state = torch.FloatTensor(state)
        logits = self.forward(state)
        action_prob = self.softmax(logits).detach().numpy()

        noise = np.ones(self.action_n) / self.action_n
        action_prob_noised = (1 - self.eps) * action_prob + self.eps * noise
        action_prob_noised = action_prob_noised / np.sum(action_prob_noised)
        # print(f'{action_prob} -> {action_prob_noised}')
        action = np.random.choice(self.action_n, p=action_prob_noised)
        return action
    
    def update_policy(self, elite_trajectories):
        elite_states = []
        elite_actions = []
        for trajectory in elite_trajectories:
            elite_states.extend(trajectory['states'])
            elite_actions.extend(trajectory['actions'])

        elite_states = np.array(elite_states)
        elite_actions = np.array(elite_actions)

        elite_states_tensor = torch.FloatTensor(elite_states)
        elite_actions_tensor = torch.LongTensor(elite_actions)

        loss = self.loss(self.forward(elite_states_tensor), elite_actions_tensor)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

In [43]:
def get_trajectory(env, agent, trajectory_len, visualize=False):
    trajectory = {'states':[], 'actions': [], 'total_reward': 0}
    
    state = env.reset()
    trajectory['states'].append(state)
    
    for _ in range(trajectory_len):
        
        action = agent.get_action(state)
        trajectory['actions'].append(action)
        
        state, reward, done, _ = env.step(action)
        trajectory['total_reward'] += reward
        
        if done:
            break
            
        if visualize:
            env.render()
            
        trajectory['states'].append(state)

    return trajectory

def get_elite_trajectories(trajectories, q_param):
    total_rewards = [trajectory['total_reward'] for trajectory in trajectories]
    quantile = np.quantile(total_rewards, q=q_param) 
    return [trajectory for trajectory in trajectories if trajectory['total_reward'] > quantile]

In [44]:
env = gym.make(
    "LunarLander-v2",
    continuous = False,
    gravity = -10.0,
    enable_wind = False,
    wind_power = 15.0,
    turbulence_power = 1.5)

state_dim = 8
action_n = 4

# env = gym.make('CartPole-v1')
# state_dim = 4
# action_n = 2

lr = 1e-2
eps = 1

agent = CEM(state_dim, action_n, lr=lr, eps=eps)
episode_n = 100
trajectory_n = 200
trajectory_len = 10000
q_param = 0.6

In [45]:
total_start = time.time()
for episode in range(1, episode_n + 1):
    agent.eps = np.sqrt(1 / episode)
    # start = time.time()
    trajectories = [get_trajectory(env, agent, trajectory_len) for _ in range(trajectory_n)]
    
    mean_total_reward = np.mean([trajectory['total_reward'] for trajectory in trajectories])
    print(f'episode: {episode}, mean_total_reward = {mean_total_reward}')
    
    elite_trajectories = get_elite_trajectories(trajectories, q_param)
    
    if len(elite_trajectories) > 0:
        agent.update_policy(elite_trajectories)

    # print('iteration_time: ', time.time() - start)
# get_trajectory(env, agent, trajectory_len, visualize=True)
print('total time', time.time() - total_start)

episode: 1, mean_total_reward = -178.51307056848768
episode: 2, mean_total_reward = -187.0018895309032
episode: 3, mean_total_reward = -184.9670750467903
episode: 4, mean_total_reward = -184.6199581170169
episode: 5, mean_total_reward = -175.29465608661428
episode: 6, mean_total_reward = -181.06930676324163
episode: 7, mean_total_reward = -174.50364442442267
episode: 8, mean_total_reward = -170.61394789474093
episode: 9, mean_total_reward = -163.40830295749782
episode: 10, mean_total_reward = -171.87550241168978
episode: 11, mean_total_reward = -153.38376066750382
episode: 12, mean_total_reward = -159.67496335517194
episode: 13, mean_total_reward = -150.7088109254332
episode: 14, mean_total_reward = -149.6710103386742
episode: 15, mean_total_reward = -158.75750766583292
episode: 16, mean_total_reward = -141.18040899351206
episode: 17, mean_total_reward = -146.59531646150225
episode: 18, mean_total_reward = -132.35889582322295
episode: 19, mean_total_reward = -130.1505402274524
episode: