# DQN with Pytorch

You will use Pytorch to create a DQN for the cartpole gym example. 

Let's import the necessary libraries

# Google Colab
### Rendering Dependancies
If you prefer you can use Google Colab to prevent your CPU from overheating while training, if you don't have a GPU:

1. Open this notebook in Colab https://colab.research.google.com/
2. Run the following snippet there


In [0]:
# Run this for google colab
!pip install gym pyvirtualdisplay
!apt-get install -y xvfb python-opengl ffmpeg

from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

# Imports and Helper functions


In [0]:
import gym
from gym import logger as gymlogger
from gym.wrappers import Monitor
gymlogger.set_level(40) #error only
import numpy as np
import random

# We will use matplot to plot our progress during training
import matplotlib 
import matplotlib.pyplot as plt
%matplotlib inline
import math
import glob
import io
import base64
import collections

from IPython.display import HTML

from IPython import display as ipythondisplay

from itertools import count
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T    

In [0]:
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython: from IPython import display

# Neural Network
The neural network in this case is pretty simple. We are using fully conntected layers. 
Of course it could be replaced by something more complex like a CNN.

In [0]:
class DQN(nn.Module):
    def __init__(self, img_height, img_width):
        super().__init__()
            
        self.fc1 = nn.Linear(in_features=img_height*img_width*3, out_features=24)   
        self.fc2 = nn.Linear(in_features=24, out_features=32)
        self.out = nn.Linear(in_features=32, out_features=2)
        
    def forward(self, t):
        t = t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

# Replay Memory
We will use a deque as our replay memory. A deque automatically removes the first element after appending an Experience to the end if the max lenghth is reached.

In [0]:
Experience = collections.namedtuple(
    'Experience',
    (
        'state', 
        'action', 
        'next_state', 
        'reward', 
        'done' # we also store if the episode was completed after taking the step
    )
)

class ReplayMemory():
    def __init__(self, capacity):
        self.memory = collections.deque(maxlen=capacity)  
    def append(self, experience):
        self.memory.append(experience)
    # ranomly select experiences from memory of batch_size 
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

# Screen
This are some helper functions to obtain a processed current screen from the environment. 

`get_screen()` returns a 4D tensor as (Batch, Color-Channel, Height, Width), which is a necessary order for Pytorch. 

Another reason for processing is to make training faster. We crop the whitespaces from the screen and get smaller images.

In [0]:
resize = T.Compose([T.ToPILImage(),
                    T.Resize(40, interpolation=Image.CUBIC),
                    T.ToTensor()])


def get_cart_location(screen_width):
    world_width = env.x_threshold * 2
    scale = screen_width / world_width
    return int(env.state[0] * scale + screen_width / 2.0)  # MIDDLE OF CART

def get_screen(env, device):
    # Returned screen requested by gym is 400x600x3, but is sometimes larger
    # such as 800x1200x3. Transpose it into torch order (CHW).
    screen = env.render(mode='rgb_array').transpose((2, 0, 1))
    # Cart is in the lower half, so strip off the top and bottom of the screen
    _, screen_height, screen_width = screen.shape
    screen = screen[:, int(screen_height*0.4):int(screen_height * 0.8)]
    view_width = int(screen_width * 0.6)
    cart_location = get_cart_location(screen_width)
    if cart_location < view_width // 2:
        slice_range = slice(view_width)
    elif cart_location > (screen_width - view_width // 2):
        slice_range = slice(-view_width, None)
    else:
        slice_range = slice(cart_location - view_width // 2,
                            cart_location + view_width // 2)
    # Strip off the edges, so that we have a square image centered on a cart
    screen = screen[:, :, slice_range]
    # Convert to float, rescale, convert to torch tensor
    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
    screen = torch.from_numpy(screen)
    # Resize, and add a batch dimension (BCHW)
    return resize(screen).unsqueeze(0).to(device)

Next, we'll define some helper functions for plotting our progress.

In [0]:
def plot(values, moving_avg_period):
    plt.figure(2)
    plt.clf()        
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(values)

    moving_avg = get_moving_average(moving_avg_period, values)
    plt.plot(moving_avg)    
    plt.pause(0.001)
    print("Episode", len(values), "\n", \
        moving_avg_period, "episode moving avg:", moving_avg[-1])
    if is_ipython: display.clear_output(wait=True)

def get_moving_average(period, values):
    values = torch.tensor(values, dtype=torch.float)
    if len(values) >= period:
        moving_avg = values.unfold(dimension=0, size=period, step=1).mean(dim=1).flatten(start_dim=0)
        moving_avg = torch.cat((torch.zeros(period-1), moving_avg))
        return moving_avg.numpy()
    else:
        moving_avg = torch.zeros(len(values))
        return moving_avg.numpy()

# Hyperparameters
We set up the hyperparameters. 

Notice that the number of episodes `num_episodes` is really low because traing takes some time. 

If you have a GPU and some time at home you can try it out with a higher number like `1000`.

In [0]:
batch_size = 256
gamma = 0.999

# Needed for our epsilon-greedy-method
eps_start = 1
eps_end = 0.01
eps_decay = 0.001

target_update = 10

memory_size = 100000
lr = 0.001

num_episodes = 120 # Set to a higher number like 1000 at home when you have more time

# Initialization


We set up the device which will be used by pytorch during training. If you have a CUDA GPU then it will be used.

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
env = gym.make('CartPole-v0').unwrapped
env.reset()

Let's use our `get_screen()` method to have a look at our cart pole environment

In [0]:
screen = get_screen(env, device)
screen = env.render('rgb_array')

plt.figure()
plt.imshow(screen)
plt.title('Non-processed screen example')
plt.show()

In [0]:
memory = ReplayMemory(memory_size)

We need the screen height and width to initialize our both networks

In [0]:
screen = get_screen(env, device)
_, _, screen_height, screen_width = screen.shape

Next we set up the policy network and a target network as copy of the policy network. To do that we need the height and width of a single screen.

In [0]:
policy_net = DQN(screen_height, screen_width).to(device)
target_net = DQN(screen_height, screen_width).to(device)

Using Pytorch the learnable parameters like weights and biases are under the model's parameters.

A state_dict is a dictionary object that maps each layer to its parameter tensor. We copy the state_dict of the policy_net to the target_net.

In [0]:
target_net.load_state_dict(policy_net.state_dict())

We will set up the target_net for evaluation only. We only do train the policy_net. The target_net is updated manually later:

In [0]:
target_net.eval()

We will use Adam for the optimization process. It combines AdaGrad and RMSProp:

In [0]:
optimizer = optim.Adam(params=policy_net.parameters(), lr=lr)

# Some python particularities
Before we will start with the training let's analyse some python specific things.

## Use * to pass arguments to a function
When putting the asteriks * at the beginning of an iterable (A list is an iterable for example) to pass all elements of the iterable to a function each as a separate argument. This is very usefull when you don't know how many elements a list has for example.

Have a look at the following example:

In [0]:
test = ["hi", "what's", "up?"]
print(test)

# the following examples are both equivalent:
print(test[0], test[1], test[2])
print(*test)

## zip
Zip takes for example multiple lists and outputs a zip object, which is an iterator of tuples. The elements with the same index of the different iterables are put toghether as seperate tuples in the new list.

In [0]:
a = ["a", "b", "c"]
b = ["1", "2", "3"]

x = zip(a, b) # contains zip object with [(a,1), (b,2), (c,3)]

To unpack the content of the zip object we can use the asteriks again

In [0]:
print(*x)

In [0]:
a = [ # batches of experiences
        ( # first batch element 
            [1, 2, 3], # state
            2, # action
            [6, 5, 4], # next_state
            4, #reward
            False
        ), 
        ( # second batch element
            [5, 4, 3], # state
            0, # action
            [6, 5, 4], # next_state
            4, #reward,
            False
        ), 
        ( # third batch element
            [2, 1, 2], # state
            0, # action
            [5, 10, 9], # next_state
            4, #reward
            True
        ), 
]

e = Experience(*zip(*a))
print("experience of batch-arrays", e)

states, actions, next_states, rewards, dones = e
print("states", states)

Notice how all states of all batch elements are now under the same tuple state. The same is true for the actions, next_states and rewards.

We will use that later in our algorithm.

# Difference of the last two frames

To make it easier for the network to differentiate between states we will create a state as the difference of the last two frames later. Let's have a look at the following example:

In [0]:
frame1 = get_screen(env, device)
frame2 = get_screen(env, device)

difference = frame1 - frame2

plt.figure()
plt.imshow(difference.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')
plt.title('Starting state example')
plt.show()

In this case we get a black screen, because frame1 and frame2 are the same.

Let's reset the environment in-between. This should get us a non black image.

In [0]:
frame1 = get_screen(env, device)
env.reset()
frame2 = get_screen(env, device)

difference = frame1 - frame2

plt.figure()
plt.imshow(difference.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')
plt.title('Starting state example')
plt.show()

Now we can see only the parts that are truly different. This will make it easier for the network to take the movement into account. Those states are what we feed the network with.

# Training

1. Have a look at the training algorithm first
2. Try to implement the epsilon greedy algorithmn to select an action

In [0]:
# used to calculate the epsilon value
total_steps = 0

# necessary for plotting
episode_durations = []

for episode in range(num_episodes):
    env.reset()

    # In the beginning the screen is black
    current_screen = get_screen(env, device)
    black_screen = torch.zeros_like(current_screen)

    state = black_screen # start with black screen

    for timestep in count():
      # take action depending on epsilon
      epsilon = ... # calculate epsilon
      total_steps += 1
      if epsilon > random.random():
        action = random.randrange(env.action_space.n)
        action = torch.tensor([action]).to(device) # explore
      else:
          with torch.no_grad():
            action = policy_net(state).argmax(dim=1).to(device) # exploit

      _, reward, done, _ = env.step(action.item())
      reward = torch.tensor([reward], device=device)

      next_state = None
      if not done:
        s1 = current_screen
        s2 = get_screen(env, device)
        current_screen = s2

        # The next state is the difference of the frames s2 and s1
        next_state = s2 - s1
      else:
        # black screen if we are done
        next_state = black_screen
      
      is_done = torch.tensor([done], dtype=torch.bool, device=device)
      memory.append(Experience(state, action, next_state, reward, is_done))
      state = next_state

      if memory.can_provide_sample(batch_size):
        experiences = memory.sample(batch_size)

        states, actions, next_states, rewards, dones = Experience(*zip(*(experiences)))
        states = torch.cat(states)
        actions = torch.cat(actions)
        next_states = torch.cat(next_states)
        rewards = torch.cat(rewards)
        done_mask = torch.cat(dones)
        
        # calculate Q(s,a)
        q_values = policy_net(states).gather(dim=1, index=actions.unsqueeze(-1))

        # calculate Q_target(s_a)
        next_q_values = target_net(next_states).detach() # detach --> no gradient will be backproped for next_q_values
        max_next_q_values = next_q_values.max(dim=1).values
        # target = r if episode ended, that's why we set to 0 for states after which episode ended
        max_next_q_values[done_mask] = 0.0 #  next_states after which are done should NOT be considered
        
        # target = r + y * maxQ_target(s', a')
        target = rewards + gamma * max_next_q_values  

        loss = F.mse_loss(q_values, target.unsqueeze(1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
      if done:
        episode_durations.append(timestep)
        plot(episode_durations, 100)
        break
      if episode % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())
      env.close()

# References: 
Pytorch Documentation: https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

DeepLizard: https://deeplizard.com/

HandsOnReinforcement Learning: https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On/tree/master/Chapter06