In [573]:
from collections import deque
import random

import attr
import gym
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
from pytorch_monitor import init_experiment, monitor_module
# from smooth import smooth  # timeseries smoothing function
import torch
from torch import nn
import torch.nn.functional as F

# pin rng seeds:
random.seed(0)
np.random.seed(0)  
cartpole = gym.make('CartPole-v1')
lunarlander = gym.make('LunarLander-v2')
plt.style.use('seaborn-white')
dtype = torch.float
cuda = 'cpu'

monitor_config = {
    'title':'Test Monitor',
    'log_dir':'test',
    'random_seed': 0
}
writer, _ = init_experiment(monitor_config)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [546]:
@attr.s
class Memory(deque):
    """ Experience Replay Memory class. """
    size = attr.ib()
    minibatch_size = attr.ib()

    def append(self, thing):
        if len(self) > self.size - 1:
            self.popleft()
        return super().append(thing)

    def sample(self):
        batch_size = min(len(self), self.minibatch_size)
        return random.sample(self, batch_size)


@attr.s
class ValueDistribution(torch.nn.Module):
    state_shape = attr.ib()
    action_shape = attr.ib()
    vmin = attr.ib()
    vmax = attr.ib()
    num_atoms = attr.ib(default=51)
    num_hidden1_units = attr.ib(default=128)
    num_hidden2_units = attr.ib(default=128)
    
    def __attrs_post_init__(self):
        super().__init__()
        self.atoms = torch.linspace(self.vmin, self.vmax, self.num_atoms)
        self.linear1 = nn.Linear(self.state_shape, self.num_hidden1_units)
        self.linear2 = nn.Linear(self.num_hidden1_units, self.num_hidden2_units)
        self.out = nn.Linear(self.num_hidden2_units, self.action_shape * self.num_atoms)
        
    def forward(self, x):
        """ Return (actions x atoms). """
        x1 = F.relu(self.linear1(x))
        x2 = F.relu(self.linear2(x1))
        x3 = F.relu(self.out(x2))
        x3 = x3.reshape(-1, self.action_shape, self.num_atoms)
        out = F.softmax(x3, dim=2)  # (actions x atoms)
        batch_size = x.size(0)
        assert out.size() == torch.Size((batch_size, self.action_shape, self.num_atoms))
        if hasattr(self, 'monitor'):
            self.monitor('x1', x1, track_data=True, track_grad=True)
            self.monitor('x2', x2, track_data=True, track_grad=True)
            self.monitor('x3', x3, track_data=True, track_grad=True)
            self.monitor('out', out, track_data=True, track_grad=True)
        return out
    
    def predict_action_values(self, states):
        """ Return (batch-size x actions). """
        distribution = self.forward(states)
        weighted_distribution = distribution * self.atoms
        out = weighted_distribution.sum(dim=2).squeeze()  # (batch-size x actions)
        dims = states.dim()
        if dims > 1:
            assert out.size() == torch.Size((batch_size, self.action_shape))
        else:
            assert out.size() == torch.Size((self.action_shape,))
        return out
    
    def batch_actions(self, states):
        values = self.predict_action_values(states)
        actions = values.argmax(1)
        assert actions.size(0) == states.size(0)
        return actions
    
    def single_action(self, state):
        values = self.predict_action_values(state)
        action = values.argmax()
        return action
        

In [542]:
states = torch.tensor([    
    [1.0, 2, 3, 4, 5, 6, 7, 8],
    [-4, -3, -1, -5, -5,  3,  1,  3],
    [ 4,  1, -3, -5,  1, -4, -2, -4],
])
actions = torch.tensor([0, 4, 1]) # num actions: 5
rewards = torch.tensor([-7.0,  4,  3,])
dones = torch.tensor([1, 0, 0], dtype=torch.long)
states_ = states + 1
many_transition = (states, actions, rewards, states_, dones, )

num_atoms = 11
vmin = -10
vmax = 10 
action_shape = 5

random.seed(0)
np.random.seed(0)  
online_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net.load_state_dict(online_net.state_dict())

'\none_state = torch.tensor([[1.0, 2, 3, 4, 5, 6, 7, 8]])\none_action = torch.tensor([3.0])\none_reward = torch.tensor([12.0])\none_done = torch.tensor([0.0])\none_state_ = one_state + 1\none_transition = (one_state, one_action, one_reward, one_state_, one_done)\n'

In [494]:
def categorical_loss(online_net, target_net, transitions, discount):
    states, actions, rewards, states_, dones = transitions    
    not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
    atoms = target_net.atoms
    probabilities = target_net.forward(states_)
    Q_x_ = (probabilities * atoms).sum(2)
    batch_size = states.shape[0]
    assert Q_x_.shape == torch.Size((batch_size, target_net.action_shape)), f'Got: {Q_x_.shape}, expected: {(batch_size, target_net.action_shape)}'
    a_star = Q_x_.argmax(dim=1) 
    assert a_star.shape == torch.Size((batch_size,)), f'Got {a_star.shape}, expected: ((batch_size,))'
    m = torch.zeros(batch_size, target_net.num_atoms, dtype=torch.float)

    delta_z = (target_net.vmax - target_net.vmin)/(target_net.num_atoms - 1)    
    # select only the probabilities distributions for the a_star actions:
    probabilities = probabilities[range(batch_size), a_star]
    #print(f'probabilities: {probabilities}')
    for idx, atom in enumerate(atoms):
        T_zj = (rewards + discount * atom * not_dones).clamp(min=target_net.vmin, max=target_net.vmax)        
        b_j = (T_zj - target_net.vmin)/delta_z
        lo = b_j.floor()
        hi = b_j.ceil()
        print(f'{idx}: {lo}, {hi}')        
        m[range(batch_size), lo.to(torch.long)] += probabilities[:, idx] * (hi - b_j)
        m[range(batch_size), hi.to(torch.long)] += probabilities[:, idx] * (b_j - lo)
        # print(f'{idx}: b_j: {b_j}, m: {m}')
    return m
    # return torch.sum(-m * (online_net.forward(states)[:, actions]).log())

In [513]:
def categorical_loss_vectorized(online_net, target_net, transitions, discount):
    # states, actions, rewards, states_, dones = many_transition    
    states, actions, rewards, states_, dones = transitions
    # not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
    not_dones = (1 - dones).type(torch.FloatTensor)
    atoms = target_net.atoms
    probabilities = target_net.forward(states_)
    Q_x_ = (probabilities * atoms).sum(2)
    batch_size = states.shape[0]
    assert Q_x_.shape == torch.Size((batch_size, target_net.action_shape)), f'Got: {Q_x_.shape}, expected: {(batch_size, target_net.action_shape)}'
    a_star = Q_x_.argmax(dim=1) 
    assert a_star.shape == torch.Size((batch_size,)), f'Got {a_star.shape}, expected: ((batch_size,))'

    delta_z = (target_net.vmax - target_net.vmin)/(target_net.num_atoms - 1)    
    # select only the probabilities distributions for the a_star actions:
    probabilities = probabilities[range(batch_size), a_star]
    T_zj = rewards.unsqueeze(1) + discount * atoms * not_dones.unsqueeze(1)
    b_j = (T_zj.clamp(target_net.vmin, target_net.vmax) - target_net.vmin) / delta_z  # correct
    lo = b_j.floor()
    hi = b_j.ceil()
    m = torch.zeros(batch_size, target_net.num_atoms, dtype=torch.float)
    row_idxs = torch.range(0, batch_size-1, dtype=torch.long).unsqueeze(1).expand_as(m)
    m[row_idxs, lo.type(torch.long)] += (probabilities * (hi - b_j))
    m[row_idxs, hi.type(torch.long)] += (probabilities * (b_j - lo))
    return (-(target_net.forward(states)[range(3), actions]).log() * m).mean()

In [476]:
# Speed and sanity checks:

discount = 0.95
%timeit mine = categorical_loss(online_net, target_net, many_transition, discount)
'--'
mine
'---- Vect:'

%timeit mine_vectorize = categorical_loss_vectorized(online_net, target_net, many_transition, discount)
'--'
mine_vectorize

# Higgs:

%timeit projection_distribution(target_net, states, rewards, dones.type(torch.float), discount)

# '--'
'Hengyuan-hu:'
agent = DistributionalDQNAgent(online_net, False, action_shape, num_atoms, online_net.vmin, online_net.vmax)
not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
agent.compute_targets(rewards, states_, not_dones, discount)

# ----
"Floringo's:"
%timeit m = floringo_agent._get_categorical(states_, rewards.unsqueeze(1), (1 - dones).unsqueeze(1))
m

1.7 ms ± 92.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


'--'

tensor([[ 0.0000,  0.5000,  0.5000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0604,  0.0929,  0.0914,  0.0855,  0.0846,
          0.0101,  0.1228,  0.0888,  0.0925],
        [ 0.0000,  0.0215,  0.0896,  0.0882,  0.0882,  0.1026,  0.1061,
          0.0963,  0.0946,  0.0882,  0.0294]])

'---- Vect:'

415 µs ± 214 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


'--'

tensor([[ 0.0000,  0.0539,  0.0539,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0604,  0.0929,  0.0914,  0.0855,  0.0846,
          0.0060,  0.1228,  0.0888,  0.0000],
        [ 0.0000,  0.0215,  0.0896,  0.0882,  0.0882,  0.1026,  0.1061,
          0.0963,  0.0946,  0.0882,  0.0000]])

276 µs ± 4.57 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


'Hengyuan-hu:'



tensor([[ 0.0000,  0.5000,  0.5000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0604,  0.0929,  0.0914,  0.0855,  0.0846,
          0.1053,  0.1228,  0.0888,  0.2682],
        [ 0.0000,  0.0215,  0.0896,  0.0882,  0.0882,  0.1026,  0.1061,
          0.0963,  0.0946,  0.0882,  0.2247]])