In [2]:
import random

import attr
import gym
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
from smooth import smooth  # timeseries smoothing function
import torch
from torch import nn
import torch.nn.functional as F

# pin rng seeds:
random.seed(0)
np.random.seed(0)  
cartpole = gym.make('CartPole-v1')
lunarlander = gym.make('LunarLander-v2')
plt.style.use('seaborn-white')
dtype = torch.float
cuda = 'cpu'

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [3]:
@attr.s
class ValueDistribution(torch.nn.Module):
    state_shape = attr.ib()
    action_shape = attr.ib()
    vmin = attr.ib()
    vmax = attr.ib()
    num_atoms = attr.ib(default=51)
    num_hidden1_units = attr.ib(default=128)
    num_hidden2_units = attr.ib(default=128)
    
    def __attrs_post_init__(self):
        super().__init__()
        self.atoms = torch.linspace(self.vmin, self.vmax, self.num_atoms)
        self.linear1 = nn.Linear(self.state_shape, self.num_hidden1_units)
        self.linear2 = nn.Linear(self.num_hidden1_units, self.num_hidden2_units)
        self.out = nn.Linear(self.num_hidden2_units, self.action_shape * self.num_atoms)
        
    def forward(self, x):
        """ Return (actions x atoms). """
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.out(x))  
        x = x.reshape(-1, self.action_shape, self.num_atoms)   
        out = F.softmax(x, dim=2)  # (actions x atoms) 
        batch_size = x.size(0)
        assert out.size() == torch.Size((batch_size, self.action_shape, self.num_atoms))
        return out
    
    def predict_action_values(self, state):
        """ Return (batch-size x actions). """
        distribution = self.forward(state)
        weighted_distribution = distribution * self.atoms
        out = weighted_distribution.sum(dim=2).squeeze()  # (batch-size x actions)
        batch_size = state.size(0)
        assert out.size() == torch.Size((batch_size, self.action_shape))
        return out

In [317]:
states = torch.tensor([    
    [1.0, 2, 3, 4, 5, 6, 7, 8],
    [-4, -3, -1, -5, -5,  3,  1,  3],
    [ 4,  1, -3, -5,  1, -4, -2, -4],
])
actions = torch.tensor([0, 4, 1]) # num actions: 5
rewards = torch.tensor([-7.0,  4,  3,])
dones = torch.tensor([1, 0, 0], dtype=torch.long)
states_ = states + 1

'''
one_state = torch.tensor([[1.0, 2, 3, 4, 5, 6, 7, 8]])
one_action = torch.tensor([3.0])
one_reward = torch.tensor([12.0])
one_done = torch.tensor([0.0])
one_state_ = one_state + 1
one_transition = (one_state, one_action, one_reward, one_state_, one_done)
'''
many_transition = (states, actions, rewards, states_, dones, )

num_atoms = 11
vmin = -10
vmax = 10 
action_shape = 5

random.seed(0)
np.random.seed(0)  
online_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net.load_state_dict(online_net.state_dict())

'\none_state = torch.tensor([[1.0, 2, 3, 4, 5, 6, 7, 8]])\none_action = torch.tensor([3.0])\none_reward = torch.tensor([12.0])\none_done = torch.tensor([0.0])\none_state_ = one_state + 1\none_transition = (one_state, one_action, one_reward, one_state_, one_done)\n'

In [288]:
target_net.forward(states_).shape

torch.Size([3, 5, 11])

In [455]:
# Mine

def categorical_loss(online_net, target_net, transitions, discount):
    states, actions, rewards, states_, dones = transitions    
    not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
    atoms = target_net.atoms
    probabilities = target_net.forward(states_)
    Q_x_ = (probabilities * atoms).sum(2)
    batch_size = states.shape[0]
    assert Q_x_.shape == torch.Size((batch_size, target_net.action_shape)), f'Got: {Q_x_.shape}, expected: {(batch_size, target_net.action_shape)}'
    a_star = Q_x_.argmax(dim=1) 
    assert a_star.shape == torch.Size((batch_size,)), f'Got {a_star.shape}, expected: ((batch_size,))'
    m = torch.zeros(batch_size, target_net.num_atoms, dtype=torch.float)

    delta_z = (target_net.vmax - target_net.vmin)/(target_net.num_atoms - 1)    
    # select only the probabilities distributions for the a_star actions:
    probabilities = probabilities[range(batch_size), a_star]
    #print(f'probabilities: {probabilities}')
    for idx, atom in enumerate(atoms):
        T_zj = (rewards + discount * atom * not_dones).clamp(min=target_net.vmin, max=target_net.vmax)
        b_j = (T_zj - target_net.vmin)/delta_z

        lo = b_j.floor()
        hi = b_j.ceil()
        m[range(batch_size), lo.to(torch.long)] += probabilities[:, idx] * (hi - b_j)
        m[range(batch_size), hi.to(torch.long)] += probabilities[:, idx] * (b_j - lo)
        # print(f'{idx}: b_j: {b_j}, m: {m}')
    return m
    # return torch.sum(-m * (online_net.forward(states)[:, actions]).log())

In [454]:
def categorical_loss_vectorized(online_net, target_net, transitions, discount):
    # states, actions, rewards, states_, dones = many_transition    
    states, actions, rewards, states_, dones = transitions
    not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
    atoms = target_net.atoms
    probabilities = target_net.forward(states_)
    Q_x_ = (probabilities * atoms).sum(2)
    batch_size = states.shape[0]
    assert Q_x_.shape == torch.Size((batch_size, target_net.action_shape)), f'Got: {Q_x_.shape}, expected: {(batch_size, target_net.action_shape)}'
    a_star = Q_x_.argmax(dim=1) 
    assert a_star.shape == torch.Size((batch_size,)), f'Got {a_star.shape}, expected: ((batch_size,))'

    delta_z = (target_net.vmax - target_net.vmin)/(target_net.num_atoms - 1)    
    # select only the probabilities distributions for the a_star actions:
    probabilities = probabilities[range(batch_size), a_star]
    T_zj = rewards.unsqueeze(1) + discount * atoms * not_dones.unsqueeze(1)
    b_j = (T_zj.clamp(target_net.vmin, target_net.vmax) - target_net.vmin) / delta_z  # correct
    lo = b_j.floor()
    hi = b_j.ceil()
    m = torch.zeros(batch_size, target_net.num_atoms, dtype=torch.float)
    row_idxs = torch.range(0, batch_size-1, dtype=torch.long).unsqueeze(1).expand_as(m)
    m[row_idxs, lo.type(torch.long)] += probabilities * (hi - b_j)
    m[row_idxs, hi.type(torch.long)] += probabilities * (b_j - lo)
    return m

In [456]:
discount = 0.95
mine = categorical_loss(online_net, target_net, many_transition, discount)
'--'
mine
'---- Vect:'

mine_vectorize = categorical_loss_vectorized(online_net, target_net, many_transition, discount)
'--'
mine_vectorize

# '--'
'Hengyuan-hu:'
agent = DistributionalDQNAgent(online_net, False, action_shape, num_atoms, online_net.vmin, online_net.vmax)
not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
agent.compute_targets(rewards, states_, not_dones, discount)

'--'

tensor([[ 0.0000,  0.5000,  0.5000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0604,  0.0929,  0.0914,  0.0855,  0.0846,
          0.0101,  0.1228,  0.0888,  0.0925],
        [ 0.0000,  0.0215,  0.0896,  0.0882,  0.0882,  0.1026,  0.1061,
          0.0963,  0.0946,  0.0882,  0.0294]])

'---- Vect:'

'--'

tensor([[ 0.0000,  0.0539,  0.0539,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0604,  0.0929,  0.0914,  0.0855,  0.0846,
          0.0060,  0.1228,  0.0888,  0.0000],
        [ 0.0000,  0.0215,  0.0896,  0.0882,  0.0882,  0.1026,  0.1061,
          0.0963,  0.0946,  0.0882,  0.0000]])

'Hengyuan-hu:'



tensor([[ 0.0000,  0.5000,  0.5000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0604,  0.0929,  0.0914,  0.0855,  0.0846,
          0.1053,  0.1228,  0.0888,  0.2682],
        [ 0.0000,  0.0215,  0.0896,  0.0882,  0.0882,  0.1026,  0.1061,
          0.0963,  0.0946,  0.0882,  0.2247]])

In [371]:
# HiggsField's projection implmentation:
# these are mini-batch sized tensors of nexts, rewards, dones:
def projection_distribution(target_model, next_state, rewards, dones, discount):
    batch_size = next_state.size(0)
    Vmax = target_model.vmax
    Vmin = target_model.vmin
    num_atoms = target_model.num_atoms
    delta_z = float(Vmax - Vmin) / (num_atoms - 1)
    support = torch.linspace(Vmin, Vmax, num_atoms)

    next_dist = target_model(next_state).data.cpu() * support
    next_action = next_dist.sum(2).max(1)[1]
    next_action = next_action.unsqueeze(1).unsqueeze(1).expand(next_dist.size(0), 1, next_dist.size(2))
    next_dist = next_dist.gather(1, next_action).squeeze(1)

    rewards = rewards.unsqueeze(1).expand_as(next_dist)
    dones = dones.unsqueeze(1).expand_as(next_dist)
    support = support.unsqueeze(0).expand_as(next_dist)

    Tz = rewards + (1 - dones) * discount * support
    Tz = Tz.clamp(min=Vmin, max=Vmax)
    b = (Tz - Vmin) / delta_z
    l = b.floor().long()
    u = b.ceil().long()
    offset = torch.linspace(
        0,  # start
        (batch_size - 1) * num_atoms,  # end
        batch_size  # steps
    ).long(  # cast to int
    ).unsqueeze(1  # "new tensor with a dim of size one inserted at the specified position."
                   # Basically turns a row vect into a col vect.
                   # (batch_size, 1): [[0], [51], [102], [153]]
    ).expand(batch_size, num_atoms)  # (batch_size, 1) -> (batch_size, num_atoms), copying values.
    proj_dist = torch.zeros(next_dist.size())
    proj_dist.view(-1
        ).index_add_(
        0,
        (l + offset).view(-1),
        (next_dist * (u.float() - b)).view(-1)
    )
    proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))

    return proj_dist  # this is `m` in the paper.

In [372]:
import copy
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np


EPS = 1e-7


def assert_eq(real, expected):
    assert real == expected, '%s (true) vs %s (expected)' % (real, expected)


def one_hot(x, n):
    assert x.dim() == 2
    # one_hot_x = torch.zeros(x.size(0), n).cuda()
    one_hot_x = torch.zeros(x.size(0), n)
    one_hot_x.scatter_(1, x, 1)
    return one_hot_x


class DQNAgent(object):
    def __init__(self, q_net, double_dqn, num_actions):
        self.online_q_net = q_net
        self.target_q_net = copy.deepcopy(q_net)
        self.double_dqn = double_dqn
        self.num_actions = num_actions

    def save_q_net(self, prefix):
        torch.save(self.online_q_net.state_dict(), prefix+'online_q_net.pth')

    def parameters(self):
        return self.online_q_net.parameters()

    def sync_target(self):
        self.target_q_net = copy.deepcopy(self.online_q_net)

    def target_q_values(self, states):
        q_vals = self.target_q_net(Variable(states, volatile=True)).data
        return q_vals

    def online_q_values(self, states):
        q_vals = self.online_q_net(Variable(states, volatile=True)).data
        return q_vals

    def compute_targets(self, rewards, next_states, non_ends, gamma):
        """Compute batch of targets for dqn
        params:
            rewards: Tensor [batch]
            next_states: Tensor [batch, channel, w, h]
            non_ends: Tensor [batch]
            gamma: float
        """
        next_q_vals = self.target_q_values(next_states)

        if self.double_dqn:
            next_actions = self.online_q_values(next_states).max(1, True)[1]
            next_actions = one_hot(next_actions, self.num_actions)
            next_qs = (next_q_vals * next_actions).sum(1)
        else:
            next_qs = next_q_vals.max(1)[0] # max returns a pair

        targets = rewards + gamma * next_qs * non_ends
        return targets

    def loss(self, states, actions, targets):
        """
        params:
            states: Variable [batch, channel, w, h]
            actions: Variable [batch, num_actions] one hot encoding
            targets: Variable [batch]
        """
        assert_eq(actions.size(1), self.num_actions)

        qs = self.online_q_net(states)
        preds = (qs * actions).sum(1)
        err = nn.functional.smooth_l1_loss(preds, targets)
        return err


class DistributionalDQNAgent(DQNAgent):
    def __init__(self, q_net, double_dqn, num_actions, num_atoms, vmin, vmax):
        super(DistributionalDQNAgent, self).__init__(q_net, double_dqn, num_actions)

        self.num_atoms = num_atoms
        self.vmin = float(vmin)
        self.vmax = float(vmax)

        self.delta_z = (self.vmax - self.vmin) / (num_atoms - 1)

        zpoints = np.linspace(vmin, vmax, num_atoms).astype(np.float32)
        # self.zpoints = Variable(torch.from_numpy(zpoints).unsqueeze(0)).cuda()
        self.zpoints = Variable(torch.from_numpy(zpoints).unsqueeze(0))

    def _q_values(self, q_net, states):
        """internal function to compute q_value
        params:
            q_net: self.online_q_net or self.target_q_net
            states: Variable [batch, channel, w, h]
        """
        probs = q_net(states) # [batch, num_actions, num_atoms]
        q_vals = (probs * self.zpoints).sum(2)
        return q_vals, probs

    def target_q_values(self, states):
        states = Variable(states, volatile=True)
        q_vals, _ = self._q_values(self.target_q_net, states)
        return q_vals.data

    def online_q_values(self, states):
        states = Variable(states, volatile=True)
        q_vals, _ = self._q_values(self.online_q_net, states)
        return q_vals.data

    def compute_targets(self, rewards, next_states, non_ends, gamma):
        """Compute batch of targets for distributional dqn
        params:
            rewards: Tensor [batch, 1]
            next_states: Tensor [batch, channel, w, h]
            non_ends: Tensor [batch, 1]
            gamma: float
        """
        assert not self.double_dqn, 'not supported yet'

        # get next distribution
        next_states = Variable(next_states, volatile=True)
        # [batch, num_actions], [batch, num_actions, num_atoms]
        next_q_vals, next_probs = self._q_values(self.target_q_net, next_states)
        next_actions = next_q_vals.data.max(1, True)[1] # [batch, 1]
        next_actions = one_hot(next_actions, self.num_actions).unsqueeze(2)
        next_greedy_probs = (next_actions * next_probs.data).sum(1)

        # transform the distribution
        rewards = rewards.unsqueeze(1)
        non_ends = non_ends.unsqueeze(1)
        proj_zpoints = rewards + gamma * non_ends * self.zpoints.data
        proj_zpoints.clamp_(self.vmin, self.vmax)

        # project onto shared support
        b = (proj_zpoints - self.vmin) / self.delta_z
        lower = b.floor()
        upper = b.ceil()
        # handle corner case where b is integer
        eq = (upper == lower).float()
        lower -= eq
        lt0 = (lower < 0).float()
        lower += lt0
        upper += lt0

        # note: it's faster to do the following on cpu
        ml = (next_greedy_probs * (upper - b)).cpu().numpy()
        mu = (next_greedy_probs * (b - lower)).cpu().numpy()

        lower = lower.cpu().numpy().astype(np.int32)
        upper = upper.cpu().numpy().astype(np.int32)

        batch_size = rewards.size(0)
        mass = np.zeros((batch_size, self.num_atoms), dtype=np.float32)
        brange = range(batch_size)
        for i in range(self.num_atoms):
            mass[brange, lower[brange, i]] += ml[brange, i]
            mass[brange, upper[brange, i]] += mu[brange, i]

        return torch.from_numpy(mass)
        # return torch.from_numpy(mass).cuda()

    def loss(self, states, actions, targets):
        """
        params:
            states: Variable [batch, channel, w, h]
            actions: Variable [batch, num_actions] one hot encoding
            targets: Variable [batch, num_atoms]
        """
        assert_eq(actions.size(1), self.num_actions)

        actions = actions.unsqueeze(2)
        probs = self.online_q_net(states)  # [batch, num_actions, num_atoms]
        probs = (probs * actions).sum(1)  # [batch, num_atoms]
        xent = -(targets * torch.log(probs.clamp(min=EPS))).sum(1)
        xent = xent.mean(0)
        return xent
