In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class QuoridorNNet(nn.Module):
    """
    Input: slen - integer representing the length of the sequence run through transformer
           ntoken - integer representing the amount of total tokens
           d_model - MUST BE EVEN, represents the dimension of the model
           ninp - integer representing number of input layers
           nhead - integer representing the number of heads in the multiheadattention models
           nhid - integer representing the dimension of the feedforward network model in nn.TransformerEncoder
           nlayers - integer representing the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
           dropout - integer representing the dropout percentage you want to use (Default=0.5) [OPTIONAL]
    Description: Initailize transormer model class creating the appropiate layers
    Output: None
    """
    def __init__(self, game, nwalltoken: int = 3, npawntoken: int = 3, d_model: int = 200, nhead: int = 2, d_hid: int = 2048, nlayers: int = 2, dropout: float = 0.5, ):
        # game params
        self.action_size = game.getActionSize()
        self.input_size = game.getBoardSize()
        
        super().__init__()
        self.model_type = 'Transformer'
        self.embeddingSquares = nn.Embedding(nwalltoken, int(d_model / 2))
        self.embeddingWalls = nn.Embedding(npawntoken, int(d_model / 2))
        self.embeddingNWalls = nn.Embedding(11, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.d_model = d_model
        self.softmax = nn.Softmax(dim=1) #Softmax activation layer
        self.gelu = nn.GELU() #GELU activation layer
        self.flatten = nn.Flatten(start_dim=1) #Flatten layer
        self.decoder = nn.Linear(d_model,1) #Decode layer
        self.v_output = nn.Linear(self.input_size, 3) #Decode layer
        self.p_output = nn.Linear(self.input_size, self.action_size) #Decode layer

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embeddingSquares.weight.data.uniform_(-initrange, initrange)
        self.embeddingWalls.weight.data.uniform_(-initrange, initrange)
        self.embeddingNWalls.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src):
        """
        Arguments:
            src: Tensor, shape [sinp, batch_size]
            src_mask: Tensor, shape ``[seq_len, seq_len]``

        Returns:
            output Tensor of shape ````
        """
        srcNWalls = src[:2].int()
        srcSquares = src[2: 2 + 81].int()
        srcWalls = src[2 + 81:].int()
        src = torch.cat(self.embeddingNWalls(srcNWalls), self.embeddingSquares(srcSquares), self.embeddingWalls(srcWalls))
        src = src * math.sqrt(self.d_model) # normalize embeddings
        src = self.pos_encoder(src) # Apply positional embeddings
        output = self.transformer_encoder(src) # Run through transformer model
        output = self.gelu(output)
        output = self.decoder(output) #Linear layer
        output = self.gelu(output)
        output = self.flatten(output)
        v = self.v_output(output) #Value output
        v = self.softmax(v) #Get softmax probability
        p = self.p_output(output) #Policy output
        p = self.softmax(p) #Get softmax probability
        
        return v,p
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
