In [2]:
import numpy as np
import torch
import torch.nn as nn
import gym
import time



In [3]:
env = gym.make('CartPole-v1')
state_dim = 4
acdtion_n = 2

In [4]:
class CEM(nn.Module):
    def __init__(self, state_dim, action_n):
        super().__init__()
        self.state_dim = state_dim
        self.action_n = action_n
        self.network = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, action_n)
        )
        self.softmax = nn.Softmax()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, _input):
        return self.network(_input)

    def get_action(self, state):
        state = torch.FloatTensor(state)
        logits = self.forward(state)
        probs = self.softmax(logits).data.numpy()
        action = np.random.choice(self.action_n, p=probs)
        return action

    def fit(self, elite_trajectories):
        elite_states = []
        elite_actions = []
        for trajectory in elite_trajectories:
            for state, actions in zip(trajectory['states'], trajectory['actions']):
                elite_states.append(state)
                elite_actions.append(actions)
        elite_states = torch.FloatTensor(np.array(elite_states))
        elite_actions = torch.LongTensor(np.array(elite_actions))
        pred_actions = self.forward(elite_states)
        loss = self.loss(pred_actions, elite_actions)
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

In [5]:
def get_trajectory(env, agent, max_steps, visualize=False):
    state = env.reset()
    trajectory = {'states': [], 'actions': [], 'rewards': []}
    for s in range(max_steps):
        trajectory['states'].append(state)
        action = agent.get_action(state)
        state, reward, done, _ = env.step(action)
        trajectory['actions'].append(action)
        trajectory['rewards'].append(reward)
        if done:
            break
        if visualize:
            env.render()
            time.sleep(0.1)
    return trajectory

In [6]:
n_iterations = 100
max_steps = 500
trajectory_n = 50
q=0.8

In [7]:
agent = CEM(state_dim, acdtion_n)

for i in range(n_iterations):
    start = time.time()
    trajectories = [get_trajectory(env, agent, max_steps=max_steps) for _ in range(trajectory_n)]
    total_rewards = [np.sum(trajectory['rewards']) for trajectory in trajectories]
    if i % 10 == 0:
        print(f'iteration: {i}/{n_iterations}, mean_total_reward: {np.mean(total_rewards)}')

    quantile = np.quantile(total_rewards, q)
    elite_trajectories = []
    for t in trajectories:
        total_reward = np.sum(t['rewards'])
        if total_reward >= quantile:
            elite_trajectories.append(t)
    if len(elite_trajectories) == 0:
        print(total_rewards)
        print(quantile)
    agent.fit(elite_trajectories)
    print(f'fit time: {time.time() - start}')

  return self._call_impl(*args, **kwargs)


iteration: 0/100, mean_total_reward: 23.18
fit time: 0.34783291816711426
fit time: 0.3802661895751953
fit time: 0.36266350746154785
fit time: 0.43755125999450684
fit time: 0.4734065532684326
fit time: 0.5512897968292236
fit time: 0.5538144111633301
fit time: 0.6167526245117188
fit time: 0.6684765815734863
fit time: 0.6773273944854736
iteration: 10/100, mean_total_reward: 59.86
fit time: 0.7562265396118164
fit time: 0.8239850997924805
fit time: 0.7979583740234375
fit time: 0.9094693660736084
fit time: 0.8120217323303223
fit time: 0.9461963176727295
fit time: 0.9251337051391602
fit time: 1.0181291103363037
fit time: 1.2618491649627686
fit time: 1.355663776397705
iteration: 20/100, mean_total_reward: 101.86
fit time: 1.2522380352020264
fit time: 1.4808483123779297
fit time: 1.8177599906921387
fit time: 1.9115183353424072
fit time: 1.8199975490570068
fit time: 2.3346142768859863
fit time: 2.852109909057617



KeyboardInterrupt

