In [574]:
from collections import deque

import attr
import gym
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
from pytorch_monitor import init_experiment, monitor_module
# from smooth import smooth  # timeseries smoothing function
import torch
from torch import nn
import torch.nn.functional as F


cartpole = gym.make('CartPole-v1')
lunarlander = gym.make('LunarLander-v2')
plt.style.use('seaborn-white')
monitor_config = {
    'title':'Test Monitor',
    'log_dir':'test',
    'random_seed': 0
}
writer, _ = init_experiment(monitor_config)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [577]:
@attr.s
class Memory(deque):
    """ Experience Replay Memory class. """
    size = attr.ib()
    minibatch_size = attr.ib()

    def append(self, thing):
        if len(self) > self.size - 1:
            self.popleft()
        return super().append(thing)

    def sample(self):
        batch_size = min(len(self), self.minibatch_size)
        return random.sample(self, batch_size)


@attr.s
class ValueDistribution(torch.nn.Module):
    state_shape = attr.ib()
    action_shape = attr.ib()
    vmin = attr.ib()
    vmax = attr.ib()
    num_atoms = attr.ib(default=51)
    num_hidden1_units = attr.ib(default=128)
    num_hidden2_units = attr.ib(default=128)
    
    def __attrs_post_init__(self):
        super().__init__()
        self.atoms = torch.linspace(self.vmin, self.vmax, self.num_atoms)
        self.linear1 = nn.Linear(self.state_shape, self.num_hidden1_units)
        self.linear2 = nn.Linear(self.num_hidden1_units, self.num_hidden2_units)
        self.out = nn.Linear(self.num_hidden2_units, self.action_shape * self.num_atoms)
        
    def forward(self, x):
        """ Return (actions x atoms). """
        x1 = F.relu(self.linear1(x))
        x2 = F.relu(self.linear2(x1))
        x3 = F.relu(self.out(x2))
        x3 = x3.reshape(-1, self.action_shape, self.num_atoms)
        out = F.softmax(x3, dim=2)  # (actions x atoms)
        batch_size = x.size(0)
        assert out.size() == torch.Size((batch_size, self.action_shape, self.num_atoms))
        if hasattr(self, 'monitor'):
            self.monitor('x1', x1, track_data=True, track_grad=True)
            self.monitor('x2', x2, track_data=True, track_grad=True)
            self.monitor('x3', x3, track_data=True, track_grad=True)
            self.monitor('out', out, track_data=True, track_grad=True)
        return out
    
    def predict_action_values(self, states):
        """ Return (batch-size x actions). """
        distribution = self.forward(states)
        weighted_distribution = distribution * self.atoms
        out = weighted_distribution.sum(dim=2).squeeze()  # (batch-size x actions)
        dims = states.dim()
        if dims > 1:
            assert out.size() == torch.Size((batch_size, self.action_shape))
        else:
            assert out.size() == torch.Size((self.action_shape,))
        return out
    
    def batch_actions(self, states):
        values = self.predict_action_values(states)
        actions = values.argmax(1)
        assert actions.size(0) == states.size(0)
        return actions
    
    def single_action(self, state):
        values = self.predict_action_values(state)
        action = values.argmax()
        return action

In [576]:
online_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net.load_state_dict(online_net.state_dict())

In [513]:
def categorical_loss(online_net, target_net, transitions, discount):
    # states, actions, rewards, states_, dones = many_transition    
    states, actions, rewards, states_, dones = transitions
    # not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
    not_dones = (1 - dones).type(torch.FloatTensor)
    atoms = target_net.atoms
    probabilities = target_net.forward(states_)
    Q_x_ = (probabilities * atoms).sum(2)
    batch_size = states.shape[0]
    assert Q_x_.shape == torch.Size((batch_size, target_net.action_shape)), f'Got: {Q_x_.shape}, expected: {(batch_size, target_net.action_shape)}'
    a_star = Q_x_.argmax(dim=1) 
    assert a_star.shape == torch.Size((batch_size,)), f'Got {a_star.shape}, expected: ((batch_size,))'

    delta_z = (target_net.vmax - target_net.vmin)/(target_net.num_atoms - 1)    
    # select only the probabilities distributions for the a_star actions:
    probabilities = probabilities[range(batch_size), a_star]
    T_zj = rewards.unsqueeze(1) + discount * atoms * not_dones.unsqueeze(1)
    b_j = (T_zj.clamp(target_net.vmin, target_net.vmax) - target_net.vmin) / delta_z  # correct
    lo = b_j.floor()
    hi = b_j.ceil()
    m = torch.zeros(batch_size, target_net.num_atoms, dtype=torch.float)
    row_idxs = torch.range(0, batch_size-1, dtype=torch.long).unsqueeze(1).expand_as(m)
    m[row_idxs, lo.type(torch.long)] += (probabilities * (hi - b_j))
    m[row_idxs, hi.type(torch.long)] += (probabilities * (b_j - lo))
    return (-(target_net.forward(states)[range(3), actions]).log() * m).mean()