In [1]:
import random
import math
import os
import time
import gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch import FloatTensor, LongTensor
import torch.nn.functional as F
import matplotlib.pyplot as plt

from ale_py.roms import Breakout
from ale_py import ALEInterface

In [None]:
EPISODES = 50
EPS1 = 0.9
EPS2 = 0.05
EPS_DECAY = 200
GAMMA = 0.8
LR = 0.001
BATCHSIZE = 64

ale = ALEInterface()
ale.loadROM(Breakout)
env = gym.make('ALE/Breakout-v5', render_mode='human')
state = env.reset()
state_shape = state.shape

In [3]:
class ReplayMem:
    def __init__(self, cap):
        self.cap = cap
        self.mem = []
        
    def push(self, obj):
        self.mem.append(obj)
        if len(self.mem) > self.cap:
            del self.mem[0]
    
    def sample(self, batch_size):
        return random.sample(self.mem, batch_size)
    
    def __len__(self):
        return len(self.mem)

class DQNet(nn.Module):
    def __init__(self):
        super(DQNet, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2),
            nn.Conv2d(in_channels=16, out_channels=48, kernel_size=3, stride=2),
            nn.Conv2d(in_channels=48, out_channels=64, kernel_size=3, stride=2),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=2),
            nn.Flatten(),
            nn.Linear(640, 64),
            nn.Linear(64, 4)
        )
    
    def forward(self, x):
        return self.net(x)

class Agent:
    def __init__(self):
        self.replay_mem = ReplayMem(400)
        self.net = DQNet()
        self.optimizer = optim.Adam(self.net.parameters(), LR)
        self.actions_taken = 0
    
    def next_action(self, state, training=True):
        eps = EPS2 + (EPS1 - EPS2) * math.exp(-1. * self.actions_taken / EPS_DECAY)
        if not training:
            eps = 0.2
        if random.random() > eps:
            return self.net(Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1].view(1,1)
        else:
            return LongTensor([[random.randrange(4)]])
    
    def run_episode(self, episode, env, training=True):
        state = env.reset()
        while True:
            action = self.next_action(FloatTensor([state]).permute(0, 3, 1, 2), training=training)
            next_state, reward, done, _ = env.step(action[0,0])
            if training:
                self.replay_mem.push((
                    FloatTensor([state]).permute(0,3,1,2),
                    action,
                    FloatTensor([next_state]).permute(0,3,1,2),
                    FloatTensor([reward])
                ))
                self.learn()
            state = next_state
            if done:
                print("episode " + str(episode) + " finished")
                if training: 
                    self.save_chk(episode)
                break
                
    
    def learn(self):
        if len(self.replay_mem) < BATCHSIZE:
            return
        
        transitions = self.replay_mem.sample(BATCHSIZE)
        batch_state, batch_action, batch_next_state, batch_reward = zip(*transitions)
        batch_state = Variable(torch.cat(batch_state))
        batch_action = Variable(torch.cat(batch_action))
        batch_next_state = Variable(torch.cat(batch_next_state))
        batch_reward = Variable(torch.cat(batch_reward))
        
        curr_q_values = self.net(batch_state).gather(1, batch_action)
        max_next_q_values = self.net(batch_next_state).detach().max(1)[0]
        expected_q_values = batch_reward + (GAMMA * max_next_q_values)
        loss = F.smooth_l1_loss(curr_q_values.squeeze(), expected_q_values)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
    
    def save_chk(self, episode):
        torch.save(self.net.state_dict(), os.path.join("/Users/stack/Documents/warudo/dqnchk/", str(episode) + ".pth"))
    
    def load_chk(self, episode):
        self.net.load_state_dict(torch.load(os.path.join("/Users/stack/Documents/warudo/dqnchk/", str(episode) + ".pth")))

In [None]:
agent = Agent()
agent.load_chk(49)
for episode in range(EPISODES):
    agent.run_episode(episode, env, training=True)