In [None]:
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import random
import os
import cv2
from scipy import misc
import matplotlib.pyplot as plt
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display

%matplotlib inline

In [None]:
env_name = 'Breakout-v0'

num_stacked_frames = 4
replay_memory_size = 250000
min_replay_size_to_update = 25000
lr = 0.001
gamma = 0.97
minibatch_size = 32
steps_rollout = 16
start_eps = 1
final_eps = 0.1
final_eps_frame = 1000000
total_steps = 5000000
target_net_update = 625
save_model_steps = 500000

In [None]:
device = torch.device("gpu")
dtype = torch.float

In [None]:
class Atari_Wrapper(gym.Wrapper):
    
    def __init__(self, env, env_name, k, dsize=(84,84)):
        super(Atari_Wrapper, self).__init__(env)
        self.dsize = dsize
        self.k = k
        self.frame_cutout_h = (31,-16)
        self.frame_cutout_w = (7,-7)
        
    def reset(self):
    
        self.Return = 0
        self.last_life_count = 0
        ob = self.env.reset()
        ob = self.preprocess_observation(ob)
        self.frame_stack = np.stack([ob for i in range(self.k)])
        
        return self.frame_stack
    
    
    def step(self, action):
        
        reward = 0
        done = False
        frames = []
        for i in range(self.k):
            ob, r, d, info = self.env.step(action)
            ob = self.preprocess_observation(ob)
            frames.append(ob)
            reward += r
            if d:
                done = True
                break
                
        self.step_frame_stack(frames)
        self.Return += reward
        if done:
            info["return"] = self.Return
        
        if reward > 0:
            reward = 1
        elif reward == 0:
            reward = 0
        else:
            reward = -1
            
        return self.frame_stack, reward, done, info
    
    def step_frame_stack(self, frames):
        num_frames = len(frames)
        if num_frames == self.k:
            self.frame_stack = np.stack(frames)
        elif num_frames > self.k:
            self.frame_stack = np.array(frames[-k::])
        else:
            self.frame_stack[0: self.k - num_frames] = self.frame_stack[num_frames::]
            self.frame_stack[self.k - num_frames::] = np.array(frames)  
            
    def preprocess_observation(self, ob):
        ob = cv2.cvtColor(ob[self.frame_cutout_h[0]:self.frame_cutout_h[1],self.frame_cutout_w[0]:self.frame_cutout_w[1]], cv2.COLOR_BGR2GRAY)
        ob = cv2.resize(ob, dsize=self.dsize)
    
        return ob

In [None]:
class DQN(nn.Module):
    def __init__(self, in_channels, num_actions):
        super().__init__()
        
        network = [
            torch.nn.Conv2d(in_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            torch.nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            torch.nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*7*7,512),
            nn.ReLU(),
            nn.Linear(512, num_actions)
        ]
        self.network = nn.Sequential(*network)
    
    def forward(self, x):
        actions = self.network(x)
        return actions

In [None]:
class Agent(nn.Module):
    
    def __init__(self, in_channels, num_actions, epsilon):
        super().__init__()
        self.in_channels = in_channels
        self.num_actions = num_actions
        self.network = DQN(in_channels, num_actions)
        self.eps = epsilon
    
    def forward(self, x):
        actions = self.network(x)
        return actions
    
    def e_greedy(self, x):
        actions = self.forward(x)
        greedy = torch.rand(1)
        if self.eps < greedy:
            return torch.argmax(actions)
        else:
            return (torch.rand(1) * self.num_actions).type('torch.LongTensor')[0] 
    
    def set_epsilon(self, epsilon):
        self.eps = epsilon

In [None]:
class Logger:
    def __init__(self, filename):
        self.filename = filename
        f = open(f"{self.filename}.csv", "w")
        f.close()
        
    def log(self, msg):
        f = open(f"{self.filename}.csv", "a+")
        f.write(f"{msg}\n")
        f.close()

In [None]:
class Experience_Replay():

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def insert(self, transitions):
        for i in range(len(transitions)):
            if len(self.memory) < self.capacity:
                self.memory.append(None)
            self.memory[self.position] = transitions[i]
            self.position = (self.position + 1) % self.capacity

    def get(self, batch_size):
        indexes = (np.random.rand(batch_size) * (len(self.memory)-1)).astype(int)
        return [self.memory[i] for i in indexes]

    def __len__(self):
        return len(self.memory)

In [None]:
class Env_Runner:
    
    def __init__(self, env, agent):
        super().__init__()
        self.env = env
        self.agent = agent
        self.logger = Logger("training/training_info")
        self.logger.log("training_step, training_return")
        self.ob = self.env.reset()
        self.total_steps = 0
        
    def run(self, steps):
        obs = []
        actions = []
        rewards = []
        dones = []
        for step in range(steps):
            self.ob = torch.tensor(self.ob)
            action = self.agent.e_greedy(
                self.ob.to(device).to(dtype).unsqueeze(0) / 255)
            action = action.detach().cpu().numpy()
            obs.append(self.ob)
            actions.append(action)
            self.ob, r, done, info = self.env.step(action)
            if done:
                self.ob = self.env.reset()
                if "return" in info:
                    self.logger.log(f'{self.total_steps+step},{info["return"]}')
            rewards.append(r)
            dones.append(done)
        self.total_steps += steps
        
        return obs, actions, rewards, dones
    
def make_transitions(obs, actions, rewards, dones):
    tuples = []
    steps = len(obs) - 1
    for t in range(steps):
        tuples.append((obs[t],actions[t],rewards[t],obs[t+1],int(not dones[t])))

    return tuples

In [None]:
raw_env = gym.make(env_name)
env = Atari_Wrapper(raw_env, env_name, num_stacked_frames)

in_channels = num_stacked_frames
num_actions = env.action_space.n

eps_interval = start_eps-final_eps

agent = Agent(in_channels, num_actions, start_eps).to(device)
target_agent = Agent(in_channels, num_actions, start_eps).to(device)
target_agent.load_state_dict(agent.state_dict())

replay = Experience_Replay(replay_memory_size)
runner = Env_Runner(env, agent)
optimizer = optim.Adam(agent.parameters(), lr=lr)
huber_loss = torch.nn.SmoothL1Loss()

num_steps = 0
num_model_updates = 0

start_time = time.time()
while num_steps < total_steps:
    new_epsilon = np.maximum(final_eps, start_eps - ( eps_interval * num_steps/final_eps_frame))
    agent.set_epsilon(new_epsilon)
    
    obs, actions, rewards, dones = runner.run(steps_rollout)
    transitions = make_transitions(obs, actions, rewards, dones)
    replay.insert(transitions)
    
    num_steps += steps_rollout
    
    if num_steps < min_replay_size_to_update:
        continue
    
    for update in range(4):
        optimizer.zero_grad()
        
        minibatch = replay.get(minibatch_size)
        
        obs = (torch.stack([i[0] for i in minibatch]).to(device).to(dtype)) / 255 
        
        actions = np.stack([i[1] for i in minibatch])
        rewards = torch.tensor([i[2] for i in minibatch]).to(device)
        
        next_obs = (torch.stack([i[3] for i in minibatch]).to(device).to(dtype)) / 255
        
        dones = torch.tensor([i[4] for i in minibatch]).to(device)
        
        Qs = agent(torch.cat([obs, next_obs]))
        obs_Q, next_obs_Q = torch.split(Qs, minibatch_size ,dim=0)
        
        obs_Q = obs_Q[range(minibatch_size), actions]
        
        next_obs_Q_max = torch.max(next_obs_Q,1)[1].detach()
        target_Q = target_agent(next_obs)[range(minibatch_size), next_obs_Q_max].detach()
        
        target = rewards + gamma * target_Q * dones
        
        loss = huber_loss(obs_Q, target)
        loss.backward()
        optimizer.step()
        
    num_model_updates += 1
     
    if num_model_updates%target_net_update == 0:
        target_agent.load_state_dict(agent.state_dict())

    if num_steps%50000 < steps_rollout:
        end_time = time.time()
        print(f'##### total steps: {num_steps} | loss: {loss} #####')
        start_time = time.time()
    
    if num_steps%save_model_steps < steps_rollout:
        torch.save(agent,f"{env_name}-{num_steps}.pt")

env.close()