In [4]:
import math
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from collections import namedtuple, deque

In [5]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(in_features=9,out_features=128)
        self.layer2 = nn.Linear(in_featrures=128,out_features=128)
        self.head = nn.Linear(in_featrures=128,out_features=21)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        #x = x.to(device)
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.head(x)

In [None]:
def heaps_to_bits(heaps):
    #Return vector of length 9
    #Heap numbers from 1

In [None]:
#Copy pasted so that we remember to create same methods
class NimDQNlearningAgent:
    #Q learning agent for Nim with epsilon-greedy policy
    def __init__(self, alpha=0.1, gamma=0.99, epsilon=0.1, player=0, buffer_size=10000, batch_size=64, update_cycle=500, target_network=None, policy_network=None, ReplayBuffer=None):
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.player = player #  0 or 1
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.update_cycle = update_cycle
        
        #Initialize replay buffer
        if ReplayBuffer is not None:
            self.buffer = ReplayBuffer
        else:
            self.buffer = ReplayMemory(self.buffer_size)
        
        if policy_network is not None:
            self.policy_network = policy_network
        else:
            self.policy_network = DQN()
        if target_network is not None:
            self.target_network = target_network
        else:
            self.target_network = DQN()
            target_network.load_state_dict(policy_network.state_dict())
            
        self.optimizer = torch.optim.Adam(self.policy_network.parameters(), lr=0.0005)
    
    def copy(self):
        return NimDQNlearningAgent(alpha=self.alpha, gamma=self.gamma, epsilon=self.epsilon, player=self.player, buffer_size=self.buffer_size, batch_size=self.batch_size, update_cycle=self.update_cycle, target_network=self.target_network, policy_network=policy_network, ReplayBuffer=self.buffer)
    
    def randomMove(self, heaps):
        # choose a random move from the available heaps
        heaps_avail = [i for i in range(len(heaps)) if heaps[i] > 0]
        chosen_heap = random.choice(heaps_avail)
        n_obj = random.choice(range(1, heaps[chosen_heap] + 1))
        move = [chosen_heap + 1, n_obj]
        return move
    
    def act(self, heaps):
        # take a state, return the action with max Q-value
        if random.random() < self.epsilon:
            return self.randomMove(heaps)
        else:
            #Q values from policy network
            s = heaps_to_bits(heaps)
            Q = self.policy_network(s)
            i_max = np.argmax(Q)
            #Transform into heap number and number of sticks
            move = [i_max//7+1, i_max%7+1]
            return move
        
    def update(self, oldHeaps, move, newHeaps, reward):
        old_state = heaps_to_bits(oldHeap)
        new_state = heaps_to_bits(newHeaps)
        index_of_move = move[0]
        
        #Save to replay buffer
        self.buffer.push([old_state,index_of_move,new_state,reward])
        
        #Sample from replay if possible and update policy netwok
        if len(self.buffer) < self.batch_size:
            return
        transitions = self.buffer.sample(self.batch_size)
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
        # detailed explanation). This converts batch-array of Transitions
        # to Transition of batch-arrays.
        batch = Transition(*zip(*transitions))

        # Compute a mask of non-final states and concatenate the batch elements
        # (a final state would've been the one after which simulation ended)
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)

        # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
        # columns of actions taken. These are the actions which would've been taken
        # for each batch state according to policy_net
        state_action_values = self.policy_network(state_batch).gather(1, action_batch)

        # Compute V(s_{t+1}) for all next states.
        # Expected values of actions for non_final_next_states are computed based
        # on the "older" target_net; selecting their best reward with max(1)[0].
        # This is merged based on the mask, such that we'll have either the expected
        # state value or 0 in case the state was final.
        next_state_values = torch.zeros(self.batch_size)
        next_state_values[non_final_mask] = self.target_network(non_final_next_states).max(1)[0]
        # Compute the expected Q values
        expected_state_action_values = (next_state_values * self.gamma) + reward_batch

        # Compute Huber loss
        criterion = nn.SmoothL1Loss()
        loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

        # Optimize the model
        self.optimizer.zero_grad()
        loss.backward()
        #for param in policy_net.parameters():
            #param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
        
    def update_target(self):
        self.target_network.load_state_dict(self.policy_network.state_dict())