In [None]:
import os
import math
import time
import random
import numpy as np
from collections import namedtuple
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
import torch.autograd as autograd
from torch.distributions import Categorical

from utils.minipacman import MiniPacman
from utils.multiprocessing_env import SubprocVecEnv

import matplotlib.pyplot as plt
from IPython import display
from IPython.core.debugger import set_trace


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [None]:
class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):
    def __init__(self, in_shape, n_actions):
        super(DQN, self).__init__()
        self.in_shape = in_shape
        
        self.features = nn.Sequential(
            nn.Conv2d(in_shape[0], 16, kernel_size=3, stride=1),
            # try adding batch norm
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, stride=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
        )
        
        self.fc = nn.Sequential(
            nn.Linear(self.feature_size(), 256),
            nn.ReLU(),
        )
        
        self.head = nn.Linear(256, n_actions)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        value = self.head(x)
        return value
    

    def act(self, x):
        value = self.forward(x)
        # probs = F.softmax(value, dim=-1)
        return value.max(1)[1].view(-1, 1)
        
        # return probs.multinomial(1)
    

    def feature_size(self):
        return self.features(torch.zeros(1, *self.in_shape)).view(1, -1).size(1)

In [None]:
def select_action(state, policy_net, num_actions, num_envs, epilson=0.9):
    e = random.random()
    
    if e < epilson:
        with torch.no_grad():
            return policy_net.act(state)
    else:
        # return torch.tensor([[random.randrange(num_actions)]], device=DEVICE, dtype=torch.long)
        return torch.from_numpy(np.random.randint(num_actions, size=(num_envs, 1))).long().to(DEVICE)

In [None]:
def optimize_model(policy_net, target_net, memory, batch_size=128, gamma=0.999, memory_sample_prob=0.25):
    if len(memory) < batch_size:
        return
    
    transitions = memory.sample(batch_size)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), 
                                  device=DEVICE, dtype=torch.bool)
    
    non_final_next_states = torch.cat([torch.FloatTensor([s]).to(DEVICE) for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    state_action_values = policy_net(state_batch).gather(1, action_batch)

    next_state_values = torch.zeros(batch_size, device=DEVICE)
    
    # For DQN
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    
    # For DDQN
    # with torch.no_grad():
    #         next_pred_action_batch = policy_net.act(non_final_next_states)
        
    # next_state_values[non_final_mask] = target_net(non_final_next_states).gather(1, next_pred_action_batch).squeeze().detach()
    
    
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * gamma) + reward_batch
    
    # Compute Huber loss
    loss = F.mse_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    # loss = F.mse_loss(state_action_values, expected_state_action_values.view(-1, 1))
    # loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))

    # Optimize the model
    target_net.eval()
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step() 

In [None]:
mode = "regular"
num_envs = 16
# env = MiniPacman(mode, 1000)

def make_env():
    def _thunk():
        env = MiniPacman(mode, 1000)
        return env

    return _thunk

envs = [make_env() for i in range(num_envs)]
envs = SubprocVecEnv(envs)

state_shape = envs.observation_space.shape
num_actions = envs.action_space.n

In [None]:
policy_net = DQN(state_shape, num_actions).to(DEVICE) # save to checkpoint
target_net = DQN(state_shape, num_actions).to(DEVICE) # save to checkpoint
target_net.load_state_dict(policy_net.state_dict()) # save to checkpoint
target_net.eval()
policy_net.train()



lr    = 7e-4
eps   = 1e-5
alpha = 0.99

optimizer = optim.RMSprop(policy_net.parameters(), lr, eps=eps, alpha=alpha) # save to checkpoint
# optimizer = optim.Adam(policy_net.parameters())
# optimizer = optim.RMSprop(policy_net.parameters()) # save to checkpoint
memory = ReplayMemory(10000) # save to checkpoint

In [None]:
num_frames = int(1e5)
target_update = 1000
batch_size = 256
backprops_freq = 0
cur_best_reward = 0
tau = 1e-3

all_rewards = []
# all_losses  = []

episode_rewards = torch.zeros(num_envs, 1)
final_rewards   = torch.zeros(num_envs, 1)

state = envs.reset()
state = torch.FloatTensor(np.float32(state)).to(DEVICE)

for i_update in range(num_frames):
    # Initialize the environment and state
    state = torch.FloatTensor(np.float32(state)).to(DEVICE)
    
    for t in range(10):
        # Select and perform an action
        action = select_action(state, policy_net, num_actions, num_envs)
        next_state, reward, done, _ = envs.step(action.squeeze(1).cpu().data.numpy())
        reward = torch.tensor(reward).unsqueeze(1).to(DEVICE)
        
        episode_rewards += reward
        masks = torch.FloatTensor(1-np.array(done)).unsqueeze(1)
        final_rewards *= masks
        final_rewards += (1-masks) * episode_rewards
        episode_rewards *= masks
        
        # Store the all transition in memory
        for i in range(num_envs):
            memory.push(state[i].unsqueeze(0), action[i].unsqueeze(0), next_state[i], reward[i])
    
        # Move to the next state
        state = torch.FloatTensor(np.float32(next_state)).to(DEVICE)

        # Perform one step of the optimization (on the target network)
        optimize_model(policy_net, target_net, memory, batch_size=batch_size)
             
        # Update the target network, copying all weights and biases in DQN
        # hard update
        backprops_freq += 1
        if backprops_freq % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())

        # soft update
        # for target_param, local_param in zip(target_net.parameters(), policy_net.parameters()):
        #     target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
    
    
    if i_update % 10 == 0:
        all_rewards.append(final_rewards.mean())
        display.clear_output(True)
        plt.figure(figsize=(8, 5))
        plt.plot(all_rewards)
        plt.title(f"epoch {i_update}. reward: {np.mean(all_rewards[-10:])}")
        plt.xlabel("Environmental Steps x10")
        plt.ylabel("Rewards")
        plt.show()

In [None]:
# env = MiniPacman(mode, 1000)
# state = env.reset()
# done = False
# total_reward = 0
# step = 1

# policy_net.eval()

# while not done:
#     current_state = torch.FloatTensor(state).unsqueeze(0).to(DEVICE)
# #     action = target_net.act(current_state)
#     action = policy_net.act(current_state)
#     next_state, reward, done, _ = env.step(action.data[0, 0])
#     total_reward += reward
#     state = next_state
    
#     plt.imshow(state.transpose([1, 2, 0]))
#     plt.axis('off')
#     plt.title(f"steps: {step}, reward: {total_reward}")
    
#     display.display(plt.gcf())
#     display.clear_output(wait=True)
#     time.sleep(0.1)
    
#     step += 1