# Training Notebook

Library Imports for the jupyter notebook

In [1]:
import os   # miscellaneous os interfaces
import sys  # configuring python runtime environment
import time # library for time manipulation, and logging

In [2]:
# use `datetime` to control and preceive the environment
# in addition `pandas` also provides date time functionalities
import datetime as dt

In [3]:
from copy import deepcopy      # dataframe is mutable
from tqdm import tqdm     # progress bar for loops
from uuid import uuid4 as UUID # unique identifier for objs

In [4]:
import numpy as np
import matplotlib.pyplot as plt

In [5]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms

from torchinfo import summary
from torch.utils.data import DataLoader
from torchsummary import summary

In [7]:
import math

## Modify Pytorch summary for my own use

In [8]:
# import torch
# import torch.nn as nn
# from torch.autograd import Variable

# from collections import OrderedDict
# import numpy as np


# def summary(model, input_size, batch_size=-1, device="cuda"):

#     def register_hook(module):

#         def hook(module, input, output):
#             class_name = str(module.__class__).split(".")[-1].split("'")[0]
#             module_idx = len(summary)

#             m_key = "%s-%i" % (class_name, module_idx + 1)
#             summary[m_key] = OrderedDict()
#             print(11, module)
#             print(111, input[0])
#             summary[m_key]["input_shape"] = list(input[0].size())
#             summary[m_key]["input_shape"][0] = batch_size
#             if isinstance(output, (list, tuple)):
#                 summary[m_key]["output_shape"] = [
#                     [-1] + list(o.size())[1:] for o in output
#                 ]
#             else:
#                 summary[m_key]["output_shape"] = list(output.size())
#                 summary[m_key]["output_shape"][0] = batch_size

#             params = 0
#             if hasattr(module, "weight") and hasattr(module.weight, "size"):
#                 params += torch.prod(torch.LongTensor(list(module.weight.size())))
#                 summary[m_key]["trainable"] = module.weight.requires_grad
#             if hasattr(module, "bias") and hasattr(module.bias, "size"):
#                 params += torch.prod(torch.LongTensor(list(module.bias.size())))
#             summary[m_key]["nb_params"] = params

#         if (
#             not isinstance(module, nn.Sequential)
#             and not isinstance(module, nn.ModuleList)
#             and not (module == model)
#         ):
#             hooks.append(module.register_forward_hook(hook))

#     device = device.lower()
#     assert device in [
#         "cuda",
#         "cpu",
#     ], "Input device is not valid, please specify 'cuda' or 'cpu'"

#     if device == "cuda" and torch.cuda.is_available():
#         dtype = torch.cuda.FloatTensor
#     else:
#         dtype = torch.FloatTensor

#     # multiple inputs to the network
#     if isinstance(input_size, tuple):
#         input_size = [input_size]

#     # batch_size of 2 for batchnorm
#     x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
#     # print(type(x[0]))

#     # create properties
#     summary = OrderedDict()
#     hooks = []

#     # register hook
#     model.apply(register_hook)

#     # make a forward pass
#     # print(x.shape)
#     model(*x)

#     # remove these hooks
#     for h in hooks:
#         h.remove()

#     print("----------------------------------------------------------------")
#     line_new = "{:>20}  {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
#     print(line_new)
#     print("================================================================")
#     total_params = 0
#     total_output = 0
#     trainable_params = 0
#     for layer in summary:
#         # input_shape, output_shape, trainable, nb_params
#         line_new = "{:>20}  {:>25} {:>15}".format(
#             layer,
#             str(summary[layer]["output_shape"]),
#             "{0:,}".format(summary[layer]["nb_params"]),
#         )
#         total_params += summary[layer]["nb_params"]
#         total_output += np.prod(summary[layer]["output_shape"])
#         if "trainable" in summary[layer]:
#             if summary[layer]["trainable"] == True:
#                 trainable_params += summary[layer]["nb_params"]
#         print(line_new)

#     # assume 4 bytes/number (float on cuda).
#     total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
#     total_output_size = abs(2. * total_output * 4. / (1024 ** 2.))  # x2 for gradients
#     total_params_size = abs(total_params.numpy() * 4. / (1024 ** 2.))
#     total_size = total_params_size + total_output_size + total_input_size

#     print("================================================================")
#     print("Total params: {0:,}".format(total_params))
#     print("Trainable params: {0:,}".format(trainable_params))
#     print("Non-trainable params: {0:,}".format(total_params - trainable_params))
#     print("----------------------------------------------------------------")
#     print("Input size (MB): %0.2f" % total_input_size)
#     print("Forward/backward pass size (MB): %0.2f" % total_output_size)
#     print("Params size (MB): %0.2f" % total_params_size)
#     print("Estimated Total Size (MB): %0.2f" % total_size)
#     print("----------------------------------------------------------------")
#     # return summary


## Building model

Create a transformer model from the original [transformer paper](https://arxiv.org/abs/1706.03762)

Lets start the build by understanding the fundamental block of transformers and build the entire model from here

### Understanding Multi Head Attention (MHA)

In [9]:
class UnoptimizedMultiHeadAttention(nn.Module):
    """
    We can refer to the following blog to understand in depth about the transformer and MHA
    https://jalammar.github.io/illustrated-transformer/
    """
    def __init__(self, dk, dv, h):
        """
        Input Args:
        
        dk(int): Key dimensions used for generating Key weight matrix
        dv(int): Val dimensions used for generating val weight matrix
        h(int) : Number of heads in MHA
        """
        super().__init__()
        assert dk == dv
        self.dk = dk
        self.dv = dv
        self.h = h
        self.dmodel = self.dk * self.h  # model dimension
        
        # Add the params in modulelist as the params in the conv list needs to be tracked
        # wq, wk, wv -> multiple linear weights for the number of heads
        self.WQ = nn.ModuleList([nn.Linear(self.dmodel, self.dk) for _ in range(self.h)]) # shape -> (dmodel, dk)
        self.WK = nn.ModuleList([nn.Linear(self.dmodel, self.dk) for _ in range(self.h)]) # shape -> (dmodel, dk)
        self.WV = nn.ModuleList([nn.Linear(self.dmodel, self.dv) for _ in range(self.h)]) # shape -> (dmodel, dv)
        # Output Weights
        self.WO = nn.Linear(self.h*self.dv, self.dmodel)  # shape -> (dmodel, dmodel)
        
#         self.attention_dropout = nn.Dropout(p=dropout_probability)
        self.softmax = nn.Softmax(dim=-1)
        
    def attention(self, query, key, val):
        """
        Perform Scaled Dot Product Attention on multi head attention. 
        
        Notation: B - batch size, S/T - max src/trg token-sequence length
        query shape = (B, dmodel, S/T)
        key shape = (B, dmodel, S/T)
        val shape = (B, dmodel, S/T)
        """
        head = []
        # Create multiple heads using SDP
        for i in range(self.h):
            Q = self.WQ[i](query) # shape -> (B, 1, dk)
            K = self.WK[i](key)   # shape -> (B, 1, dk)
            V = self.WV[i](val)   # shape -> (B, 1, dv)
            score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dk) # shape -> (B, 1, 1)
            score = self.softmax(score)
            H = torch.matmul(score, V) # V Transpose not needed here as per the paper shape -> (B, 1, dk)
            head.append(H)
        return head
            
    def forward(self, x):
        """
        Forward pass for MHA
        """
        query = key = val = x # For visualization we use the same input for all shape = (B, 1, dmodel)
        # Calculate multi head attentions for Q, K, V
        head = self.attention(query, key, val)
        # Concatenate multiple head along dim 1 as head shape = [B x 1 x dk]xh
        # therefore resultant would be out shape = B x 1 x dk*h
        out = torch.cat(head, axis=-1)
        # Final token_representation shape = (B, (dmodel*h), dmodel)
        token_representation = self.WO(out)  # shape = B x 1 x (dk*h=dmodel)
        return token_representation
        
        
    
    

In [10]:
dk = dv = 64
h = 8

In [11]:
net = UnoptimizedMultiHeadAttention(dk, dv, h)
print(net)
summary(net, (1, 512)) # Input should be 1, dk*h

UnoptimizedMultiHeadAttention(
  (WQ): ModuleList(
    (0-7): 8 x Linear(in_features=512, out_features=64, bias=True)
  )
  (WK): ModuleList(
    (0-7): 8 x Linear(in_features=512, out_features=64, bias=True)
  )
  (WV): ModuleList(
    (0-7): 8 x Linear(in_features=512, out_features=64, bias=True)
  )
  (WO): Linear(in_features=512, out_features=512, bias=True)
  (softmax): Softmax(dim=-1)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 1, 64]          32,832
            Linear-2                [-1, 1, 64]          32,832
            Linear-3                [-1, 1, 64]          32,832
           Softmax-4                 [-1, 1, 1]               0
            Linear-5                [-1, 1, 64]          32,832
            Linear-6                [-1, 1, 64]          32,832
            Linear-7                [-1, 1, 64]          32,832
           Softmax-8       

Now optimizing the multi head attention by removing the for loop and introducing matrix calculation for the optimization.

We also refer to this [blog](https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a)

In [12]:
class MultiHeadAttention1(nn.Module):
    """
    We can refer to the following blog to understand in depth about the transformer and MHA
    https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a
    
    Here we are clubbing all the linear layers together and duplicating the inputs and 
    then performing matrix multiplications
    """
    def __init__(self, dk, dv, h):
        """
        Input Args:
        
        dk(int): Key dimensions used for generating Key weight matrix
        dv(int): Val dimensions used for generating val weight matrix
        h(int) : Number of heads in MHA
        """
        super().__init__()
        assert dk == dv
        self.dk = dk
        self.dv = dv
        self.h = h
        self.dmodel = self.dk * self.h  # model dimension
        
        # Add the params in modulelist as the params in the conv list needs to be tracked
        # wq, wk, wv -> multiple linear weights for the number of heads
        self.WQ = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        self.WK = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        self.WV = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        # Output Weights
        self.WO = nn.Linear(self.h*self.dv, self.dmodel)  # shape -> (dmodel, dmodel)
        self.softmax = nn.Softmax(dim=-1)
        
            
    def forward(self, x):
        """
        Forward pass for MHA
        
        X has a size of (batch_size, seq_length, d_model)
        Wq, Wk, and Wv have a size of (d_model, d_model)
        
        Perform Scaled Dot Product Attention on multi head attention. 
        
        Notation: B - batch size, S/T - max src/trg token-sequence length
        query shape = (B, S, dmodel)
        key shape = (B, S, dmodel)
        val shape = (B, S, dmodel)
        """
#         TODO: Define the inputs properly
        query = key = val = x # For visualization we use the same input for all shape = (B, S, dmodel)
        
        # Weight the queries
        Q = self.WQ(query)     # shape -> (B, S, dmodel)
        K = self.WK(key)       # shape -> (B, S, dmodel)
        V = self.WV(val)       # shape -> (B, S, dmodel)
        
        # Separate last dimension to number of head and dk
        batch_size = Q.size(0)   
        Q = Q.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        K = K.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        V = V.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        
        # each sequence is split across n_heads, with each head receiving seq_length tokens 
        # with d_key elements in each token instead of d_model.
        Q = Q.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        K = K.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        V = V.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        
        # dot product of Q and K
        scaled_dot_product = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(self.dk)
        scaled_dot_product = self.softmax(scaled_dot_product)
        
        # Create head 
        head = torch.matmul(scaled_dot_product, V)  # shape -> (B, h, S, S) * (B, h, S, dk) = (B, h, S, dk)
        # Prepare the head to pass it through output linear layer
        head = head.permute(0, 2, 1, 3).contiguous()  # shape -> (B, S, h, dk)
        # Concatenate the head together
        head = head.view(batch_size, -1, self.h* self.dk)  # shape -> (B, S, (h*dk = dmodel))
        # Pass through output layer
        token_representation = self.WO(head)
        return token_representation
        

In [13]:
dk = dv = 64
h = 8
net = MultiHeadAttention1(dk, dv, h)
print(net)
summary(net, (10, 512)) # Input should be S, (dk*h=dmodel)

MultiHeadAttention1(
  (WQ): Linear(in_features=512, out_features=512, bias=True)
  (WK): Linear(in_features=512, out_features=512, bias=True)
  (WV): Linear(in_features=512, out_features=512, bias=True)
  (WO): Linear(in_features=512, out_features=512, bias=True)
  (softmax): Softmax(dim=-1)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 10, 512]         262,656
            Linear-2              [-1, 10, 512]         262,656
            Linear-3              [-1, 10, 512]         262,656
           Softmax-4            [-1, 8, 10, 10]               0
            Linear-5              [-1, 10, 512]         262,656
Total params: 1,050,624
Trainable params: 1,050,624
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 0.16
Params size (MB): 4.01
Estimated Total Size (MB): 4.19

## Understanding Positonal Encoding

In [14]:
class Embedding(nn.Module):
    """
    Embedding lookup table which is used by the positional embedding block.
    Embedding lookup table is shared across input and output
    """
    def __init__(self, vocab_size, dmodel):
        """
        Embedding lookup needs a vocab size and model dimension size matrix for 
        creating lookups
        """
        super().__init__()
        self.embedding_lookup = nn.Embedding(vocab_size, dmodel)
        self.vocab_size = vocab_size
        self.dmodel = dmodel
        
    def forward(self, token_ids):
        """
        For a given token lookup the embedding vector
        
        As per the paper, we also multiply the embedding vector with sqrt of dmodel 
        """
        token_ids = torch.Tensor([[0, 0], [1, 1]]).type(torch.IntTensor)
#         TODO: delete the above
        # Since tokens -> shape -> (batch_size, token)
        assert token_ids.ndim == 2, f'Expected: (batch size, max token sequence length), got {token_ids.shape}'
        embedding_vector = self.embedding_lookup(token_ids)
        
        return embedding_vector * math.sqrt(self.dmodel)
              

In [15]:
vocab_size = 100
dmodel = dk*h
net = Embedding(vocab_size, dmodel)
print(net)
summary(net, input_size=([[2]])) 

Embedding(
  (embedding_lookup): Embedding(100, 512)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Embedding-1               [-1, 2, 512]          51,200
Total params: 51,200
Trainable params: 51,200
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 0.20
Estimated Total Size (MB): 0.20
----------------------------------------------------------------


In [16]:
class PositionalEncoding(nn.Module):
    def __init__(self, dmodel, max_seq_length = 5000, pdropout = 0.1,):
        """
        dmodel(int): model dimensions
        max_seq_length(int): Maximum input sequence length
        pdropout(float): Dropout probability
        """
        super().__init__()
        self.dropout = nn.Dropout(p = pdropout)
        
        # Calculate frequencies
        position_ids = torch.arange(0, max_seq_length).unsqueeze(1)
        # -ve sign is added because the exponents are inverted when you multiply position and frequencies
        frequencies = torch.pow(10000, -torch.arange(0, dmodel, 2, dtype = torch.float)/ dmodel) 
        
        # Create positional encoding table
        positional_encoding_table = torch.zeros(max_seq_length, dmodel)
        # Fill the table with even entries with sin and odd entries with cosine
        positional_encoding_table[:, 0::2] = torch.sin(position_ids * frequencies)
        positional_encoding_table[:, 1::2] = torch.cos(position_ids * frequencies)
    
        # Registering the position enconding in state_dict but the its not included 
        # in named parameter as it is not trainable
        self.register_buffer("positional_encoding_table", positional_encoding_table)
        
    
    def forward(self, embeddings_batch):
        """
        embeddings_batch shape = (batch size, seq_length, dmodel)
        positional_encoding_table shape = (max_seq_length, dmodel)
        """
        assert embeddings_batch.ndim == 3, \
        f"Embeddings batch should have dimension of 3 but got {embeddings_batch.ndim}"
        assert embeddings_batch.size()[-1] == self.positional_encoding_table.size()[-1], \
        f"Embedding batch shape and positional_encoding_table shape should match, expected Embedding batch shape : {embeddings_batch.shape[-1]} while positional_encoding_table shape : {positional_encoding_table[-1]}"
        
        # Get encodings for the given input sequence length
        pos_encodings = self.positional_encoding_table[:embeddings_batch.shape[1]] # Choose only seq_length out of max_seq_length
        
        # Final output 
        out = embeddings_batch + pos_encodings
        out = self.dropout(out)
        return out

In [17]:
vocab_size = 100
dmodel = dk*h
net = PositionalEncoding(dmodel)
print(net)
# summary(net, input_size=(100, 512))  # Can't visualize since params are none

PositionalEncoding(
  (dropout): Dropout(p=0.1, inplace=False)
)


## Understanding positionwise FeedForward Network

In [18]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, dmodel, dff, pdropout = 0.1):
        super().__init__()
        
        self.dropout = nn.Dropout(p = pdropout)
        
        self.W1 = nn.Linear(dmodel, dff)      # Intermediate layer
        self.W2 = nn.Linear(dff, dmodel)    # Output layer
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        """
        Perform Feedforward calculation
        
        x shape = (B - batch size, S/T - max token sequence length, D- model dimension).
        """
        out = self.W2(self.relu(self.dropout(self.W1(x))))
        return out

In [19]:
vocab_size = 100
dmodel = dk*h
dff = dmodel * 4
net = PositionwiseFeedForward(dmodel, dff)
print(net)
summary(net, input_size=(2, 512))  

PositionwiseFeedForward(
  (dropout): Dropout(p=0.1, inplace=False)
  (W1): Linear(in_features=512, out_features=2048, bias=True)
  (W2): Linear(in_features=2048, out_features=512, bias=True)
  (relu): ReLU()
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 2, 2048]       1,050,624
           Dropout-2              [-1, 2, 2048]               0
              ReLU-3              [-1, 2, 2048]               0
            Linear-4               [-1, 2, 512]       1,049,088
Total params: 2,099,712
Trainable params: 2,099,712
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.10
Params size (MB): 8.01
Estimated Total Size (MB): 8.12
----------------------------------------------------------------


## Understanding Encoder model

In [20]:
class MultiHeadAttention(nn.Module):
    """
    We can refer to the following blog to understand in depth about the transformer and MHA
    https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a
    
    Here we are clubbing all the linear layers together and duplicating the inputs and 
    then performing matrix multiplications
    """
    def __init__(self, dk, dv, h):
        """
        Input Args:
        
        dk(int): Key dimensions used for generating Key weight matrix
        dv(int): Val dimensions used for generating val weight matrix
        h(int) : Number of heads in MHA
        """
        super().__init__()
        assert dk == dv
        self.dk = dk
        self.dv = dv
        self.h = h
        self.dmodel = self.dk * self.h  # model dimension
        
        # Add the params in modulelist as the params in the conv list needs to be tracked
        # wq, wk, wv -> multiple linear weights for the number of heads
        self.WQ = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        self.WK = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        self.WV = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        # Output Weights
        self.WO = nn.Linear(self.h*self.dv, self.dmodel)  # shape -> (dmodel, dmodel)
        self.softmax = nn.Softmax(dim=-1)
        
            
    def forward(self, query, key, val, mask=None):
        """
        Forward pass for MHA
        
        X has a size of (batch_size, seq_length, d_model)
        Wq, Wk, and Wv have a size of (d_model, d_model)
        
        Perform Scaled Dot Product Attention on multi head attention. 
        
        Notation: B - batch size, S/T - max src/trg token-sequence length
        query shape = (B, S, dmodel)
        key shape = (B, S, dmodel)
        val shape = (B, S, dmodel)
        """      
        # Weight the queries
        Q = self.WQ(query)     # shape -> (B, S, dmodel)
        K = self.WK(key)       # shape -> (B, S, dmodel)
        V = self.WV(val)       # shape -> (B, S, dmodel)
        
        # Separate last dimension to number of head and dk
        batch_size = Q.size(0)   
        Q = Q.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        K = K.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        V = V.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        
        # each sequence is split across n_heads, with each head receiving seq_length tokens 
        # with d_key elements in each token instead of d_model.
        Q = Q.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        K = K.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        V = V.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        
        # dot product of Q and K
        scaled_dot_product = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(self.dk)
        
        # fill those positions of product as (-1e10) where mask positions are 0
        if mask is not None:
            scaled_dot_product = scaled_dot_product.masked_fill(mask == 0, -1e10)
            
        scaled_dot_product = self.softmax(scaled_dot_product)
        attention_prob = scaled_dot_product
        
        # Create head 
        head = torch.matmul(scaled_dot_product, V)  # shape -> (B, h, S, S) * (B, h, S, dk) = (B, h, S, dk)
        # Prepare the head to pass it through output linear layer
        head = head.permute(0, 2, 1, 3).contiguous()  # shape -> (B, S, h, dk)
        # Concatenate the head together
        head = head.view(batch_size, -1, self.h* self.dk)  # shape -> (B, S, (h*dk = dmodel))
        # Pass through output layer
        token_representation = self.WO(head)
        return token_representation, attention_prob
        

In [21]:
class EncoderLayer(nn.Module):
    """
    This building block in the encoder layer consists of the following
    1. MultiHead Attention
    2. Sublayer Logic
    3. Positional FeedForward Network
    """
    def __init__(self, dk, dv, h, dim_multiplier = 4, pdropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(dk, dv, h)
        # Reference page 5 chapter 3.2.2 Multi-head attention
        dmodel = dk*h
        # Reference page 5 chapter 3.3 positionwise FeedForward
        dff = dmodel * dim_multiplier
        self.attn_norm = nn.LayerNorm(dmodel)
        self.ff = PositionwiseFeedForward(dmodel, dff, pdropout=pdropout)
        self.ff_norm = nn.LayerNorm(dmodel)
        
        self.dropout = nn.Dropout(p = pdropout)
    
    def forward(self, src_inputs, src_mask=None):
        """
        Forward pass as per page 3 chapter 3.1
        """
        mha_out, attention_wts = self.attention(
                                query = src_inputs, 
                                key = src_inputs, 
                                val = src_inputs, 
                                mask = src_mask)
        
        # Residual connection between input and sublayer output, details: Page 7, Chapter 5.4 "Regularization",
        # Actual paper design is the following
        intermediate_out = self.attn_norm(src_inputs + self.dropout(mha_out))
        
        pff_out = self.ff(intermediate_out)
        
        # Perform Add Norm again
        out = self.ff_norm(intermediate_out + self.dropout(pff_out))
        return out, attention_wts
        

In [22]:
dk = 64
dv = 64
h = 8
net = EncoderLayer(dk, dv, h)
print(net)
# summary(net, input_size=(2, 512), device="cpu")  

EncoderLayer(
  (attention): MultiHeadAttention(
    (WQ): Linear(in_features=512, out_features=512, bias=True)
    (WK): Linear(in_features=512, out_features=512, bias=True)
    (WV): Linear(in_features=512, out_features=512, bias=True)
    (WO): Linear(in_features=512, out_features=512, bias=True)
    (softmax): Softmax(dim=-1)
  )
  (attn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (ff): PositionwiseFeedForward(
    (dropout): Dropout(p=0.1, inplace=False)
    (W1): Linear(in_features=512, out_features=2048, bias=True)
    (W2): Linear(in_features=2048, out_features=512, bias=True)
    (relu): ReLU()
  )
  (ff_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


In [23]:
class Encoder(nn.Module):
    def __init__(self, dk, dv, h, num_encoders, dim_multiplier = 4, pdropout=0.1):
        super().__init__()
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(dk, 
                         dv, 
                         h, 
                         dim_multiplier, 
                         pdropout) for _ in range(num_encoders)
        ])
    
    def forward(self, src_inputs, src_mask = None):
        """
        Input from the Embedding layer
        src_inputs = (B - batch size, S/T - max token sequence length, D- model dimension)
        """
        src_representation = src_inputs
        
        # Forward pass through encoder stack
        for enc in self.encoder_layers:
            src_representation, attention_wts = enc(src_representation, src_mask)
            
        self.attention_wts = attention_wts
        return src_representation
        

In [24]:
dk = 64
dv = 64
h = 8
num_encoders = 6
dim_multiplier = 4
pdropout=0.1
net = Encoder(dk, dv, h, num_encoders, dim_multiplier, pdropout)
print(net)
# summary(net, input_size=(2, 512), device="cpu")  

Encoder(
  (encoder_layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (attention): MultiHeadAttention(
        (WQ): Linear(in_features=512, out_features=512, bias=True)
        (WK): Linear(in_features=512, out_features=512, bias=True)
        (WV): Linear(in_features=512, out_features=512, bias=True)
        (WO): Linear(in_features=512, out_features=512, bias=True)
        (softmax): Softmax(dim=-1)
      )
      (attn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ff): PositionwiseFeedForward(
        (dropout): Dropout(p=0.1, inplace=False)
        (W1): Linear(in_features=512, out_features=2048, bias=True)
        (W2): Linear(in_features=2048, out_features=512, bias=True)
        (relu): ReLU()
      )
      (ff_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)


## Understanding Decoder Model

In [25]:
class DecoderLayer(nn.Module):
    def __init__(
                self, 
                dk, 
                dv, 
                h,
                dim_multiplier = 4, 
                pdropout = 0.1):
        super().__init__()
        
        # Reference page 5 chapter 3.2.2 Multi-head attention
        dmodel = dk*h
        # Reference page 5 chapter 3.3 positionwise FeedForward
        dff = dmodel * dim_multiplier
        
        # Masked Multi Head Attention
        self.masked_attention = MultiHeadAttention(dk, dv, h)
        self.masked_attn_norm = nn.LayerNorm(dmodel)
        
        # Multi head attention
        self.attention = MultiHeadAttention(dk, dv, h)
        self.attn_norm = nn.LayerNorm(dmodel)
        
        # Add position FeedForward Network
        self.ff = PositionwiseFeedForward(dmodel, dff, pdropout=pdropout)
        self.ff_norm = nn.LayerNorm(dmodel)
        
        self.dropout = nn.Dropout(p = pdropout)
    
    def forward(self, target_inputs, src_inputs, target_mask, src_mask):
        """
        Input from the Embedding layer
        target_inputs = embedded sequences    (batch_size, trg_seq_length, d_model)
        src_inputs = embedded sequences       (batch_size, src_seq_length, d_model)
        target_mask = mask for the sequences  (batch_size, 1, trg_seq_length, trg_seq_length)
        src_mask = mask for the sequences     (batch_size, 1, 1, src_seq_length)
        """
        mmha_out, attention_wts = self.masked_attention(
                                query = target_inputs, 
                                key = target_inputs, 
                                val = target_inputs, 
                                mask = target_mask)
        
        # Residual connection between input and sublayer output, details: Page 7, Chapter 5.4 "Regularization",
        # Actual paper design is the following
        target_inputs = self.masked_attn_norm(target_inputs + self.dropout(mmha_out))
        
        # Inputs to the decoder attention is given as follows
        # query = previous decoder layer
        # key and val = output of encoder
        # mask = src_mask
        # Reference : page 5 chapter 3.2.3 point 1
        mha_out, attention_wts = self.attention(
                                query = target_inputs, 
                                key = src_inputs, 
                                val = src_inputs, 
                                mask = src_mask)
        target_inputs = self.attn_norm(target_inputs + self.dropout(mha_out))
        
        pff_out = self.ff(target_inputs)
        # Perform Add Norm again
        out = self.ff_norm(target_inputs + self.dropout(pff_out))
        return out, attention_wts

In [26]:
dk = 64
dv = 64
h = 8
net = DecoderLayer(dk, dv, h)
print(net)
# summary(net, input_size=(2, 512), device="cpu")  

DecoderLayer(
  (masked_attention): MultiHeadAttention(
    (WQ): Linear(in_features=512, out_features=512, bias=True)
    (WK): Linear(in_features=512, out_features=512, bias=True)
    (WV): Linear(in_features=512, out_features=512, bias=True)
    (WO): Linear(in_features=512, out_features=512, bias=True)
    (softmax): Softmax(dim=-1)
  )
  (masked_attn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (attention): MultiHeadAttention(
    (WQ): Linear(in_features=512, out_features=512, bias=True)
    (WK): Linear(in_features=512, out_features=512, bias=True)
    (WV): Linear(in_features=512, out_features=512, bias=True)
    (WO): Linear(in_features=512, out_features=512, bias=True)
    (softmax): Softmax(dim=-1)
  )
  (attn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (ff): PositionwiseFeedForward(
    (dropout): Dropout(p=0.1, inplace=False)
    (W1): Linear(in_features=512, out_features=2048, bias=True)
    (W2): Linear(in_features=2048, out_features=

In [27]:
class Decoder(nn.Module):
    def __init__(
                self, 
                dk, 
                dv, 
                h, 
                num_decoders, 
                dim_multiplier = 4, 
                pdropout=0.1):
        super().__init__()
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(dk, 
                         dv, 
                         h, 
                         dim_multiplier, 
                         pdropout) for _ in range(num_decoders)
        ])
        
    def forward(self, target_inputs, src_inputs, target_mask, src_mask):
        """
        Input from the Embedding layer
        target_inputs = embedded sequences    (batch_size, trg_seq_length, d_model)
        src_inputs = embedded sequences       (batch_size, src_seq_length, d_model)
        target_mask = mask for the sequences  (batch_size, 1, trg_seq_length, trg_seq_length)
        src_mask = mask for the sequences     (batch_size, 1, 1, src_seq_length)
        """
        target_representation = target_inputs
        
        # Forward pass through decoder stack
        for layer in self.decoder_layers:
            target_representation = layer(
                                    target_representation,
                                    src_inputs, 
                                    target_mask,
                                    src_mask)
        return target_representation
        
        

In [28]:
dk = 64
dv = 64
h = 8
num_decoders = 6
dim_multiplier = 4
pdropout=0.1
net = Decoder(dk, dv, h, num_decoders, dim_multiplier, pdropout)
print(net)
# summary(net, input_size=([[2, 10, 512], [2, 10, 512], [2, 1, 10, 10], [2, 1, 1, 10]]), device="cpu")  

Decoder(
  (decoder_layers): ModuleList(
    (0-5): 6 x DecoderLayer(
      (masked_attention): MultiHeadAttention(
        (WQ): Linear(in_features=512, out_features=512, bias=True)
        (WK): Linear(in_features=512, out_features=512, bias=True)
        (WV): Linear(in_features=512, out_features=512, bias=True)
        (WO): Linear(in_features=512, out_features=512, bias=True)
        (softmax): Softmax(dim=-1)
      )
      (masked_attn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (attention): MultiHeadAttention(
        (WQ): Linear(in_features=512, out_features=512, bias=True)
        (WK): Linear(in_features=512, out_features=512, bias=True)
        (WV): Linear(in_features=512, out_features=512, bias=True)
        (WO): Linear(in_features=512, out_features=512, bias=True)
        (softmax): Softmax(dim=-1)
      )
      (attn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (ff): PositionwiseFeedForward(
        (dropout): Dropout(p=0.1,

## Adding all up to construct the complete model for language translation

In [29]:
class Transformer(nn.Module):
    def __init__(self,
                dk, 
                dv, 
                h,
                src_vocab_size,
                target_vocab_size,
                num_encoders,
                num_decoders,
                dim_multiplier = 4, 
                pdropout=0.1,
                device = "cpu"
                ):
        super().__init__()
        
        dmodel = dk*h
        
        # Modules required to build Encoder
        self.src_embeddings = Embedding(src_vocab_size, dmodel)
        self.src_positional_encoding = PositionalEncoding(
                                        dmodel,
                                        max_seq_length = src_vocab_size,
                                        pdropout = pdropout
                                        )
        self.encoder = Encoder(
                                dk, 
                                dv, 
                                h, 
                                num_encoders, 
                                dim_multiplier=dim_multiplier, 
                                pdropout=pdropout)
        
        # Modules required to build Decoder
        self.target_embeddings = Embedding(target_vocab_size, dmodel)
        self.target_positional_encoding = PositionalEncoding(
                                        dmodel,
                                        max_seq_length = target_vocab_size,
                                        pdropout = pdropout
                                        )
        self.decoder = Decoder(
                                dk, 
                                dv, 
                                h, 
                                num_decoders,  
                                dim_multiplier=4, 
                                pdropout=0.1)
        
        # Final output 
        self.linear = nn.Linear(dmodel, target_vocab_size)
        self.softmax = nn.Softmax(dim=-1)
        self.device = device
        self.init_params()  
    
    # This part wasn't mentioned in the paper, but it's super important!
    def init_params(self):
        """
        xavier has tremendous impact! I didn't expect
        that the model's perf, with normalization layers, 
        is so dependent on the choice of weight initialization.
        """
        for name, p in self.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
                
    def make_src_mask(self, src, src_pad_idx):
        """
        Args:
            src: raw sequences with padding        (batch_size, seq_length) 
            src_pad_idx(int): index where the token need not be attended

        Returns:
            src_mask: mask for each sequence            (batch_size, 1, 1, seq_length)
        """
        batch_size = src.shape[0]
        # assign 1 to tokens that need attended to and 0 to padding tokens, 
        # then add 2 dimensions
        src_mask = (src != src_pad_idx).view(batch_size, 1, 1, -1)
        return src_mask
    
    def make_target_mask(self, target, target_pad_idx):
        """
        Args:
            target:  raw sequences with padding        (batch_size, seq_length)     
            target_pad_idx(int): index where the token need not be attended

        Returns:
            target_mask: mask for each sequence   (batch_size, 1, seq_length, seq_length)
        """

        seq_length = target.shape[1]
        batch_size = target.shape[0]
        
        # assign True to tokens that need attended to and False to padding tokens, then add 2 dimensions
        target_mask = (trg != target_pad_idx).view(batch_size, 1, 1, -1) # (batch_size, 1, 1, seq_length)

        # generate subsequent mask
        trg_sub_mask = torch.tril(torch.ones((seq_length, seq_length), device=self.device)).bool() # (batch_size, 1, seq_length, seq_length)

        # bitwise "and" operator | 0 & 0 = 0, 1 & 1 = 1, 1 & 0 = 0
        target_mask = target_mask & trg_sub_mask

        return target_mask
        
    def forward(
        self, 
        src_token_ids_batch, 
        target_token_ids_batch, 
        src_pad_idx, 
        target_pad_idx):
        
        # create source and target masks     
        src_mask = self.make_src_mask(
                        src_token_ids_batch, 
                        src_pad_idx) # (batch_size, 1, 1, src_seq_length)
        target_mask = self.make_target_mask(
                        target_token_ids_batch, 
                        target_pad_idx) # (batch_size, 1, trg_seq_length, trg_seq_length)
        
        # Create embeddings
        src_representations = self.src_embeddings(src_token_ids_batch)
        src_representations = self.src_positional_encoding(src_representations)
        
        target_representations = self.target_embeddings(target_token_ids_batch)
        target_representations = self.target_positional_encoding(target_representations)
        
        # Encode 
        encoded_src = self.encoder(src_representations, src_mask)
        
        # Decode
        decoded_output = self.decoder(
                                target_representations, 
                                encoded_src, 
                                target_mask, 
                                src_mask)
        
        # Post processing
        out = self.linear(decoded_output)
        # Output 
        out = self.softmax(out)
        return out
    
    

In [30]:
dk = 64
dv = 64
h = 8
src_vocab_size = 1000
target_vocab_size = 1000
num_encoders = 6
num_decoders = 6
dim_multiplier = 4
pdropout=0.1
device = "cuda" if torch.cuda.is_available() else "cpu"
net = Transformer(
                dk, 
                dv, 
                h,
                src_vocab_size,
                target_vocab_size,
                num_encoders,
                num_decoders,
                dim_multiplier, 
                pdropout,
                device = device)
print(net)

Transformer(
  (src_embeddings): Embedding(
    (embedding_lookup): Embedding(1000, 512)
  )
  (src_positional_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Encoder(
    (encoder_layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (attention): MultiHeadAttention(
          (WQ): Linear(in_features=512, out_features=512, bias=True)
          (WK): Linear(in_features=512, out_features=512, bias=True)
          (WV): Linear(in_features=512, out_features=512, bias=True)
          (WO): Linear(in_features=512, out_features=512, bias=True)
          (softmax): Softmax(dim=-1)
        )
        (attn_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (ff): PositionwiseFeedForward(
          (dropout): Dropout(p=0.1, inplace=False)
          (W1): Linear(in_features=512, out_features=2048, bias=True)
          (W2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
        )
        (ff_norm):