In [1]:
from collections import namedtuple
import gym
from itertools import count
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

%matplotlib inline

In [6]:
# Set up display
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython: from IPython import display

Define deep Q-network

In [5]:
class DQN(nn.Module):
    
    def __init__(self, img_height, img_width):
        super().__init__()
        
        self.fc1 = nn.Linear(in_features=img_height*img_width*3, out_features=24)
        self.fc2 = nn.Linear(in_features=24, out_features=32)
        self.out = nn.Linear(in_features=22, out_features=2)

    def forward(self, t):
        t = t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

Experiences from replay memory will be used to train the network

In [7]:
Experience = namedtuple(
    'Experience',
    ('state', 'action', 'next_state', 'reward')
)

Replay memory will store experiences

In [9]:
class ReplayMemory():
    
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.push_count = 0
    
    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            # New experiences begin being pushed onto the front 
            # of memory, overwriting the oldest experiences first.
            self.memory[self.push_count % self.capacity] = experience
        self.push_count += 1
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

Epsilon greedy strategy determines whether or not to employ explore or exploit during training

In [10]:
class EpsilonGreedyStrategy():
    
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay
    
    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * math.exp(-1. * current_step * self.decay)

Reinforcement learning agent

In [12]:
class Agent():
    
    def __init__(self, strategy, num_actions):
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = num_actions
    
    def select_action(self, state, policy_net):
        # strategy defined by EpsilonGreedyStrategy class
        rate = self.strategy.get_exploration_rate(self.current_step)
        self.current_step += 1
        
        if rate > random.random():
            return random.randrange(self.num_actions) # explore
        else:
            # turn off gradient tracking as only using model for 
            # inference not training
            with torch.no_grad():
                return policy_net(state).argmax(dim=1).item() # exploit