In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

In [None]:
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor

In [None]:
# W, H, C = agent.env.get_observation_space().shape
# NUM_ACTIONS = agent.env.get_action_space().n

In [None]:
# conv_params = (
#     {'in_channels':  3, 'out_channels': 32, 'kernel_size': 8, 'stride': 4},
#     {'in_channels': 32, 'out_channels': 64, 'kernel_size': 4, 'stride': 2},
#     {'in_channels': 64, 'out_channels': 64, 'kernel_size': 3, 'stride': 1},
# )

# linear_params = (512, NUM_ACTIONS)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
#         self.convs = []
#         self.bns = []
#         self.fcs = []
        
#         w, h = W, H
#         for i in range(len(conv_params)):
#             param = conv_params[i]
#             self.convs.append(nn.Conv2d(**param))
#             self.bns.append(nn.BatchNorm2d(param['out_channels']))
#             w = (w - param['kernel_size']) // param['stride'] + 1
#             h = (h - param['kernel_size']) // param['stride'] + 1
#         c = param['out_channels']
        
#         last_d = w * h * c
#         for d in linear_params:
#             self.fcs.append(nn.Linear(last_d, d))
#             last_d = d
#         # TODO: initialization?
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=8, stride=4)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(22528, 512)
        self.fc2 = nn.Linear(512, 4)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
import random
from collections import namedtuple

MEMORY_SIZE = 10000

Transition = namedtuple('Transition', ('s', 'a', 's_', 'r'))

class ReplayMemory(object):
    def __init__(self):
        self.memory = []
        self.index = 0
    
    def get_batch(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        return batch
    
    def add_transition(self, *args):
        if len(self.memory) < MEMORY_SIZE:
            self.memory.append(None)
        self.memory[self.index] = Transition(*args)
        self.index = (self.index + 1) % MEMORY_SIZE
    
    def __len__(self):
        return len(self.memory)

In [None]:
def to_var(state):
    state_ = state.transpose(2, 0, 1).astype(float)
    var = Variable(FloatTensor(state_))
    var = var.unsqueeze(0)
    return var

In [1]:
TARGET_UPDATE_FREQ = 1000
EVAL_UPDATE_FREQ = 4
BATCH_SIZE = 32
GAMMA = 0.99
EPSILON_START = 0.99
EPSILON_END = 0.05

def use_model(progress):
    threshold = EPSILON_START + (EPSILON_END - EPSILON_START) * progress
    return random.random() < threshold

class DQN(object):
    def __init__(self):
        self.eval_net = Net()
        self.target_net = Net()
        self.eval_net.cuda()
        self.target_net.cuda()

        self.update_count = 0
        self.memory = ReplayMemory()
        self.optimizer = torch.optim.RMSprop(self.eval_net.parameters(), lr=1e-4)
        self.loss = nn.SmoothL1Loss()
    
    def get_action(self, state, progress):
        if use_model(progress):
            var = to_var(state)
            actions = self.eval_net(var)
#             print(actions)
#             print(actions.data.max(1))
            action = actions.data.max(1)[1][0]
        else:
            action = random.randrange(4) # TODO: NUM_ACTIONS
        return action
    
    def can_learn(self):
        return len(self.memory) > MEMORY_SIZE
    
    def learn(self):
        transitions = self.memory.get_batch(BATCH_SIZE)
        batch = Transition(*zip(*transitions)) # transpose
        
        b_s  = Variable(torch.cat(batch.s))
        b_a  = Variable(torch.cat(batch.a))
        b_s_ = Variable(torch.cat(batch.s_))
        b_r  = Variable(torch.cat(batch.r))
        
        q_eval = self.eval_net(b_s).gather(1, b_a)
        q_next = self.target_net(b_s_).detach()
        q_target = b_r + GAMMA * q_next.max(1)[0].unsqueeze(1)
        loss = self.loss(q_eval, q_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.update_count += 1
        if self.update_count % TARGET_UPDATE_FREQ == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())

In [None]:
import argparse
from test import test
from environment import Environment

In [None]:
def parse():
    parser = argparse.ArgumentParser(description="MLDS&ADL HW3")
    parser.add_argument('--env_name', default=None, help='environment name')
    parser.add_argument('--train_pg', action='store_true', help='whether train policy gradient')
    parser.add_argument('--train_dqn', action='store_true', help='whether train DQN')
    parser.add_argument('--test_pg', action='store_true', help='whether test policy gradient')
    parser.add_argument('--test_dqn', action='store_true', help='whether test DQN')
    parser.add_argument('--video_dir', default=None, help='output video directory')
    parser.add_argument('--do_render', action='store_true', help='whether render environment')
    try:
        from argument import add_arguments
        parser = add_arguments(parser)
    except:
        pass
    return parser

In [None]:
# if args.train_dqn:
parser = parse()
args = parser.parse_args(['--train_dqn'])
env_name = args.env_name or 'BreakoutNoFrameskip-v4'
env = Environment(env_name, args)
from agent_dir.agent_dqn import Agent_DQN
agent = Agent_DQN(env, args)

In [None]:
dqn = DQN()

In [None]:
from itertools import count
from time import time

rewards = []
seed = 11037
env.seed(seed)
total_steps = 10**7
steps = 0
for ep in count():
    s = env.reset()
    done = False
    episode_reward = 0.0

    s0 = np.zeros_like(s)
    start = time()
    for j in count():
        delta_s = s - s0
        a = dqn.get_action(delta_s, steps / total_steps)
        s_, r, done, info = env.step(a)
        if done:
            delta_s_ = None
            # TODO: add to memory?
        else:
            delta_s_ = s_ - s
            dqn.memory.add_transition(delta_s, a, delta_s_, r)

        if dqn.can_learn() and j % EVAL_UPDATE_FREQ == EVAL_UPDATE_FREQ - 1:
            dqn.learn()
        
        if done:
            break
        
        s0, s = s, s_
        episode_reward += r
        steps += 1
        
    print(ep, episode_reward, time() - start)
    if steps > total_steps:
        break
    rewards.append(episode_reward)

    torch.save(dqn.eval_net, './models/eval_net{}.pt'.format(ep))