In [12]:
import numpy as np
import gym
import math
import random
from collections import namedtuple
from itertools import count
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.autograd import Variable

%matplotlib inline

In [13]:
env = gym.make('CartPole-v1')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory:
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = list()
        self.position = 0
    
    def push(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        
        self.position += 1
        if self.position == self.capacity:
            self.position = 0
            
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
        

In [15]:
class NoisyLinear(nn.Module):
    def __init__(self, in_sz, out_sz, std_init=0.1):
        super(NoisyLinear, self).__init__()
        self.in_sz = in_sz
        self.out_sz = out_sz
        self.std_init = std_init
        
        self.weight_mu = nn.Parameter(torch.empty(out_sz, in_sz, device=device))
        self.weight_sigma = nn.Parameter(torch.empty(out_sz, in_sz, device=device))
        self.register_buffer('weight_epsilon', torch.empty(out_sz, in_sz, device=device))
        
        self.bias_mu = nn.Parameter(torch.empty(out_sz, device=device))
        self.bias_sigma = nn.Parameter(torch.empty(out_sz, device=device))
        self.register_buffer('bias_epsilon', torch.empty(out_sz, device=device))
        
        mu_range = 1 / math.sqrt(self.in_sz)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.std_init / math.sqrt(self.in_sz))
        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.std_init / math.sqrt(self.out_sz))
        
        self.reset_noise()
        
    def reset_noise(self):
        epsilon_in = self._scale_noise(self.in_sz)
        epsilon_out = self._scale_noise(self.out_sz)
        self.weight_epsilon.copy_(epsilon_out.ger(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)
        
    def _scale_noise(self, size):
        x = torch.randn(size, device=device)
        return x.sign().mul_(x.abs().sqrt_())

    def forward(self, input):
        return F.linear(input, self.weight_mu + self.weight_sigma * self.weight_epsilon, self.bias_mu + self.bias_sigma * self.bias_epsilon)

In [16]:
class DQN(nn.Module):

    def __init__(self, outputs):
        super(DQN, self).__init__()              
        self.first = NoisyLinear(4, 320)
        self.second = NoisyLinear(320, outputs)

    def forward(self, x):
        x = F.relu(self.first(x))
        x = self.second(x)
        return x
    
    def reset_noise(self):
        self.first.reset_noise()
        self.second.reset_noise()

In [17]:
BATCH_SIZE = 128
GAMMA = 0.999
MEMORY_SIZE = 100000
TARGET_UPDATE = 10

In [18]:
n_actions = env.action_space.n
steps_done = 0
episode_durations = list()

net = DQN(n_actions).to(device)
target_net = DQN(n_actions).to(device)
optimizer = optim.Adam(net.parameters(), lr=0.0001)
memory = ReplayMemory(MEMORY_SIZE)

In [19]:
def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated

In [20]:
def select_action(state):
    global steps_done
    steps_done += 1
    with torch.no_grad():
        net.reset_noise()
        return net(state).max(1)[1].view(1, 1)

In [21]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)

    batch = Transition(*zip(*transitions))

    batch_state = torch.cat(batch.state)
    batch_action = torch.cat(batch.action)
    batch_reward = torch.cat(batch.reward)
    batch_next_state = torch.cat(batch.next_state)
    
    net.reset_noise()
    state_action_values = net(batch_state).gather(1, batch_action)
    with torch.no_grad():
        max_next_action = target_net(batch_next_state).max(1)[1].unsqueeze(1)
        target_net.reset_noise()
        next_values = target_net(batch_next_state).gather(1, max_next_action).squeeze(1)
    
        expected_state_action_values = ((next_values * GAMMA) + batch_reward).unsqueeze(1)

    loss = (expected_state_action_values - state_action_values) ** 2 / 2
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
num_episodes = [50] * 10 + [0]
for epoch in range(len(num_episodes)):
    for i_episode in range(num_episodes[epoch]):
        state = torch.tensor([env.reset()], dtype=torch.float, device=device)
        for t in count():
            action = select_action(state)
            if epoch == len(num_episodes) - 1:
                env.render()
            next_state, reward, done, _ = env.step(action.item())
            next_state = torch.tensor([next_state], dtype=torch.float, device=device)
            if done and t + 1 != 500:
                reward = -10
            reward = torch.tensor([reward], dtype=torch.float, device=device)

            memory.push(state, action, next_state, reward)
            state = next_state
            optimize_model()

            if done:
                episode_durations.append(t + 1)
                break
        if i_episode % TARGET_UPDATE == 0:
            target_net.load_state_dict(net.state_dict())
            
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    print('Complete epoch {}'.format(epoch), np.mean(episode_durations[-200:]))
    
env.close()
plot_durations()
plt.show()

Complete epoch 0 15.28
Complete epoch 1 23.24
Complete epoch 2 36.98
Complete epoch 3 75.77
