In [578]:
from collections import deque

import attr
import gym
from IPython.display import clear_output
from matplotlib import pyplot as plt
import numpy as np
from pytorch_monitor import init_experiment, monitor_module
# from smooth import smooth  # timeseries smoothing function
import torch
from torch import nn
import torch.nn.functional as F


cartpole = gym.make('CartPole-v1')
lunarlander = gym.make('LunarLander-v2')
plt.style.use('seaborn-white')
monitor_config = {
    'title':'Test Monitor',
    'log_dir':'test',
    'random_seed': 0
}
writer, _ = init_experiment(monitor_config)

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'  

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m
[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


# TODO: add device spec for every tensor

In [608]:
@attr.s
class Memory(deque):
    """ Experience Replay Memory class. """
    size = attr.ib()
    minibatch_size = attr.ib()

    def append(self, thing):
        if len(self) > self.size - 1:
            self.popleft()
        return super().append(thing)

    def sample(self):
        batch_size = min(len(self), self.minibatch_size)
        data = random.sample(self, batch_size)
        states = torch.stack([record[0] for record in data])
        actions = torch.tensor([record[1] for record in data], dtype=torch.long)
        rewards = torch.tensor([record[2] for record in data], dtype=torch.float)
        states_ = torch.stack([record[3] for record in data])
        dones = torch.tensor([record[4] for record in data], dtype=torch.long)
        return (states, actions, rewards, states_, dones)


@attr.s
class ValueDistribution(torch.nn.Module):
    state_shape = attr.ib()
    action_shape = attr.ib()
    vmin = attr.ib()
    vmax = attr.ib()
    num_atoms = attr.ib(default=51)
    num_hidden1_units = attr.ib(default=128)
    num_hidden2_units = attr.ib(default=128)
    
    def __attrs_post_init__(self):
        super().__init__()
        self.atoms = torch.linspace(self.vmin, self.vmax, self.num_atoms)
        self.linear1 = nn.Linear(self.state_shape, self.num_hidden1_units)
        self.linear2 = nn.Linear(self.num_hidden1_units, self.num_hidden2_units)
        self.out = nn.Linear(self.num_hidden2_units, self.action_shape * self.num_atoms)
        
    def forward(self, x):
        """ Return (actions x atoms). """
        x1 = F.relu(self.linear1(x))
        x2 = F.relu(self.linear2(x1))
        x3 = F.relu(self.out(x2))
        x3 = x3.reshape(-1, self.action_shape, self.num_atoms)
        out = F.softmax(x3, dim=2)  # (actions x atoms)
        if x.dim() == 1:
            batch_size = 1
        else:
            batch_size = x.size(0)
        assert out.size() == torch.Size((batch_size, self.action_shape, self.num_atoms))
        if hasattr(self, 'monitor'):
            self.monitor('x1', x1, track_data=True, track_grad=True)
            self.monitor('x2', x2, track_data=True, track_grad=True)
            self.monitor('x3', x3, track_data=True, track_grad=True)
            self.monitor('out', out, track_data=True, track_grad=True)
        return out
    
    def predict_action_values(self, states):
        """ Return (batch-size x actions). """
        distribution = self.forward(states)
        weighted_distribution = distribution * self.atoms
        out = weighted_distribution.sum(dim=2).squeeze()  # (batch-size x actions)
        dims = states.dim()
        if dims > 1:
            assert out.size() == torch.Size((self.action_shape,))
        else:
            assert out.size() == torch.Size((self.action_shape,))
        return out
    
    def batch_actions(self, states):
        values = self.predict_action_values(states)
        actions = values.argmax(1)
        assert actions.size(0) == states.size(0)
        return actions
    
    def single_action(self, state):
        values = self.predict_action_values(state)
        action = values.argmax()
        return action

In [589]:
online_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net = ValueDistribution(state_shape=8, action_shape=action_shape, vmin=vmin, vmax=vmax, num_atoms=num_atoms, num_hidden1_units=7, num_hidden2_units=7)
target_net.load_state_dict(online_net.state_dict())

In [628]:
def categorical_loss(online_net, target_net, transitions, discount): 
    states, actions, rewards, states_, dones = transitions
    not_dones = (1 - dones).type(torch.FloatTensor)
    atoms = target_net.atoms
    probabilities = target_net.forward(states_)
    Q_x_ = (probabilities * atoms).sum(2)
    batch_size = states.shape[0]
    assert Q_x_.shape == torch.Size((batch_size, target_net.action_shape)), f'Got: {Q_x_.shape}, expected: {(batch_size, target_net.action_shape)}'
    a_star = Q_x_.argmax(dim=1) 
    assert a_star.shape == torch.Size((batch_size,)), f'Got {a_star.shape}, expected: ((batch_size,))'
    
    # compute the projected probability:
    delta_z = (target_net.vmax - target_net.vmin)/(target_net.num_atoms - 1)    
    # select only the probabilities distributions for the a_star actions:
    probabilities = probabilities[range(batch_size), a_star]
    T_zj = rewards.unsqueeze(1) + discount * atoms * not_dones.unsqueeze(1)
    b_j = (T_zj.clamp(target_net.vmin, target_net.vmax) - target_net.vmin) / delta_z  # correct
    lo = b_j.floor()
    hi = b_j.ceil()
    m = torch.zeros(batch_size, target_net.num_atoms, dtype=torch.float)
    row_idxs = torch.range(0, batch_size-1, dtype=torch.long).unsqueeze(1).expand_as(m)
    m[row_idxs, lo.type(torch.long)] += (probabilities * (hi - b_j))
    m[row_idxs, hi.type(torch.long)] += (probabilities * (b_j - lo))

    # x enthropy is Sigma <true> log <unnatural>, so for us is: target log(online)
    online_distribution = online_net.forward(states)[range(batch_size), actions]
    return -(m * online_distribution.log()).sum(1).mean()

In [632]:
@attr.s
class CategoricalAgent:
    env = attr.ib()
    discount = attr.ib(default=0.99)
    epsilon_max = attr.ib(default=1.0)
    epsilon_min = attr.ib(default=0.01)
    annealing_const = attr.ib(default=.001)  # aka Lambda
    minibatch_size = attr.ib(default=32)
    memory_size = attr.ib(default=int(1e6))
    num_episodes = attr.ib(default=1000)  # num of episodes in a training epoch
    render_every = attr.ib(default=20)  # set to zero to turn off rendering
    update_target_every = attr.ib(default=200)
    vmin = attr.ib(default=-10)
    vmax = attr.ib(default=10)
    num_atoms = attr.ib(default=51)
    learning_rate = attr.ib(default=0.000001)
    
    def __attrs_post_init__(self):
        self.steps = 0
        state_shape = self.env.observation_space.shape[0]
        self.memory = Memory(self.memory_size, self.minibatch_size)
        self.optimizer = torch.optim.Adam(online_net.parameters(), lr=self.learning_rate)
        self.action_shape = self.env.action_space.n
        self.online_net = ValueDistribution(state_shape=state_shape, action_shape=action_shape, vmin=self.vmin, vmax=self.vmax, num_atoms=self.num_atoms)
        self.target_net = ValueDistribution(state_shape=state_shape, action_shape=action_shape, vmin=self.vmin, vmax=self.vmax, num_atoms=self.num_atoms)
        self.target_net.load_state_dict(self.online_net.state_dict())
        self.reset_data_recorders()

    def reset_data_recorders(self):
        self.episode_rewards = []
        self.episode_losses = []
        self.td_errors = []
        self.online_net_q_values = []
        self.target_net_q_values = []
        self.w1_gradient = []
        self.w2_gradient = []
                
    def render(self, episode):
        if self.render_every and episode % self.render_every == 0:
            self.env.render()

    def training_progress_report(self, episode):
        last_ep = self.episode_rewards[-1]
        ten_ep_mean = sum(self.episode_rewards[-10:])/len(self.episode_rewards[-10:])
        hundred_ep_mean = sum(self.episode_rewards[-100:])/len(self.episode_rewards[-100:])
        return f'Ep: {episode} // steps: {self.steps} // last ep reward: {last_ep:.2f} // {min(10, len(self.episode_rewards[-10:]))}-ep mean: {ten_ep_mean:.2f} // {min(100, len(self.episode_rewards[-100:]))}-ep mean: {hundred_ep_mean:.2f}'
                        
    def replay(self):
        batch = self.memory.sample()
        loss = categorical_loss(self.online_net, self.target_net, batch, self.discount)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()/self.minibatch_size
        
    def train(self):
        steps = 0
        for episode in range(self.num_episodes):
            episode_done = False
            episode_reward = 0
            episode_td_errors = 0
            episode_loss = 0
            state = torch.tensor(self.env.reset(), dtype=torch.float)
            self.target_net_q_values.append(self.target_net.predict_action_values(state).max().item())
            while not episode_done:
                epsilon = self.epsilon_min + (self.epsilon_max - self.epsilon_min) * math.exp(-self.annealing_const * self.steps)
                self.steps += 1
                if random.random() < epsilon:
                    action = random.randint(0, self.action_shape-1)
                else:
                    action = self.online_net.single_action(state).item()
                self.render(episode)
                state_, reward, episode_done, _ = self.env.step(action)
                state_ = torch.tensor(state_, dtype=torch.float)
                episode_reward += reward
                self.memory.append((state, action, reward, state_, episode_done))
                episode_loss += self.replay()
                state = state_
                if steps % self.update_target_every == 0:
                    self.target_net.load_state_dict(self.online_net.state_dict())
                if episode_done:
                    self.episode_rewards.append(episode_reward)
                    print(self.training_progress_report(episode), end='\r', flush=True)
        self.env.close()
        
    def test(self):
        pass

In [633]:
agent = CategoricalAgent(cartpole)
agent.train()

Ep: 20 // steps: 602 // last ep reward: 32.00 // 10-ep mean: 28.30 // 21-ep mean: 28.67

AssertionError: 4 (<class 'int'>) invalid

In [638]:
test_state = torch.tensor([-0.1579, -1.4070,  0.1360,  2.0100], dtype=torch.float)
agent.online_net.predict_action_values(test_state)
agent.online_net.single_action(test_state)
agent.on

tensor([ 0.0209,  0.0280, -0.1268,  0.0086,  0.0294])

tensor(4)