### Imports

In [18]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import namedtuple, deque
import random
import math
from itertools import count

### Hyperparameters

In [21]:
BUFFER_SIZE = int(1e5)
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.95
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE_EVERY = 10
LEARNING_RATE = 5e-4
NUM_EPISODES = 20

VISIBILITY = 4 # square's half side length
STATE_SIZE = (2 * VISIBILITY + 1)**2
ACTION_SIZE = 5 # 4 directions + do nothing

### Initializations

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Function definitions

### Environment

In [23]:
class Environment():
    def __init__(self):
        #Initialize env
        pass
    
    def reset():
        pass
    
    def step(action):
        pass
    
    def render():
        pass
    

### Replay memory

In [10]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
        
    def push(self, *args):
        self.memory.append(Transition(*args))
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

### DQN

In [14]:
class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(DQN, self).__init__()
        self.seed = torch.manual_seed(0)
        self.fc1 = nn.Linear(state_size, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 32)
        self.head = nn.Linear(32, action_size)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        # return self.head(x)
        return self.head(x.view(x.size(0), -1))

### DQL Agent

In [28]:
class DQLAgent():
    def __init__(self, state_size, action_size, env):
        self.state_size = state_size
        self.action_size = action_size
        self.env = env
        
        # DQNs
        self.policy_net = DQN(state_size, action_size).to(device)
        self.target_net = DQN(state_size, action_size).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
        self.memory = ReplayMemory(BUFFER_SIZE)
        
        self.time_step = 0
        
    def optimize_model(self):
        if len(self.memory) < BATCH_SIZE:
            return
        transitions = self.memory.sample(BATCH_SIZE)
        batch = Transition(*zip(*transitions))
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        
        state_action_values = policy_net(state_batch).gather(1, action_batch)
        
        next_state_values = torch.zeros(BATCH_SIZE, device=device)
        next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach()
        expected_state_action_values = next_state_values * GAMMA + reward_batch
        
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
        
        self.optimizer.zero_grad()
        loss.backward()
        for param in policy_net.parameters():
            param.grad.data.clamp_(-1, 1) # Gradient clipping?
        optimizer.step()
        
    def select_action(self, state):
        sample = random.random()
        eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * self.time_step / EPS_DECAY)
        if sample > eps_threshold:
            with torch.no_grad():
                return self.policy_net(state).max(1)[1].view(1, 1)
        else:
            return torch.tensor([[random.randrange(self.action_size)]], device=device, dtype=torch.long)
    
    def train(self):
        for episode in range(NUM_EPISODES):
            state = self.env.reset() # get initial state
            for t in count(): # The environment is responsible for returning done=True after some time steps
                action = self.select_action(state)
                next_state, reward, done, _ = self.env.step(action.item())
                reward = torch.tensor([reward], device=device)
                
                self.memory.push(state, action, next_state, reward)
                state = next_state
                
                self.optimize_model()
                if done:
                    # TODO: Plot some statistics etc...
                    break
            if episode % TARGET_UPDATE_EVERY == 0:
                self.target_net.load_state_dict(self.policy_net.state_dict())
        print("Training finished")
        

In [29]:
env = Environment()
dqlAgent = DQLAgent(STATE_SIZE, ACTION_SIZE, env)