In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import random
from collections import namedtuple
from itertools import count
from time import time

In [None]:
import argparse
# from test import test
from environment import Environment

def parse():
    parser = argparse.ArgumentParser(description="MLDS&ADL HW3")
    parser.add_argument('--env_name', default=None, help='environment name')
    parser.add_argument('--train_pg', action='store_true', help='whether train policy gradient')
    parser.add_argument('--train_dqn', action='store_true', help='whether train DQN')
    parser.add_argument('--test_pg', action='store_true', help='whether test policy gradient')
    parser.add_argument('--test_dqn', action='store_true', help='whether test DQN')
    parser.add_argument('--video_dir', default=None, help='output video directory')
    parser.add_argument('--do_render', action='store_true', help='whether render environment')
    try:
        from argument import add_arguments
        parser = add_arguments(parser)
    except:
        pass
    return parser

parser = parse()
args = parser.parse_args(['--train_dqn'])
env_name = args.env_name or 'BreakoutNoFrameskip-v4'
env = Environment(env_name, args, atari_wrapper=True)

In [None]:
USE_CUDA = False #torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # TODO: initialization?
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(3136, 512)
        self.fc2 = nn.Linear(512, 4)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
MEMORY_SIZE = 10000
TARGET_UPDATE_FREQ = 1000
EVAL_UPDATE_FREQ = 4
BATCH_SIZE = 32
GAMMA = 0.99
EPSILON_START = 0.99
EPSILON_END = 0.05
SAVE_EVERY = 10000
# NUM_ACTIONS = env.get_action_space().n

In [None]:
Transition = namedtuple('Transition', ('s', 'a', 's_', 'r'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = []
        self.index = 0
        self.capacity = capacity
    
    def get_batch(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        return batch
    
    def add_transition(self, *args):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.index] = Transition(*args)
        self.index = (self.index + 1) % self.capacity
    
    def __len__(self):
        return len(self.memory)

In [None]:
def to_var(state):
    state_ = state.transpose(2, 0, 1).astype(float)
    var = Variable(FloatTensor(state_))
    var = var.unsqueeze(0)
    return var

In [None]:
class DQN(object):
    def __init__(self, target_update_freq):
        self.eval_net = Net()
        self.target_net = Net()
        if USE_CUDA:
            self.eval_net.cuda()
            self.target_net.cuda()

        self.memory = ReplayMemory(MEMORY_SIZE)
        self.optimizer = torch.optim.RMSprop(self.eval_net.parameters(), lr=1e-4)
        self.loss = nn.SmoothL1Loss()
        
        self.target_update_freq = target_update_freq
    
    def get_action(self, state, progress):
        threshold = EPSILON_START + (EPSILON_END - EPSILON_START) * progress
        threshold = max(threshold, EPSILON_END)
        use_model = random.random() > threshold
        if use_model:
            var = to_var(state)
            actions = self.eval_net(var)
            action = actions.data.max(1)[1][0]
        else:
            action = random.randrange(4) # TODO: NUM_ACTIONS
        return action
    
    def can_learn(self):
        return len(self.memory) >= MEMORY_SIZE
    
    def learn(self, update_target):
        transitions = self.memory.get_batch(BATCH_SIZE)
        batch = Transition(*zip(*transitions)) # transpose
        
        b_s  = Variable(torch.cat(batch.s))
        b_a  = Variable(torch.cat(batch.a))
        b_s_ = Variable(torch.cat(batch.s_))
        b_r  = Variable(torch.cat(batch.r)).unsqueeze(1)

        q_eval = self.eval_net(b_s).gather(1, b_a)
        q_next = self.target_net(b_s_).detach()
        q_target = b_r + GAMMA * q_next.max(1)[0].unsqueeze(1)
        loss = self.loss(q_eval, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        if update_target:
            self.target_net.load_state_dict(self.eval_net.state_dict())

In [None]:
dqn = DQN(TARGET_UPDATE_FREQ)

In [None]:
rewards = []
seed = 11037
env.seed(seed)
total_steps = 10**7
model_path = '/mnt/disk0/kevin1kevin1k/models/'
steps = 0
start = time()
for ep in count(1):
    s = env.reset()
    done = False
    episode_reward = 0.0

    for j in count():
        a = dqn.get_action(s, steps / (total_steps * 0.1))
        s_, r, done, info = env.step(a)

        dqn.memory.add_transition(
            FloatTensor(np.expand_dims(s.transpose(2, 0, 1).astype(float), 0)),
            LongTensor([[a]]),
            FloatTensor(np.expand_dims(s_.transpose(2, 0, 1).astype(float), 0)),
            FloatTensor([r.astype(float)]),
        )

        if dqn.can_learn() and steps % EVAL_UPDATE_FREQ == 0:
            update_target = steps % TARGET_UPDATE_FREQ == 0
            dqn.learn(update_target)
        
        episode_reward += r
        steps += 1
        
        if steps % SAVE_EVERY == 0:
            torch.save(dqn.eval_net, model_path + 'eval_net{}.pt'.format((steps / SAVE_EVERY) % 10))

        if done:
            break

        s = s_
        
    print('Episode: {}, steps: {}, reward: {}, time: {}'.format(ep, steps, episode_reward, time() - start))
    if steps > total_steps:
        break
    rewards.append(episode_reward)