In [1]:
from collections import namedtuple
import gym
from itertools import count
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

%matplotlib inline

In [6]:
# Set up display
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython: from IPython import display

Define deep Q-network

In [5]:
class DQN(nn.Module):
    
    def __init__(self, img_height, img_width):
        super().__init__()
        
        self.fc1 = nn.Linear(in_features=img_height*img_width*3, out_features=24)
        self.fc2 = nn.Linear(in_features=24, out_features=32)
        self.out = nn.Linear(in_features=22, out_features=2)

    def forward(self, t):
        t = t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

Experiences from replay memory will be used to train the network

In [7]:
Experience = namedtuple(
    'Experience',
    ('state', 'action', 'next_state', 'reward')
)

Replay memory will store experiences

In [9]:
class ReplayMemory():
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.push_count = 0
    
    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            # New experiences begin being pushed onto the front 
            # of memory, overwriting the oldest experiences first.
            self.memory[self.push_count % self.capacity] = experience
        self.push_count += 1
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

Epsilon greedy strategy determines whether or not to employ explore or exploit during training

In [10]:
class EpsilonGreedyStrategy():
    
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay
    
    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * math.exp(-1. * current_step * self.decay)

Reinforcement learning agent

In [12]:
class Agent():
    
    def __init__(self, strategy, num_actions, device):
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device # i.e. GPU or CPU
    
    def select_action(self, state, policy_net):
        # strategy defined by EpsilonGreedyStrategy class
        rate = self.strategy.get_exploration_rate(self.current_step)
        self.current_step += 1
        
        if rate > random.random():
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(self.device)
        else:
            # turn off gradient tracking as only using model for 
            # inference not training
            with torch.no_grad():
                return policy_net(state).argmax(dim=1).to(self.device) # exploit

Environment manager

In [27]:
class CartPoleEnvManager():
    
    def __init__(self, device):
        self.device = device
        self.env = gym.make('CartPole-v0').unwrapped
        self.env.reset()
        self.current_screen = None
        self.done = False
    
    def reset(self):
        self.env.reset()
        self.current_screen = None
    
    def close(self):
        self.env.close()
    
    def render(self, mode='human'):
        return self.env.render(mode)
    
    def num_actions_available(self):
        return self.env.action_space.n
    
    def take_action(self, action):
        _, reward, self.done, _ = self.env.step(action.item())
        return torch.tensor([reward], device=self.device)
    
    def just_starting(self):
        return self.current_screen is None

    def get_state(self):
        if self.just_starting() or self.done:
            self.current_screen = self.get_processed_screen()
            black_screen = torch.zeros_like(self.current_screen)
            return black_screen
        else:
            s1 = self.current_screen
            s2 = self.get_processed_screen()
            self.current_screen = s2
            return s2 - s1

    def get_screen_height(self):
        screen = self.get_processed_screen()
        return screen.shape[2]

    def get_screen_width(self):
        screen = self.get_processed_screen()
        return screen.shape[3]

    def get_processed_screen(self):
        screen = self.render('rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW
        screen = self.crop_screen(screen)
        return self.transform_screen_data(screen)
    
    def crop_screen(self, screen):
        screen_height = screen.shape[1]

        # Strip off top and bottom
        top = int(screen_height * 0.4)
        bottom = int(screen_height * 0.8)
        screen = screen[:, top:bottom, :]
        return screen

    def transform_screen_data(self, screen):       
        # Convert to float, rescale, convert to tensor
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
        screen = torch.from_numpy(screen)

        # Use torchvision package to compose image transforms
        resize = T.Compose([
            T.ToPILImage()
            ,T.Resize((40,90))
            ,T.ToTensor()
        ])

        return resize(screen).unsqueeze(0).to(self.device) # add a batch dimension (BCHW)

Non-processed screen example

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
em = CartPoleEnvManager(device)
em.reset()
screen = em.render('rgb_array')

# plt.figure()
# plt.imshow(screen)
# plt.title('Non-processed screen example')
# plt.show()

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/Users/thomasmartin/.local/share/virtualenvs/INM707-deep-learning-Rm3VDUC3/lib/python3.7/site-packages/pyglet/__init__.py", line 329, in __getattr__
    return getattr(self._module, name)
AttributeError: 'NoneType' object has no attribute 'get_script_home'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/thomasmartin/.local/share/virtualenvs/INM707-deep-learning-Rm3VDUC3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-28-977a3fe71549>", line 4, in <module>
    screen = em.render('rgb_array')
  File "<ipython-input-27-72f4e2dd1c78>", line 18, in render
    return self.env.render(mode)
  File "/Users/thomasmartin/.local/share/virtualenvs/INM707-deep-learning-Rm3VDUC3/lib/python3.7/site-packages/gym/envs/classic_control/cartpole.py", line 150, in render
    from g

TypeError: can only concatenate str (not "list") to str