In [14]:
import sys
sys.path.append('../r2d2_algo/')
import numpy as np
from gym import spaces
import torch
from torch import nn
from segment_tree import SumSegmentTree, MinSegmentTree
import random
import torch
from torch import nn
import torch.nn.functional as F
from model import RNNQNetwork, linear_schedule
from storage import ContinuousSequenceReplayBuffer, SequenceReplayBuffer
from envs import make_vec_envs
import torch.optim as optim
import random
import numpy as np
import gym
import gym_nav
import time
%run ../r2d2_algo/r2d2_class.py

    
    
    
def get_action_dim(action_space):
    """
    Get the dimension of the action space.
    """
    if isinstance(action_space, spaces.Box):
        return int(np.prod(action_space.shape))
    elif isinstance(action_space, spaces.Discrete):
        # Action is an int
        return 1
    elif isinstance(action_space, spaces.MultiDiscrete):
        # Number of discrete actions
        return int(len(action_space.nvec))
    elif isinstance(action_space, spaces.MultiBinary):
        # Number of binary actions
        assert isinstance(
            action_space.n, int
        ), "Multi-dimensional MultiBinary action space is not supported. You can flatten it instead."
        return int(action_space.n)
    else:
        raise NotImplementedError(f"{action_space} action space is not supported")

In [38]:
env_kwargs = {
        'num_objects': 0, 'rew_structure': 'goal',
        'task_structure': 2, 'wall_colors': 4,
        'num_rays': 12, 'fov': 1, 'max_steps': 20
}
env = gym.make('NavEnv-v0', **env_kwargs)
agent = R2D2Agent(env_id='NavEnv-v0', env_kwargs=env_kwargs,
                 verbose=1, buffer_size=1000, batch_size=256,
                 burn_in_length=4, n_envs=4, dummy_env=True,
                 learning_starts=1000, train_frequency=8)

In [39]:
agent.train(1)

Mean episode length 22.0, mean return 0.0
Mean episode length 22.0, mean return 0.0


In [17]:
agent.collect(2)

In [25]:
sample = agent.rb.sample(256)

In [29]:
agent.rb.pos

3

In [28]:
agent.rb.full

False

In [27]:
len(agent.rb)

3

In [19]:
agent.global_step

56024

In [286]:
sample = agent.rb.sample(256)

In [287]:
sample['observations'].shape

torch.Size([37, 16, 24])

In [144]:

action, q_values, next_rnn_hxs = agent.act(agent.obs, agent.rnn_hxs, masks=agent.masks)
env = agent.env
next_obs, reward, done, info = env.step(action)


In [46]:
agent.batch_size

256

In [40]:
sample = agent.rb.sample()
states = sample['observations']
next_states = sample['next_observations']
hidden_states = sample['hidden_states']
next_hidden_states = sample['next_hidden_states']
actions = sample['actions']
rewards = sample['rewards']
dones = sample['dones']
next_dones = sample['next_dones']
#training_masks are given by SequenceReplayBuffer
training_masks = sample['training_masks']

with torch.no_grad():
    target_q, _, _ = agent.target_network(next_states, next_hidden_states, next_dones)
    target_max, _ = target_q.max(dim=2)
    td_target = rewards + agent.gamma * target_max * (1 - dones)
old_q, _, _ = agent.q_network(states, hidden_states, dones)
old_val = old_q.gather(2, actions.long()).squeeze()

# loss = F.mse_loss(td_target[:, agent.burn_in_length:], old_val[:, agent.burn_in_length:])
weights = sample['weights']
elementwise_loss = F.smooth_l1_loss(td_target[:, agent.burn_in_length:],
                                    old_val[:, agent.burn_in_length:], reduction='none')
# loss = torch.mean(elementwise_loss * weights)
loss = torch.mean(elementwise_loss * weights * training_masks)

if agent.writer is not None and agent.global_update_step % 10 == 0:
    agent.writer.add_scalar('losses/td_loss', loss, agent.global_step)
    agent.writer.add_scalar('losses/q_values', old_val.mean().item(), agent.global_step)
    sps = int(agent.global_step / (time.time() - agent.start_time))
    # print('SPS:', int(sps))
    agent.writer.add_scalar('charts/SPS', sps, agent.global_step)

agent.optimizer.zero_grad()
loss.backward()
agent.optimizer.step()

# PER: update priorities
td_priorities = elementwise_loss.mean(dim=1).detach().cpu().numpy() + 1e-6
agent.rb.update_priorities(sample['idxs'], td_priorities)

In [41]:
training_masks

tensor([[1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [45]:
training_masks.sum()

tensor(36.)

In [42]:
states.shape

torch.Size([6, 12, 24])

In [25]:
hidden_states.shape

torch.Size([1, 4, 64])

In [26]:
next_hidden_states.shape

torch.Size([1, 4, 64])

In [21]:
elementwise_loss

tensor([[1.4553e-07, 1.4515e-07, 1.3384e-02, 5.2076e-02, 1.4448e-07, 1.3384e-02,
         1.3384e-02, 1.4449e-07],
        [3.6777e-03, 4.0718e-09, 4.7054e-02, 1.3871e-07, 3.8636e-03, 4.7342e-02,
         3.9526e-03, 4.7866e-02],
        [1.1588e-07, 1.1506e-07, 1.1452e-07, 1.1419e-07, 1.1399e-07, 1.1389e-07,
         1.1383e-07, 1.1381e-07],
        [4.9114e-02, 2.0491e-07, 2.0749e-07, 2.0899e-07, 2.0985e-07, 2.1033e-07,
         2.1063e-07, 2.1083e-07]], grad_fn=<SmoothL1LossBackward0>)

In [13]:
elementwise_loss.max(dim=1)[0].detach()

tensor([7.2243e-04, 9.7215e-04, 1.3732e-05, 2.8302e-06])

In [357]:
agent.rb.cur_dones

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
      dtype=float32)

In [358]:
agent.rb.cur_observations[:, :, 5:10]

array([[[0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        , 0.        , 0.        , 0.        , 0.        ],
        [0.        ,