In [7]:
import gym
import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [19]:
env = gym.make('CartPole-v1')
state_dim = 4
action_n = 2


class CEM(nn.Module):
    def __init__(self, state_dim, action_n):
        super().__init__()
        self.state_dim = state_dim
        self.action_n = action_n
        self.network = nn.Sequential(nn.Linear(self.state_dim, 128), 
                                     nn.ReLU(),
                                     nn.Linear(128, self.action_n)
                                    )
        self.softmax = nn.Softmax()
        self.optimazer = torch.optim.Adam(self.parameters(), lr=1e-2)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, _input):
        return self.network(_input)
    
    def get_action(self, state):
        state = torch.FloatTensor(state)
        logits = self.forward(state)
        probs = self.softmax(logits).data.numpy()
        action = np.random.choice(self.action_n, p=probs)
        return action

    def fit(self, elite_trajectories):
        elite_states = []
        elite_actions = []
        for trajectory in elite_trajectories:
            for state, action in zip(trajectory['states'], trajectory['actions']):
                elite_states.append(state)
                elite_actions.append(action)
        elite_states = torch.FloatTensor(elite_states)
        elite_actions = torch.LongTensor(elite_actions)
        pred_actions = self.forward(elite_states)

        loss = self.loss(pred_actions, elite_actions)
        loss.backward()
        self.optimazer.step()
        self.optimazer.zero_grad()
        

def get_trajectory(env, agent, max_len=1000, visualize=False):
    trajectory = {'states': [], 'actions': [], 'rewards': []}

    state = env.reset()

    for _ in range(max_len):
        trajectory['states'].append(state)
        
        action = agent.get_action(state)
        trajectory['actions'].append(action)
        
        state, reward, done, _ = env.step(action)
        trajectory['rewards'].append(reward)

        if visualize:
            time.sleep(0.5)
            env.render()

        if done:
            break
    
    return trajectory


agent = CEM(state_dim, action_n)
q_param = 0.8
iteration_n = 100
trajectory_n = 20
trajectory_len = 500

for iteration in range(iteration_n):

    #policy evaluation
    trajectories = [get_trajectory(env, agent) for _ in range(trajectory_n)]
    total_rewards = [np.sum(trajectory['rewards']) for trajectory in trajectories]
    print('iteration:', iteration, 'mean total reward:', np.mean(total_rewards))

    #policy improvement
    quantile = np.quantile(total_rewards, q_param)
    elite_trajectories = []
    for trajectory in trajectories:
        total_reward = np.sum(trajectory['rewards'])
        if total_reward > quantile:
            elite_trajectories.append(trajectory)
    
    if len(elite_trajectories) > 0:
        agent.fit(elite_trajectories)

trajectory = get_trajectory(env, agent, max_len=500, visualize=False)
print('total reward:', sum(trajectory['rewards']))

  probs = self.softmax(logits).data.numpy()


iteration: 0 mean total reward: 26.6
iteration: 1 mean total reward: 27.4
iteration: 2 mean total reward: 24.8
iteration: 3 mean total reward: 28.7
iteration: 4 mean total reward: 30.45
iteration: 5 mean total reward: 30.2
iteration: 6 mean total reward: 41.4
iteration: 7 mean total reward: 46.7
iteration: 8 mean total reward: 54.6
iteration: 9 mean total reward: 57.55
iteration: 10 mean total reward: 44.4
iteration: 11 mean total reward: 50.5
iteration: 12 mean total reward: 51.9
iteration: 13 mean total reward: 66.8
iteration: 14 mean total reward: 51.25
iteration: 15 mean total reward: 66.35
iteration: 16 mean total reward: 64.85
iteration: 17 mean total reward: 77.9
iteration: 18 mean total reward: 70.65
iteration: 19 mean total reward: 64.95
iteration: 20 mean total reward: 61.05
iteration: 21 mean total reward: 80.9
iteration: 22 mean total reward: 74.35
iteration: 23 mean total reward: 91.25
iteration: 24 mean total reward: 100.1
iteration: 25 mean total reward: 113.25
iteration