In [526]:
import numpy as np
import gym
import math
import random
from collections import namedtuple
from itertools import count
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter

%matplotlib inline

In [527]:
env = gym.make('CartPole-v1')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Rainbow flags

In [528]:
MULTISTEP = False
DOUBLE_DQN = False
DUELING = False
NOISY = False
PRIORITY_REPLAY = False
DISTRIBUTIONAL = False

## Hyperparameters

In [529]:
# DQN
TARGET_UPDATE = 10
GAMMA = 0.999

# Memory
BATCH_SIZE = 128
MEMORY_SIZE = 100000

# EPS exploration
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200

# Multistep
N_STEP = 5

# Distributional
N_ATOMS = 51
V_MIN = -10
V_MAX = 10

# Prioritized replay
PRIORITY_EPS = 0.005
PRIORITY_ALPHA = 1

IS_BETA_START = 0.1
IS_BETA_END = 1
IS_BETA_DECAY = 200

# NOISY
STD_INIT = 0.5

## Memory

In [530]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayBuffer:
    def __init__(self, n_step, memory):
        self.n_step = n_step
        self.memory = memory
        self.n_step_buffer = list()
    
    def push(self, *args): 
        nw_t = Transition(*args)
        self.n_step_buffer.append(nw_t)
        
        if len(self.n_step_buffer) < self.n_step:
            return
        
        R = sum([self.n_step_buffer[0].reward.item() * (GAMMA**i) for i in range(self.n_step)])
        R = torch.tensor([R], dtype=torch.float, device=device)
        t = self.n_step_buffer.pop(0)
        t = Transition(t.state, t.action, nw_t.next_state, R)
        
        memory.push_in_memory(t)

            
    def flush_buffer(self):
        while self.n_step_buffer:
            R = sum([self.n_step_buffer[0].reward.item() * (GAMMA**i) for i in range(len(self.n_step_buffer))])
            R = torch.tensor([R], dtype=torch.float, device=device)
            t = self.n_step_buffer.pop(0)
            t = Transition(t.state, t.action, t.next_state, R)
            
            memory.push_in_memory(t)
        

In [531]:
class ReplayMemory:
    def __init__(self, capacity, n_step):
        self.capacity = capacity
        self.memory = list()
        self.position = 0
        
        self.buffer = ReplayBuffer(n_step, self)
        
    def push_in_memory(self, t):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = t
        
        self.position += 1
        if self.position == self.capacity:
            self.position = 0
    
    def push(self, *args): 
        self.buffer.push(*args)
            
    def flush_buffer(self):
        self.buffer.flush_buffer()
        
    def get_transitions(self, positions):
        return [self.memory[pos] for pos in positions]
            
    def sample(self, batch_size):
        return np.random.choice(len(self.memory), batch_size)
    
    def __len__(self):
        return len(self.memory)

In [532]:
class RangeTree:
    def __init__(self):
        self.size = 1
        while self.size < MEMORY_SIZE:
            self.size *= 2
        self.values = np.zeros(2 * self.size)
        self.max_values = np.zeros(2 * self.size)
    
    def add(self, pos, x):
        pos += self.size
        self.values[pos] = x
        self.max_values[pos] = x;
        pos //= 2
        while (pos):
            self.values[pos] = self.values[2 * pos] + self.values[2 * pos + 1]
            self.max_values[pos] = max(self.max_values[2 * pos], self.max_values[2 * pos + 1])
            pos //= 2
            
    def get_max(self):
        return self.max_values[1]
            
    def get(self, x):
        x *= self.values[1]
        pos = 1
        while pos < self.size:
            if x > self.values[2 * pos]:
                x -= self.values[2 * pos]
                pos = pos * 2 + 1
            else:
                pos = pos * 2
        return pos - self.size

In [533]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class PrioritizedReplayMemory:
    
    def __init__(self, capacity, n_step):
        self.capacity = capacity
        self.memory = list()
        self.position = 0
        self.tree = RangeTree()
        self.priorities = np.empty((MEMORY_SIZE, ))
        self.size = 0
        
        self.buffer = ReplayBuffer(n_step, self)
    
    def push_in_memory(self, t):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
            
        self.memory[self.position] = t
        self.priorities[self.position] = max(self.tree.get_max(), PRIORITY_EPS)
        self.tree.add(self.position, self.priorities[self.position] ** PRIORITY_ALPHA)
        
        self.position += 1
        if self.position == self.capacity:
            self.position = 0
            
    def push(self, *args):
        self.buffer.push(*args)
        
    def flush_buffer(self):
        self.buffer.flush_buffer()
            
    def get_priorities(self, positions):
        return [self.priorities[pos] for pos in positions]
    
    def get_transitions(self, positions):
        return [self.memory[pos] for pos in positions]
            
    def update(self, positions, td_errors):
        for (pos, error) in zip(positions, td_errors):
            self.priorities[pos] = abs(error + PRIORITY_EPS)
            self.tree.add(pos,  self.priorities[pos] ** PRIORITY_ALPHA)
            
    def sample(self, batch_size):
        return [self.tree.get(np.random.uniform(k / batch_size, (k + 1) / batch_size)) for k in range(batch_size)]

    def __len__(self):
        return len(self.memory)
        

## Network

In [534]:
class NoisyLinear(nn.Module):
    def __init__(self, in_sz, out_sz, std_init=STD_INIT):
        super(NoisyLinear, self).__init__()
        self.in_sz = in_sz
        self.out_sz = out_sz
        self.std_init = std_init
        
        self.weight_mu = nn.Parameter(torch.empty(out_sz, in_sz, device=device))
        self.weight_sigma = nn.Parameter(torch.empty(out_sz, in_sz, device=device))
        self.register_buffer('weight_epsilon', torch.empty(out_sz, in_sz, device=device))
        
        self.bias_mu = nn.Parameter(torch.empty(out_sz, device=device))
        self.bias_sigma = nn.Parameter(torch.empty(out_sz, device=device))
        self.register_buffer('bias_epsilon', torch.empty(out_sz, device=device))
        
        mu_range = 1 / math.sqrt(self.in_sz)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_sz))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_sz))
        
        self.reset_noise()
        
    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_sz)
        epsilon_out = self._scale_noise(self.out_sz)
        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)
        
    def _scale_noise(self, size):
        x = torch.randn(size, device=device)
        return x.sign().mul_(x.abs().sqrt_())

    def forward(self, input):
        return F.linear(input, self.weight_mu + self.weight_sigma * self.weight_epsilon, self.bias_mu + self.bias_sigma * self.bias_epsilon)

In [535]:
class DQN(nn.Module):
    def __init__(self, outputs, n_atoms):
        super(DQN, self).__init__()  
        self.n_atoms = n_atoms
        self.outputs = outputs
        
        if NOISY:
            self.first = NoisyLinear(4, 320)
            self.second = NoisyLinear(320, self.outputs * self.n_atoms)
        else:
            self.first = nn.Linear(4, 320)
            self.second = nn.Linear(320, self.outputs * self.n_atoms)

    def forward(self, x):
        x = F.relu(self.first(x))
        x = self.second(x)
        if DISTRIBUTIONAL:
            return F.softmax(x.view(-1, self.outputs, self.n_atoms), dim=2)
        else:
            return x
    
    def reset_noise(self):
        if NOISY:
            self.first.reset_noise()
            self.second.reset_noise()

In [536]:
class Dueling_DQN(nn.Module):
    def __init__(self, outputs, n_atoms):
        super(Dueling_DQN, self).__init__()
        self.n_atoms = n_atoms
        self.outputs = outputs
        
        if NOISY:
            self.val_1 = NoisyLinear(4, 320)
            self.val_2 = NoisyLinear(320, self.n_atoms)

            self.adv_1 = NoisyLinear(4, 320)
            self.adv_2 = NoisyLinear(320, outputs * self.n_atoms)
        else:
            self.val_1 = nn.Linear(4, 320)
            self.val_2 = nn.Linear(320, self.n_atoms)

            self.adv_1 = nn.Linear(4, 320)
            self.adv_2 = nn.Linear(320, outputs * self.n_atoms)

    def forward(self, x):
        val = F.relu(self.val_1(x))
        adv = F.relu(self.adv_1(x))
        
        val = self.val_2(val)
        adv = self.adv_2(adv)
        
        batch_size = adv.size(0)

        if DISTRIBUTIONAL:
            val = val.view(-1, 1, self.n_atoms)
            adv = adv.view(-1, self.outputs, self.n_atoms)  
            adv = adv - adv.mean(dim=1, keepdim=True)
            x = val + adv
            return F.softmax(x.view(-1, self.outputs, self.n_atoms), dim=2)
        else:
            adv = adv - adv.mean(1).unsqueeze(1).expand(batch_size, self.outputs * self.n_atoms)
            val = val.expand(batch_size, self.outputs * self.n_atoms) 
            x = val + adv
            return x
    
    def reset_noise(self):
        if NOISY:
            self.val_1.reset_noise()
            self.val_2.reset_noise()
            self.adv_1.reset_noise()
            self.adv_2.reset_noise()

## Declare objects

In [537]:
n_actions = env.action_space.n
steps_done = 0
episode_durations = list()

if DUELING:
    net = Dueling_DQN(n_actions, N_ATOMS if DISTRIBUTIONAL else 1).to(device)
    target_net = Dueling_DQN(n_actions, N_ATOMS if DISTRIBUTIONAL else 1).to(device)
else:
    net = DQN(n_actions, N_ATOMS if DISTRIBUTIONAL else 1).to(device)
    target_net = DQN(n_actions, N_ATOMS if DISTRIBUTIONAL else 1).to(device)
    
optimizer = optim.Adam(net.parameters(), lr=0.0001)

if PRIORITY_REPLAY:
    memory = PrioritizedReplayMemory(MEMORY_SIZE, N_STEP if MULTISTEP else 1)
else:
    memory = ReplayMemory(MEMORY_SIZE, N_STEP if MULTISTEP else 1)

## Plot

In [538]:
def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated

## Agent

In [539]:
supports = torch.linspace(V_MIN, V_MAX, N_ATOMS).view(1, 1, N_ATOMS).to(device)
delta = (V_MAX - V_MIN) / (N_ATOMS - 1)

In [540]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if NOISY or sample > eps_threshold:
        with torch.no_grad():
            net.reset_noise()
            if DISTRIBUTIONAL:
                a = net(state) * supports
                return a.sum(dim=2).max(1)[1].view(1, 1)
            else:
                return net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], dtype=torch.long, device=device)

In [541]:
def get_next_action(next_states, network):
    with torch.no_grad():
        if DISTRIBUTIONAL:
            next_dist = network(next_states) * supports
            return next_dist.sum(dim=2).max(1)[1].view(next_states.size(0), 1, 1).expand(-1, -1, N_ATOMS)
        else:
            return network(next_states).max(1)[1].unsqueeze(1)

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    
    transitions_pos = memory.sample(BATCH_SIZE)
    transitions = memory.get_transitions(transitions_pos)
    
    if PRIORITY_REPLAY:
        beta = IS_BETA_END - (IS_BETA_END - IS_BETA_START) ** math.exp(-1. * steps_done / IS_BETA_DECAY)
        weights = (torch.tensor(memory.get_priorities(transitions_pos), dtype=torch.float, device=device) * BATCH_SIZE) ** (-beta)
        weights = weights / torch.max(weights)
    else:
        weights = torch.ones(BATCH_SIZE, dtype=torch.float, device=device)

    batch = Transition(*zip(*transitions))
    
    batch_state = torch.cat(batch.state)
    batch_action = torch.cat(batch.action)
    batch_reward = torch.cat(batch.reward)
    batch_next_state = torch.cat(batch.next_state)

    if DISTRIBUTIONAL:
        batch_action = batch_action.unsqueeze(dim=-1).expand(-1, -1, N_ATOMS)

    net.reset_noise()
    state_action_values = net(batch_state).gather(1, batch_action)
    
    with torch.no_grad():
        if DOUBLE_DQN:
            next_action = get_next_action(batch_next_state, net)
        else:
            next_action = get_next_action(batch_next_state, target_net)
        target_net.reset_noise()
        next_values = target_net(batch_next_state).gather(1, next_action).squeeze(1)
        
        if DISTRIBUTIONAL:
            Tz = batch_reward.view(-1, 1) + (GAMMA**(N_STEP if MULTISTEP else 1)) * supports.view(1, -1)
            Tz = Tz.clamp(V_MIN, V_MAX)
            b = (Tz - V_MIN) / delta
            l = b.floor().to(torch.int64)
            u = b.ceil().to(torch.int64)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (N_ATOMS - 1)) * (l == u)] += 1

            offset = torch.linspace(0, (BATCH_SIZE - 1) * N_ATOMS, BATCH_SIZE).unsqueeze(dim=1).expand(BATCH_SIZE, N_ATOMS).to(batch_action)
            m = batch_state.new_zeros(BATCH_SIZE, N_ATOMS)
            m.view(-1).index_add_(0, (l + offset).view(-1), (next_values * (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
            m.view(-1).index_add_(0, (u + offset).view(-1), (next_values * (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

    if DISTRIBUTIONAL:
        loss = -(m * state_action_values.squeeze(1).log()).sum(-1)
        expected_state_action_values = ((next_values.sum(dim=1) * GAMMA ** (N_STEP if MULTISTEP else 1)) + batch_reward).unsqueeze(1)
    else:
        expected_state_action_values = ((next_values * (GAMMA ** (N_STEP if MULTISTEP else 1))) + batch_reward).unsqueeze(1)
        loss = (expected_state_action_values - state_action_values) ** 2 / 2

    optimizer.zero_grad()
    (loss * weights).mean().backward()
    optimizer.step()
    
    if PRIORITY_REPLAY:
        with torch.no_grad():
            if DISTRIBUTIONAL:
                memory.update(transitions_pos, torch.abs(expected_state_action_values - state_action_values.sum(dim=2)))
            else:
                memory.update(transitions_pos, torch.abs(expected_state_action_values - state_action_values))

## Training loop

In [542]:
num_episodes = [100] * 10 + [0] * 1
for epoch in range(len(num_episodes)):
    for i_episode in range(num_episodes[epoch]):
        state = torch.tensor([env.reset()], dtype=torch.float, device=device)
        for t in count():
            action = select_action(state)
            if epoch == len(num_episodes) - 1:
                env.render()
            next_state, reward, done, _ = env.step(action.item())
            next_state = torch.tensor([next_state], dtype=torch.float, device=device)
            if done and t + 1 != 500:
                reward = -10
            reward = torch.tensor([reward], dtype=torch.float, device=device)

            memory.push(state, action, next_state, reward)
            state = next_state
            optimize_model()

            if done:
                memory.flush_buffer()
                episode_durations.append(t + 1)
                break
        if i_episode % TARGET_UPDATE == 0:
            target_net.load_state_dict(net.state_dict())
            
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print('Complete epoch {}'.format(epoch), np.mean(episode_durations[-100:]))
    
env.close()
plot_durations()
plt.show()

Complete epoch 0 10.52
Complete epoch 1 9.49
Complete epoch 2 9.4
Complete epoch 3 9.58


KeyboardInterrupt: 