In [None]:
%matplotlib notebook

import gym
from logger import Plotter
import numpy as np
from replay import ReplayBuffer
import torch
from torch import nn
from torch.distributions import Categorical


plotter = Plotter('Return',
                  #'Length',
                  'Explained Variance',
                  'Loss',
                  'Epsilon')

# create environment
env = gym.make('CartPole-v1')
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n

BUFFER_SIZE = 10 ** 6
BATCH_SIZE = 100
replay_buffer = ReplayBuffer(BUFFER_SIZE, obs_dim, 1)

HIDDEN_SIZE = 64
def mlp(in_dim, act_dim, activation=nn.ReLU, output_activation=nn.Identity):
    return nn.Sequential(nn.Linear(in_dim, HIDDEN_SIZE),
                         activation(),
                         nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
                         activation(),
                         nn.Linear(HIDDEN_SIZE, act_dim),
                         output_activation())

V_MIN = 0
V_MAX = 102
ATOMS = torch.tensor(range(V_MIN, V_MAX))

class EpsGreedyActor(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.act_dim = act_dim
        self.q = mlp(obs_dim, act_dim * len(ATOMS))
        
    def forward(self, state):
        return self.q(state).reshape(-1, self.act_dim, len(ATOMS))
    
    def probs(self, state):
        logits = self.forward(state)
        return nn.functional.softmax(logits, dim=-1)

    def log_probs(self, state):
        logits = self.forward(state)
        return nn.functional.log_softmax(logits, dim=-1)

    def act(self, state, epsilon=0):
        if np.random.random() < epsilon:
            return np.random.choice(self.act_dim)
        q = (self.probs(state) * ATOMS).sum(dim=-1)
        return q.max(dim=-1).indices.item()
    
q_net = EpsGreedyActor(obs_dim, act_dim)
optimiser = torch.optim.Adam(q_net.q.parameters(), lr=5e-4)
epsilon = 1

target_net = EpsGreedyActor(obs_dim, act_dim)
for p in target_net.q.parameters():
    p.requires_grad = False
    
def update_target_net(source_net, target_net):
    with torch.no_grad():
        target_net.q.load_state_dict(source_net.q.state_dict())
update_target_net(q_net, target_net)

def preprocess(observations, actions):
    return torch.as_tensor(observations[-1], dtype=torch.float32)

last_q_net = EpsGreedyActor(obs_dim, act_dim)

GAMMA = 0.99
EPISODES = 100000
epoch = -1
for episode in range(EPISODES):
    observations = []
    actions = []

    obs = env.reset()
    observations.append(obs.copy())
    state = preprocess(observations, actions)

    done = False
    while not done:
        if epsilon > 0.02:
            epsilon -= 1e-6
        with torch.no_grad():
            action = q_net.act(state, epsilon)
        obs, reward, done, _ = env.step(action)

        observations.append(obs.copy())
        actions.append(action)

        last_state, state = state, preprocess(observations, actions)
        transition = (last_state, action, reward, state, done)
        replay_buffer.store(transition)

        # calculate loss
        batch = replay_buffer.sample(BATCH_SIZE)
        batch_states, batch_actions, batch_rewards, batch_next_states, batch_terminals = batch
        batch_actions = torch.as_tensor(batch_actions, dtype=torch.int64)

        log_probs = q_net.log_probs(batch_states).gather(dim=1, index=batch_actions.repeat(1, len(ATOMS)).unsqueeze(1)).squeeze()

        next_probs = target_net.probs(batch_next_states)
        next_actions = (next_probs * ATOMS).sum(dim=-1).max(dim=-1).indices
        next_action_probs = next_probs.gather(dim=1, index=next_actions.unsqueeze(1).repeat(1, len(ATOMS)).unsqueeze(1)).squeeze()

        r = batch_rewards.repeat(len(ATOMS), 1).T
        t = batch_terminals.repeat(len(ATOMS), 1).T
        T_z = (r + (1 - t) * GAMMA * ATOMS).clamp(V_MIN, V_MAX)
        b_j = (T_z - V_MIN) / 1
        l = b_j.floor().type(torch.int64)
        u = l + 1
        losses = -next_action_probs * ((u - b_j) * log_probs.gather(dim=1, index=l) +
                                       (b_j - l) * log_probs.gather(dim=1, index=u))
        loss = losses.sum(dim=1).mean() 

        update_target_net(q_net, last_q_net)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        

    if episode % 10 == 0:
        update_target_net(q_net, target_net)

    if episode % 50:
        continue

    epoch = episode
    
    NUM_TEST_EPISODES = 8
    lengths = []
    returns = []
    explained_variances = []
    for _ in range(NUM_TEST_EPISODES):
        vvalues = []
        rewards = []
        obs = env.reset()
        done = False
        while not done:
            obs = torch.as_tensor(obs, dtype=torch.float32)
            with torch.no_grad():
                q = (q_net.probs(torch.as_tensor(obs, dtype=torch.float32)) * ATOMS).sum(dim=-1).max(dim=-1)
            vvalues.append(q.values.item())
            action = q.indices.item()

            obs, reward, done, _ = env.step(action)
            rewards.append(reward)

        def cumulative(iter, discount):
            c = iter.clone().detach()
            for i in reversed(range(len(c) - 1)):
                c[i] += discount * c[i + 1]
            return c

        test_rewards_to_go = cumulative(torch.as_tensor(rewards), GAMMA)
        explained_variance = 1 - (np.array(vvalues) - np.array(test_rewards_to_go)).var() / np.array(test_rewards_to_go).var()
        explained_variances.append(explained_variance)

        returns.append(sum(rewards))

    plotter.update(epoch,
                   (np.mean(returns), min(returns), max(returns)),
                   np.mean(explained_variance),
                   loss.item(),
                   epsilon)

In [None]:
from matplotlib import pyplot as plt

ax = plt.axes(label='b')
ax.set_xlim(V_MIN, V_MAX)
m = 0

obs = env.reset()
try:
    while True:
        env.render()
        with torch.no_grad():
            probs = target_net.probs(torch.as_tensor(obs, dtype=torch.float32)).squeeze()
            action = target_net.act(torch.as_tensor(obs, dtype=torch.float32))
            
        m = max(m, probs.max())
        ax.set_ylim(0, m)
        if ax.lines:
            for line, d in zip(ax.lines, probs):
                line.set_data(ATOMS, d)
        else:
            for d in probs:
                ax.plot(ATOMS, d)
        ax.figure.canvas.draw()
        
        obs, _, done, _ = env.step(action)

        if done:
            obs = env.reset()
except (Exception, KeyboardInterrupt) as e:
    env.close()
    raise e