In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import gym
import matplotlib.pyplot as plt

In [None]:
EPISODES = 1000
BATCH_SIZE = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Utils

In [None]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
class PytorchWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env) 
    
    def _episodic(self):
        """
        Episode terminates when agent loses one life
        """
        if(self.env.unwrapped.ale.lives() < 5):
            return True
        return False
    
    def _fire_to_start(self):
        """
        Env requires agent to perform FIRE to start
        """
        obs, _, _, _ = self.env.step(1)
        return obs
    
    def rgb2gray(self, rgb):
        return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])

    def _preprocess_obs(self, obs):
        obs = self.rgb2gray(obs)
        obs = torch.tensor(obs, dtype=torch.float)
        obs = obs[:, 25:200, :] ## Crop obs to remove score board on top
        obs = obs / 255
        return obs

    def step(self, action):
        obs, reward, done, _ = self.env.step(action)
        done = self._episodic()
        obs = self._preprocess_obs(obs)
        return obs, reward, done
    
    def reset(self):
        obs = self.env.reset()
        obs = self._fire_to_start()
        obs = self._preprocess_obs(obs)
        return obs

In [None]:
def make_env(env):
    env = gym.wrappers.FrameStack(env, 4) ## Stack 4 frames together
    env = PytorchWrapper(env)
    return env

# Policy Network - CNN

In [None]:
class Policy(nn.Module):
    def __init__(self, action_space, input_channels=3):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, stride=2)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=5, stride=2)

        self.pred_head = nn.Sequential(
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, action_space),
        )

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.dropout(x, 0.2)
        x = F.leaky_relu(self.conv2(x))
        x = F.dropout(x, 0.2)
        x = F.leaky_relu(self.conv3(x))
        x = F.dropout(x, 0.2)
        x = F.leaky_relu(self.conv4(x))
        x = x.mean((-1, -2))
        x = self.pred_head(x)
        return x

# Experience Replay (or) Replay Buffer

In [None]:
class ReplayBuffer:
    def __init__(self, buffer_size=100000):
        self.state = []
        self.action = []
        self.next_state = []
        self.reward = []
        self.buffer_size = buffer_size
    
    def store(self, state, action, next_state, reward):
        if(len(self.state) == self.buffer_size):
            self.state = self.state[1:]
            self.action = self.action[1:]
            self.next_state = self.next_state[1:]
            self.reward = self.reward[1:]
        
        self.state.append(state)
        self.action.append(action)
        self.next_state.append(next_state)
        self.reward.append(reward)
    
    def sample_batch(self, batch_size):
        idxs = np.random.choice(len(self.state), batch_size)
        state = torch.stack(self.state)[idxs]
        action = torch.tensor(self.action, dtype=torch.long)[idxs]
        next_state = torch.stack(self.next_state)[idxs]
        reward = torch.tensor(self.reward, dtype=torch.float)[idxs]
        return (state, action, next_state, reward)
    
    def __len__(self):
        return len(self.state)

# Q-learning Agent

In [None]:
class DQN:
    def __init__(self, action_size, device, gamma=0.99, lr=0.001, model_path=None):
        self.target = Policy(input_channels=4, action_space=action_size).to(device)
        self.target.eval()
        self.policy = Policy(input_channels=4, action_space=action_size).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.device = device
        self.buffer = ReplayBuffer()
        self.gamma = gamma
        self.action_size = action_size
        self.model_path = "Double_DQN.bin" if not model_path else model_path
    
    def loss_fct(self, target, pred):
        return F.smooth_l1_loss(pred, target)
    
    def forward(self, policy, obs, grad=False):
        obs = obs.to(self.device)
        q_values = policy(obs)
        if(not grad):
            q_values = q_values.detach()
        action = torch.argmax(q_values, 1)
        return q_values, action
    
    def optimize_policy(self, batch):
        self.optimizer.zero_grad()
        state, action, next_state, reward = batch
        action = action.unsqueeze(1).to(device)
        reward = reward.to(self.device)
        Q, _ = self.forward(self.policy, state, grad=True)
        _, next_action = self.forward(self.policy, next_state)
        next_Q, _ = self.forward(self.target, next_state)
        ## Target value estimation is made using both networks. Prevents overestimation
        Q_target = next_Q.gather(1, next_action.unsqueeze(-1)).squeeze()
        target = reward + self.gamma * Q_target
        Q = Q.gather(1, action).squeeze()
        loss = self.loss_fct(Q, target)

        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    def update_target(self):
        self.policy.eval()
        self.target.load_state_dict(self.policy.state_dict())
        torch.save(self.target.state_dict(), self.model_path)
        self.policy.train()
    
    def load_policy(self, path=None):
        if path is None:
            path = self.model_path
        
        self.target.load_state_dict(torch.load(path))
        print("Successfully loaded")
    
    def evaluate_policy(self, env):
        obs = env.reset()
        done = False
        while(not done):
            obs = obs.to(self.device)
            env.render()
            with torch.no_grad():
                q_values = self.target(obs)
            action = torch.argmax(q_values, 1).item()
            obs, reward, done = env.step(action)
    
    def learn(self, env, episodes, batch_size):
        writer = SummaryWriter()
        counter = 1
        writer_count = 0
        print("--- Training Agent ---")
        self.update_target()
        for eps in range(episodes):
            obs = env.reset()
            loss_tracker = AverageMeter()
            for t in range(10000):
                if(np.random.rand() <= 0.3): ## Epsilon greedy
                    action = np.random.randint(self.action_size)
                else:
                    _, action = self.forward(self.policy, obs.unsqueeze(0))
                    action = action.item()
                next_obs, reward, done = env.step(action)
                self.buffer.store(obs, action, next_obs, reward)

                if(len(self.buffer) >= batch_size):
                    batch = self.buffer.sample_batch(batch_size)
                    loss = self.optimize_policy(batch)
                    loss_tracker.update(loss)
                    writer.add_scalar('Loss', loss, writer_count)
                    writer_count += 1
                    if(counter % 500 == 0): ## Delayed update of target. Promotes exploration
                        self.update_target()
                
                if((t + 1) % 10 == 0):
                    print(f"Episode: {eps+1}/{episodes}, step: {t+1}/10000, loss: {loss_tracker.avg}")
                
                if done: break

                counter += 1
                obs = next_obs
            

# Train Agent on Atari Env

In [None]:
env = gym.make('Breakout-v0') ## if error encountered, perform 'pip install gym[atari]'
action_size = env.action_space.n
env = make_env(env)

In [None]:
agent = DQN(action_size, device)

In [None]:
## Load tensorboard for visualization of loss
%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
agent.learn(env, EPISODES, BATCH_SIZE)