In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import math
import numpy as np

In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        num_tokens,
        dim_model,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        dropout_p
    ):
        super().__init__()

        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p
        )

    def forward(self):
        pass

In [None]:
"""
Now, we code the positional encoding. 
It involves a formula found in "attention is all you need", dropout, and a residual connection.
"""

class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len): #dim_model = size of tokens
        super().__init__()
        self.dropout = nn.Dropout(dropout_p)

        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) #shape [max_len, 1], a list of positions from 0 to max_len-1
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)))

        pos_encoding[:, 0::2] = torch.cos(positions_list*division_term) # PE(pos, 2i) = sin(pos/1000^(2i/dim_model)) -> from the paper
        pos_encoding[:, 1::2] = torch.sin(positions_list*division_term) # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))

        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding", pos_encoding)

    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])


In [None]:
class Transformer(nn.Module):
    def __init__(
        self,
        num_tokens,
        dim_model,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        dropout_p
    ):
        super().__init__()

        self.model = "Transformer"
        self.dim_model = dim_model

        self.positional_encoder = PositionalEncoding(
            dim_model=dim_model, dropout_p=dropout_p, max_len=5000
        )
        self.embedding = nn.Embedding(num_tokens, dim_model) #turns token indices into dense vectors

        #core transformer module from pytorch
        self.transformer - nn.Transformer(
            d_model=dim_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p
        )
        self.out = nn.Linear(dim_model, num_tokens)

    def forward(
        self, 
        src, #size is (batch_size, src sequence length)
        tgt #size is (batch_size, tgt sequence length)
        tgt_mask=None,
        src_pad_mask=None,
        tgt_pad_mask=None
    ):

        src = self.enbedding(src) * math.sqrt(self.dim_model)
        tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)

        #permute to obtain size (sequence length, batch_size, num_tokens)
        src = src.permute(1, 0, 2)
        tgt = tgt.permute(1, 0, 2)

        #transformer blocks - out size = (squence length, batch_size, num_tokens)
        transformer_out = self.transformer(src, tgt)
        out = self.out(transformer_out)

        return out
    def get_tgt_mask(self, size) ->torch.tensor:
        #generates a square matrix where the each row allows one more word to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        return mask

        
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)