In [1]:
from collections import deque
import math
import random

import attr
import gym
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
from pytorch_monitor import init_experiment, monitor_module
# from smooth import smooth  # timeseries smoothing function
import torch
from torch import nn
import torch.nn.functional as F


cartpole = gym.make('CartPole-v1')
lunarlander = gym.make('LunarLander-v2')
plt.style.use('seaborn-white')

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'  

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


# TODO: add device spec for every tensor

In [2]:
@attr.s
class Memory(deque):
    """ Experience Replay Memory class. """
    size = attr.ib()
    minibatch_size = attr.ib()

    def append(self, thing):
        if len(self) > self.size - 1:
            self.popleft()
        return super().append(thing)

    def sample(self):
        batch_size = min(len(self), self.minibatch_size)
        data = random.sample(self, batch_size)
        states = torch.stack([record[0] for record in data])
        actions = torch.tensor([record[1] for record in data], dtype=torch.long)
        rewards = torch.tensor([record[2] for record in data], dtype=torch.float)
        states_ = torch.stack([record[3] for record in data])
        dones = torch.tensor([record[4] for record in data], dtype=torch.long)
        return (states, actions, rewards, states_, dones)


class ValueDistribution(torch.nn.Module):
    def __init__(self, state_shape, action_shape, vmin, vmax, num_atoms=51, num_hidden1_units=64, num_hidden2_units=64):
        super().__init__()
        self.state_shape = state_shape
        self.action_shape = action_shape
        self.vmin = vmin
        self.vmax = vmax
        self.num_atoms = num_atoms
        self.atoms = torch.linspace(self.vmin, self.vmax, self.num_atoms)
        self.linear1 = nn.Linear(self.state_shape, num_hidden1_units)
        self.linear2 = nn.Linear(num_hidden1_units, num_hidden2_units)
        self.linear3 = nn.Linear(num_hidden2_units, num_hidden2_units)
        self.linear4 = nn.Linear(num_hidden2_units, self.action_shape * self.num_atoms)
        
    def forward(self, x):
        """ Return (actions x atoms). """
        x1 = F.selu(self.linear1(x))
        x2 = F.selu(self.linear2(x1))
        x3 = F.selu(self.linear3(x2))
        x4 = self.linear4(x3).reshape(-1, self.action_shape, self.num_atoms)
        out = F.softmax(x4, dim=2)  # (actions x atoms)
        if x.dim() == 1:
            batch_size = 1
        else:
            batch_size = x.size(0)
        assert out.size() == torch.Size((batch_size, self.action_shape, self.num_atoms))
        if hasattr(self, 'monitor'):
            self.monitor('x1', x1, track_data=True, track_grad=True)
            self.monitor('x2', x2, track_data=True, track_grad=True)
            self.monitor('x3', x3, track_data=True, track_grad=True)
            self.monitor('x4', x4, track_data=True, track_grad=True)
            self.monitor('out', out, track_data=True, track_grad=True)
        return out
    
    def predict_action_values(self, states):
        """ Return (batch-size x actions). """
        distribution = self.forward(states)
        weighted_distribution = distribution * self.atoms
        out = weighted_distribution.sum(dim=2).squeeze()  # (batch-size x actions)
        dims = states.dim()
        assert out.size() == torch.Size((self.action_shape,))
        return out
        
    def get_action(self, state):        
        values = self.predict_action_values(state)
        action = values.argmax()
        return action


In [81]:
def categorical_vectorized_loss(online_net, target_net, transitions, discount): 
    states, actions, rewards, states_, dones = transitions
    not_dones = (1 - dones).type(torch.FloatTensor)
    atoms = target_net.atoms
    probabilities = target_net.forward(states_)
    Q_x_ = (probabilities * atoms).sum(2)
    batch_size = states.shape[0]
    assert Q_x_.shape == torch.Size((batch_size, target_net.action_shape)), f'Got: {Q_x_.shape}, expected: {(batch_size, target_net.action_shape)}'
    a_star = Q_x_.argmax(dim=1) 
    assert a_star.shape == torch.Size((batch_size,)), f'Got {a_star.shape}, expected: ((batch_size,))'
    
    # compute the projected probability:
    delta_z = (target_net.vmax - target_net.vmin)/(target_net.num_atoms - 1)    
    # select only the probabilities distributions for the a_star actions:
    probabilities = probabilities[range(batch_size), a_star]
    T_zj = rewards.unsqueeze(1) + discount * atoms * not_dones.unsqueeze(1)
    b_j = (T_zj.clamp(target_net.vmin, target_net.vmax) - target_net.vmin) / delta_z  # correct


    
    lo = b_j.floor()        
    hi = b_j.ceil()
    m = torch.zeros(batch_size, target_net.num_atoms, dtype=torch.float)
    upper = target_net.num_atoms * batch_size - target_net.num_atoms
    offset = torch.range(0, upper, target_net.num_atoms).unsqueeze(1).expand_as(m)
    lo_idx = lo + offset
    hi_idx = hi + offset
    lo_component = m.view(-1).index_add(0, lo_idx.view(-1).type(torch.long), (probabilities * (hi - b_j)).view(-1)).resize_as(m)
    hi_component = m.view(-1).index_add(0, hi_idx.view(-1).type(torch.long), (probabilities * (b_j - lo)).view(-1)).resize_as(m)
    m += lo_component + hi_component       
    # return m
    # x enthropy is Sigma <true> log <unnatural>, so for us is: target log(online)
    online_distribution = online_net.forward(states)[range(batch_size), actions]
    return -( m * online_distribution.log() ).sum(1).mean()

In [84]:
@attr.s
class CategoricalAgent:
    env = attr.ib()
    discount = attr.ib(default=0.99)
    epsilon_max = attr.ib(default=1.0)
    epsilon_min = attr.ib(default=0.01)
    annealing_const = attr.ib(default=.001)  # aka Lambda
    minibatch_size = attr.ib(default=32)
    memory_size = attr.ib(default=int(1e6))
    num_episodes = attr.ib(default=1000)  # num of episodes in a training epoch
    render_every = attr.ib(default=20)  # set to zero to turn off rendering
    update_target_every = attr.ib(default=200)
    vmin = attr.ib(default=-10)
    vmax = attr.ib(default=10)
    num_atoms = attr.ib(default=51)
    learning_rate = attr.ib(default=0.000001)
    monitor_every = attr.ib(default=50)
    logger = attr.ib(default=None)
    xavier = attr.ib(default=False)
    
    def __attrs_post_init__(self):
        self.steps = 0
        state_shape = self.env.observation_space.shape[0]
        self.memory = Memory(self.memory_size, self.minibatch_size)
        self.action_shape = self.env.action_space.n
        self.online_net = ValueDistribution(state_shape=state_shape, action_shape=self.action_shape, vmin=self.vmin, vmax=self.vmax, num_atoms=self.num_atoms)
        self.target_net = ValueDistribution(state_shape=state_shape, action_shape=self.action_shape, vmin=self.vmin, vmax=self.vmax, num_atoms=self.num_atoms)
        if self.xavier:
            gain = nn.init.calculate_gain('relu')
            for param in self.online_net.parameters():
                if param.dim() < 2:
                    continue
                nn.init.xavier_normal_(param, gain=gain)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=self.learning_rate)
        self.reset_data_recorders()        

    def reset_data_recorders(self):
        self.episode_rewards = []
        self.episode_losses = []
        self.td_errors = []
        self.online_net_q_values = []
        self.target_net_q_values = []
        self.w1_gradient = []
        self.w2_gradient = []
        self.steps = 0

    def render(self, episode):
        if self.render_every and episode % self.render_every == 0:
            self.env.render()

    def training_progress_report(self, episode):
        last_ep = self.episode_rewards[-1]
        ten_ep_mean = sum(self.episode_rewards[-10:])/len(self.episode_rewards[-10:])
        hundred_ep_mean = sum(self.episode_rewards[-100:])/len(self.episode_rewards[-100:])
        return f'Ep: {episode} // steps: {self.steps} // last ep reward: {last_ep:.2f} // {min(10, len(self.episode_rewards[-10:]))}-ep mean: {ten_ep_mean:.2f} // {min(100, len(self.episode_rewards[-100:]))}-ep mean: {hundred_ep_mean:.2f}'

    def replay(self):
        batch = self.memory.sample()
        loss = categorical_vectorized_loss(self.online_net, self.target_net, batch, self.discount)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()/self.minibatch_size

    def monitor(self):
        if self.steps % self.monitor_every == 0:
            self.target_net.monitoring(True)
        else:
            self.target_net.monitoring(False)

    def train(self):
        for episode in range(self.num_episodes):
            episode_done = False
            episode_reward = 0
            episode_loss = 0
            state = torch.tensor(self.env.reset(), dtype=torch.float)
            self.target_net_q_values.append(self.target_net.predict_action_values(state).max().item())
            if self.steps == 0:
                self.logger.add_graph(self.target_net, state)
            writer.add_scalar('Target net Q values', self.target_net_q_values[-1], self.steps)                
            while not episode_done:
                epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) * math.exp(-self.annealing_const * self.steps)
                self.steps += 1                
                if random.random() < epsilon:
                    action = random.randint(0, self.action_shape-1)
                else:
                    action = self.online_net.get_action(state).item()
                self.render(episode)
                self.monitor()
                state_, reward, episode_done, _ = self.env.step(action)
                state_ = torch.tensor(state_, dtype=torch.float)
                episode_reward += reward
                self.memory.append((state, action, reward, state_, episode_done))
                state = state_
                if self.steps < 2:
                    continue
                episode_loss += self.replay()

                if self.steps % self.update_target_every == 0:
                    self.target_net.load_state_dict(self.online_net.state_dict())
                if episode_done:
                    self.episode_rewards.append(episode_reward)
                    print(self.training_progress_report(episode), end='\r', flush=True)
                    writer.add_scalar('train loss', episode_loss, self.steps)
                    writer.add_scalar('episode reward', episode_reward, self.steps)                    
        self.env.close()

    def test(self):
        pass

In [85]:
monitor_config = {
    'title':'Testing Floringo implementation',
    'log_dir':'categorical-dqn',
    'random_seed': 0
}
writer, monitor_config = init_experiment(monitor_config)
monitor_config
agent = CategoricalAgent(
    cartpole, 
    learning_rate=0.000000001, 
    logger=writer,
    monitor_every=10,
    num_episodes=400,
    update_target_every=100,  
    xavier=False
)
monitor_module(
    agent.target_net, writer, 
    track_data=True,
    track_grad=True,
    track_update=True,
    track_update_ratio=True
)
agent.train()

{'log_dir': 'categorical-dqn',
 'random_seed': 0,
 'run_dir': 'categorical-dqn/Jul-19-18@10:08:04-DaydreamNation.local',
 'run_name': 'Jul-19-18@10:08:04-DaydreamNation.local',
 'tag': 'Experiment Config: Testing Floringo implementation :: Jul-19-18@10:08:04\n',
 'title': 'Testing Floringo implementation'}

  .format(op_name, op_name))


Ep: 399 // steps: 4350 // last ep reward: 9.00 // 10-ep mean: 9.30 // 100-ep mean: 9.40123

In [86]:
arm0 = 2.7
arm1 = 5
arm2 = -3.4


class Bandit:
    @staticmethod
    def step(action):
        if action == 0:
            return torch.tensor([0.]), np.random.normal(arm0, .5)
        elif action == 1:
            return torch.tensor([1.]), np.random.normal(arm1, .5)
        elif action == 2:
            return torch.tensor([2.]), np.random.normal(arm2, .5)            
        else:
            raise ValueError('Invalid action', action)
            
    
@attr.s
class BanditAgent:
    num_arms = attr.ib(default=3)
    num_rounds = attr.ib(default=2000)
    discount = attr.ib(default=0.99)
    epsilon_max = attr.ib(default=1.0)
    epsilon_min = attr.ib(default=0.01)
    annealing_const = attr.ib(default=.001)  # aka Lambda
    minibatch_size = attr.ib(default=32)
    memory_size = attr.ib(default=int(1e6))
    vmin = attr.ib(default=-5)
    vmax = attr.ib(default=5)
    num_atoms = attr.ib(default=10)
    learning_rate = attr.ib(default=0.00025)
    monitor_every = attr.ib(default=5)
    logger = attr.ib(default=None)
    update_target_every = attr.ib(default=100)
    xavier = attr.ib(default=False)

    def __attrs_post_init__(self):
        self.online_net = ValueDistribution(state_shape=1, action_shape=self.num_arms, vmin=self.vmin, vmax=self.vmax, num_atoms=self.num_atoms, num_hidden1_units=40, num_hidden2_units=40,)
        self.target_net = ValueDistribution(state_shape=1, action_shape=self.num_arms, vmin=self.vmin, vmax=self.vmax, num_atoms=self.num_atoms, num_hidden1_units=40, num_hidden2_units=40,)        
        if self.xavier:
            for param in bandit_agent.online_net.parameters():
                if param.dim() < 2:
                    continue
                nn.init.xavier_normal_(param)        
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=self.learning_rate)
        self.reset_data_recorders()
        self.memory = Memory(1e6, 32)
        
    def reset_data_recorders(self):
        self.episode_rewards = []
        self.episode_losses = []
        self.q_values = []
        self.steps = 0

    def monitor(self):
        if self.steps % 20 == 0:
            print(f'{self.steps}... ', end='\r', flush=True)

        if self.steps % self.monitor_every == 0:
            self.target_net.monitoring(True)
        else:
            self.target_net.monitoring(False)
            
            
    def replay(self):
        batch = self.memory.sample()
        loss = categorical_vectorized_loss(self.online_net, self.target_net, batch, self.discount)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()/self.minibatch_size
            
    def train(self):
        state = torch.tensor([1], dtype=torch.float)
        for episode in range(self.num_rounds):            
            self.q_values.append(self.target_net.predict_action_values(state).max().item())
            if self.steps == 0:
                self.logger.add_graph(self.target_net, state)
            writer.add_scalar('Target net Q values', self.q_values[-1], self.steps)                

            epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) * math.exp(-self.annealing_const * self.steps)
            self.steps += 1                
            if random.random() < epsilon:
                action = random.randint(0, self.num_arms-1)
            else:
                action = self.online_net.get_action(state).item()
            self.monitor()
            state_, reward = Bandit.step(action)
            state_ = torch.tensor(state_, dtype=torch.float)
            episode_reward = reward
            self.memory.append((state, action, reward, state_, torch.tensor([0])))
            episode_loss = self.replay()
            state = state_
            writer.add_scalar('train loss', episode_loss, self.steps)
            writer.add_scalar('episode reward', episode_reward, self.steps)   
            
            if self.steps % self.update_target_every == 0:
                self.target_net.load_state_dict(self.online_net.state_dict())
            

    def test(self):
        pass       


In [87]:


monitor_config = {
    'title':'Bandit test',
    'log_dir':'categorical-dqn',
    'random_seed': 0
}
writer, monitor_config = init_experiment(monitor_config)
bandit_agent = BanditAgent(
    logger=writer, 
    learning_rate=0.000001, 
    num_rounds=2000, 
    monitor_every=10,
    xavier=False,
)
monitor_config
monitor_module(
    bandit_agent.target_net, writer, 
    track_data=True,
    track_grad=True,
    track_update=True,
    track_update_ratio=True
)
bandit_agent.train()

{'log_dir': 'categorical-dqn',
 'random_seed': 0,
 'run_dir': 'categorical-dqn/Jul-19-18@10:11:41-DaydreamNation.local',
 'run_name': 'Jul-19-18@10:11:41-DaydreamNation.local',
 'tag': 'Experiment Config: Bandit test :: Jul-19-18@10:11:41\n',
 'title': 'Bandit test'}

20... 

  .format(op_name, op_name))


2000... 

In [239]:
arm0
arm1
arm2

bandit_agent.target_net.predict_action_values(torch.tensor([0], dtype=torch.float))
bandit_agent.target_net.predict_action_values(torch.tensor([1], dtype=torch.float))
bandit_agent.target_net.predict_action_values(torch.tensor([2], dtype=torch.float))


states, actions, rewards, states_, _ = bandit_agent.memory.sample()
rewards[states_.squeeze() == 0].mean()
rewards[states_.squeeze() == 1].mean()
rewards[states_.squeeze() == 2].mean()

2.7

5

-3.4

tensor([ 0.2444,  0.1126, -0.1673])

tensor([ 0.1686, -0.0084, -0.2471])

tensor([ 0.0947, -0.1132, -0.3562])

tensor(2.7388)

tensor(4.8597)

tensor(-3.7678)

In [210]:
bandit_agent.target_net.linear4.weight
bandit_agent.target_net.linear4.weight.grad.norm()

Parameter containing:
tensor([[-9.9555e-02, -1.2363e-01, -1.2882e-01,  ...,  4.8230e-02,
         -4.9638e-03,  1.2081e-01],
        [ 2.0693e-02, -1.5338e-01, -1.3020e-01,  ..., -1.2003e-01,
         -1.0695e-01, -1.5849e-01],
        [ 1.1170e-01,  1.3800e-01, -5.7310e-02,  ..., -1.1644e-01,
          2.8391e-02,  1.4304e-01],
        ...,
        [-4.2989e-02, -1.1564e-01,  1.4043e-01,  ..., -3.2616e-02,
         -2.2832e-02,  1.5632e-01],
        [ 2.0642e-02,  1.0494e-01,  9.0590e-02,  ..., -7.3867e-02,
          1.2260e-01, -2.3937e-02],
        [ 6.5399e-02, -2.7890e-02, -1.7782e-02,  ...,  4.1532e-02,
         -1.2335e-01, -5.9142e-02]])

tensor(518.3018)

In [6]:
import others_people_categorical as others

# Speed and sanity checks:
states = torch.tensor([    
    [1.0, 2, 3, 4, 5, 6, 7, 8],
    [-4, -3, -1, -5, -5,  3,  1,  3],
    [ 4,  1, -3, -5,  1, -4, -2, -4],
])
actions = torch.tensor([0, 4, 1]) # num actions: 5
rewards = torch.tensor([-7.0,  4,  3,])
dones = torch.tensor([1, 0, 0], dtype=torch.long)
states_ = states + 1
many_transition = (states, actions, rewards, states_, dones, )

num_atoms = 11
vmin = -10
vmax = 10 
action_shape = 5

random.seed(0)
np.random.seed(0)  
online_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net.load_state_dict(online_net.state_dict())
discount = 0.95

In [80]:
# Timing:

# '--Mine unvect--'
# "%timeit mine = categorical_loss(online_net, target_net, many_transition, discount)"

'--Mine vect--'
%timeit mine_vectorize = categorical_vectorized_loss(online_net, target_net, many_transition, discount)

"--Higgs'--"
%timeit others.projection_distribution(target_net, states, rewards, dones.type(torch.float), discount)

"--Hengyuan-hu's--"
agent = others.DistributionalDQNAgent(online_net, False, action_shape, num_atoms, online_net.vmin, online_net.vmax)
not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
%timeit agent.compute_targets(rewards, states_, not_dones, discount)

"--Floringo's--"
floringo_agent = others.CategoricalPolicyImprovement(online_net, target_net, v_min=vmin, v_max=vmax, atoms_no=num_atoms, batch_size=32)
%timeit m = floringo_agent._get_categorical(states_, rewards.unsqueeze(1), (1 - dones).unsqueeze(1))

'--Mine vect--'



636 µs ± 133 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


"--Higgs'--"

281 µs ± 4.31 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


"--Hengyuan-hu's--"

  next_states = Variable(next_states, volatile=True)


1.07 ms ± 7.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


"--Floringo's--"

362 µs ± 5.25 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [78]:


discount = 0.95

mine_vectorize = categorical_vectorized_loss(online_net, target_net, many_transition, discount)
'--Mine vectorized--'
mine_vectorize


"--Higgs'--"
#higgs = others.projection_distribution(target_net, states, rewards, dones.type(torch.float), discount)
#higgs

"--Hengyuan-hu's--"
agent = others.DistributionalDQNAgent(online_net, False, action_shape, num_atoms, online_net.vmin, online_net.vmax)
not_dones = (~dones.type(torch.ByteTensor)).type(torch.FloatTensor)
heng = agent.compute_targets(rewards, states_, not_dones, discount)
heng

"--Floringo's--"
floringo_agent = others.CategoricalPolicyImprovement(online_net, target_net, v_min=vmin, v_max=vmax, atoms_no=num_atoms, batch_size=3)
floringos = floringo_agent._get_categorical(states_, rewards.unsqueeze(1), (1 - dones).unsqueeze(1))
floringos



[21] > [33;01m<ipython-input-77-c020e9f87e77>[00m([36;01m32[00m)categorical_vectorized_loss()
-> return m
(Pdb++) m
tensor([[ 0.0000,  0.5000,  0.5000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0332,  0.0760,  0.1341,  0.1070,  0.0724,
          0.0057,  0.0502,  0.0643,  0.0987],
        [ 0.0000,  0.0138,  0.0635,  0.0768,  0.0810,  0.0759,  0.0595,
          0.1071,  0.1537,  0.1395,  0.0426]])
(Pdb++) b_j
tensor([[  1.5000,   1.5000,   1.5000,   1.5000,   1.5000,   1.5000,
           1.5000,   1.5000,   1.5000,   1.5000,   1.5000],
        [  2.2500,   3.2000,   4.1500,   5.1000,   6.0500,   7.0000,
           7.9500,   8.9000,   9.8500,  10.0000,  10.0000],
        [  1.7500,   2.7000,   3.6500,   4.6000,   5.5500,   6.5000,
           7.4500,   8.4000,   9.3500,  10.0000,  10.0000]])
(Pdb++) lo
tensor([[  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.,
           1.],
        [  2.,   3.,   4.,   

'--Mine vectorized--'

tensor([[ 0.0000,  0.5000,  0.5000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0332,  0.0760,  0.1341,  0.1070,  0.0724,
          0.0057,  0.0502,  0.0643,  0.0987],
        [ 0.0000,  0.0138,  0.0635,  0.0768,  0.0810,  0.0759,  0.0595,
          0.1071,  0.1537,  0.1395,  0.0426]])

"--Higgs'--"

"--Hengyuan-hu's--"

  next_states = Variable(next_states, volatile=True)


tensor([[ 0.0000,  0.5000,  0.5000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0332,  0.0760,  0.1341,  0.1070,  0.0724,
          0.0884,  0.0502,  0.0643,  0.3744],
        [ 0.0000,  0.0138,  0.0635,  0.0768,  0.0810,  0.0759,  0.0595,
          0.1071,  0.1537,  0.1395,  0.2294]])

"--Floringo's--"

tensor([[ 0.0000,  0.5000,  0.5000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0332,  0.0760,  0.1341,  0.1070,  0.0724,
          0.0884,  0.0502,  0.0643,  0.3744],
        [ 0.0000,  0.0138,  0.0635,  0.0768,  0.0810,  0.0759,  0.0595,
          0.1071,  0.1537,  0.1395,  0.2294]])

In [61]:
%debug

> [0;32m<ipython-input-59-bcfd1aab630c>[0m(57)[0;36mcategorical_vectorized_loss2[0;34m()[0m
[0;32m     55 [0;31m    [0mlo[0m [0;34m+=[0m [0moffset[0m[0;34m[0m[0m
[0m[0;32m     56 [0;31m    [0mhi[0m [0;34m+=[0m [0moffset[0m[0;34m[0m[0m
[0m[0;32m---> 57 [0;31m    [0mm[0m [0;34m+=[0m [0mm[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mindex_add[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0mlo[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mtype[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mlong[0m[0;34m)[0m[0;34m,[0m [0;34m([0m[0mprobabilities[0m [0;34m*[0m [0;34m([0m[0mhi[0m [0;34m-[0m [0mb_j[0m[0;34m)[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0m
[0m[0;32m     58 [0;31m    [0mm[0m [0;34m+=[0m [0mm[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0

In [48]:
states, actions, rewards, states_, dones = many_transition

def categorical_loss(online_net, target_net, transitions, discount): 
    states, actions, rewards, states_, dones = transitions
    not_dones = (1 - dones).type(torch.FloatTensor)
    atoms = target_net.atoms
    probabilities = target_net.forward(states_)

    Q_x_ = (probabilities * atoms).sum(2)


    batch_size = states.shape[0]
    assert Q_x_.shape == torch.Size((batch_size, target_net.action_shape)), f'Got: {Q_x_.shape}, expected: {(batch_size, target_net.action_shape)}'
    a_star = Q_x_.argmax(dim=1) 
    assert a_star.shape == torch.Size((batch_size,)), f'Got {a_star.shape}, expected: ((batch_size,))'

    # compute the projected probability:
    delta_z = (target_net.vmax - target_net.vmin)/(target_net.num_atoms - 1)    
    # select only the probabilities distributions for the a_star actions:
    probabilities = probabilities[range(batch_size), a_star]
    m = torch.zeros(batch_size, target_net.num_atoms, dtype=torch.float)
    for idx, atom in enumerate(atoms):
        T_zj = rewards.unsqueeze(1) + discount * atom * not_dones.unsqueeze(1)
        T_zj = T_zj.clamp(target_net.vmin, target_net.vmax)
        b_j = (T_zj - target_net.vmin)/delta_z
        lo = b_j.floor()
        hi = b_j.ceil()
        
        m[:, lo.type(torch.long)] += ((probabilities * (hi - b_j))[:, idx]).unsqueeze(1)
        m[:, hi.type(torch.long)] += ((probabilities * (b_j - lo))[:, idx]).unsqueeze(1)
    return m