In [None]:
import os
import gym
import time
import copy
import random
import numpy as np

import torch
import torchvision
import torch.nn as nn

from collections import deque
from skimage.color import rgb2grey
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

In [None]:
class DeepQNetwork(nn.Module):
    def __init__(self, num_frames, num_actions):
        super(DeepQNetwork, self).__init__()
        
        # Layers
        self.conv1 = nn.Conv2d(
            in_channels=num_frames,
            out_channels=32,
            kernel_size=3,
            stride=2,
            padding=1
            )
        self.conv2 = nn.Conv2d(
            in_channels=32,
            out_channels=64,
            kernel_size=3,
            stride=2,
            padding=1
            )
        self.conv3 = nn.Conv2d(
            in_channels=64,
            out_channels=128,
            kernel_size=3,
            stride=2,
            padding=1
            )
        self.conv4 = nn.Conv2d(
            in_channels=128,
            out_channels=256,
            kernel_size=3,
            stride=2,
            padding=1
            )
        self.fc1 = nn.Linear(
            in_features=25600,
            out_features=512,
            )
        self.fc2 = nn.Linear(
            in_features=512,
            out_features=num_actions
            )
        
        # Activations
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=0)
    
    def flatten(self, x):
        x = x.view(-1)
        return x
    
    def forward(self, x):
        
        # Forward pass
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.softmax(self.fc2(x))
        
        return x

In [None]:
class Agent:
    def __init__(self, DQN, memory_depth, epsilon_i, epsilon_f, anneal_time):
        self.DQN = DQN
        self.memory_depth = memory_depth
        self.e_i = epsilon_i
        self.e_f = epsilon_f
        self.anneal_time = anneal_time
        
        self.memory = deque(maxlen=memory_depth)
    
    def clone(self, model):
        return copy.deepcopy(model)
    
    def remember(self, state, action, reward, terminal, next_state):
        self.memory.append([state, action, reward, terminal, next_state])
    
    def retrieve(self, batch_size):
        if batch_size > self.memories:
            batch_size = self.memories
        
        return random.sample(self.memory, batch_size)
    
    @property
    def memories(self):
        return len(self.memory)
    
    def act(self, state):
        q_values = self.DQN(state).detach()
        action = np.argmax(q_values.numpy())
        return action
    
    def process(self, state):
        state = rgb2grey(state[35:195, :, :])
        state = state[np.newaxis, np.newaxis, :, :]
        return to_tensor(state)
    
    def exploration_rate(self, t):
        if t < self.anneal_time:
            return self.e_i - t*(self.e_i - self.e_f)/self.anneal_time
        elif t >= self.anneal_time:
            return self.e_f

In [None]:
# Hyperparameters

update_interval = 40
num_frames = 4
num_actions = 4
episodes = 100
memory_depth = int(1e5)
epsilon_i = 0.0
epsilon_f = 0.0
anneal_time = 10000
gamma = 0.9

In [None]:
model = DeepQNetwork(num_frames, num_actions)

In [None]:
agent = Agent(model, memory_depth, epsilon_i, epsilon_f, anneal_time)

In [None]:
cuda = True if torch.cuda.is_available() else False
to_tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [None]:
env = gym.make('Breakout-v0')

In [None]:
def q_iteration(episodes):
    
    for episode in tqdm(range(episodes)):
        
        state = env.reset()
        state = agent.process(state)
        
        done = False
        t = 0

        while not done:
            
            env.render()
            
            while state.size()[1] < num_frames:
                action = np.random.choice(num_actions)
                
                new_frame, reward, done, info = env.step(action)
                new_frame = agent.process(new_frame)
                
                state = torch.cat([state, new_frame], 1)
                
            if np.random.uniform() < agent.exploration_rate(t):
                action = np.random.choice(num_actions)

            else:
                action = agent.act(state)
                
            print(action)

            new_frame, reward, done, info = env.step(action)
            new_frame = agent.process(new_frame)
            
            print(reward, done)
            
            new_state = torch.cat([state, new_frame], 1)
            new_state = new_state[:, 1:, :, :]

            agent.remember(state, action, reward, done, new_state)

            state = new_state
            t += 1
            
            if t % update_interval == 0:
                pass

            if done:
                print("Episode {}: Episode completed after {} timesteps".format(episode, t))

In [None]:
q_iteration(100)