In [1]:
import numpy as np
import torch
import torch.nn as nn
import gym
import time



In [2]:
env = gym.make('CartPole-v1')
state_dim = 4
acdtion_n = 2

In [3]:
class CEM(nn.Module):
    def __init__(self, state_dim, action_n):
        super().__init__()
        self.state_dim = state_dim
        self.action_n = action_n
        self.network = nn.Sequential(
            nn.Linear(self.state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, self.action_n)
        )
        self.softmax = nn.Softmax()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, _input):
        return self.network(_input)

    def get_action(self, state):
        state = torch.FloatTensor(state)
        logits = self.forward(state)
        probs = self.softmax(logits).detach().numpy()
        action = np.random.choice(self.action_n, p=probs)
        return action

    def update_policy(self, elite_trajectories):
        elite_states = []
        elite_actions = []
        for trajectory in elite_trajectories:
            elite_states.extend(trajectory['states'])
            elite_actions.extend(trajectory['actions'])
        elite_states = torch.FloatTensor(np.array(elite_states))
        elite_actions = torch.LongTensor(np.array(elite_actions))
        
        pred_actions = self.forward(elite_states)
        loss = self.loss(pred_actions, elite_actions)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

In [4]:
def get_trajectory(env, agent, max_steps, visualize=False):
    trajectory = {'states': [], 'actions': [], 'total_reward': 0}
    state = env.reset()
    trajectory['states'].append(state)

    for s in range(max_steps):
        action = agent.get_action(state)
        trajectory['actions'].append(action)

        state, reward, done, _ = env.step(action)
        trajectory['total_reward'] += reward

        if done:
            break

        if visualize:
            env.render()
            time.sleep(0.01)
            
        trajectory['states'].append(state)
    return trajectory

In [5]:
def get_elite_trajectories(trajectories, q_param):
    total_rewards = [trajectory['total_reward'] for trajectory in trajectories]
    quantile = np.quantile(total_rewards, q=q_param) 
    return [trajectory for trajectory in trajectories if trajectory['total_reward'] > quantile]

In [6]:
env = gym.make('CartPole-v1')
state_dim = 4
action_n = 2

agent = CEM(state_dim, action_n)
episode_n = 50
trajectory_n = 20
trajectory_len = 500
q_param = 0.8

for episode in range(episode_n):
    trajectories = [get_trajectory(env, agent, trajectory_len) for _ in range(trajectory_n)]
    
    mean_total_reward = np.mean([trajectory['total_reward'] for trajectory in trajectories])
    print(f'episode: {episode}, mean_total_reward = {mean_total_reward}')
    
    elite_trajectories = get_elite_trajectories(trajectories, q_param)
    
    if len(elite_trajectories) > 0:
        agent.update_policy(elite_trajectories)

  return self._call_impl(*args, **kwargs)


episode: 0, mean_total_reward = 23.55
episode: 1, mean_total_reward = 29.9
episode: 2, mean_total_reward = 29.25
episode: 3, mean_total_reward = 38.0
episode: 4, mean_total_reward = 40.85
episode: 5, mean_total_reward = 35.45
episode: 6, mean_total_reward = 42.1
episode: 7, mean_total_reward = 45.25
episode: 8, mean_total_reward = 44.25
episode: 9, mean_total_reward = 42.7
episode: 10, mean_total_reward = 52.5
episode: 11, mean_total_reward = 49.2
episode: 12, mean_total_reward = 59.15
episode: 13, mean_total_reward = 53.55
episode: 14, mean_total_reward = 65.8
episode: 15, mean_total_reward = 80.25
episode: 16, mean_total_reward = 65.0
episode: 17, mean_total_reward = 74.55
episode: 18, mean_total_reward = 87.25
episode: 19, mean_total_reward = 107.15
episode: 20, mean_total_reward = 89.75
episode: 21, mean_total_reward = 107.9
episode: 22, mean_total_reward = 109.45
episode: 23, mean_total_reward = 101.9
episode: 24, mean_total_reward = 133.0
episode: 25, mean_total_reward = 147.45
e