# Training Notebook

Library Imports for the jupyter notebook

In [1]:
import os   # miscellaneous os interfaces
import sys  # configuring python runtime environment
import time # library for time manipulation, and logging

In [2]:
# use `datetime` to control and preceive the environment
# in addition `pandas` also provides date time functionalities
import datetime as dt

In [3]:
from copy import deepcopy      # dataframe is mutable
from tqdm import tqdm     # progress bar for loops
from uuid import uuid4 as UUID # unique identifier for objs

In [4]:
import numpy as np
import matplotlib.pyplot as plt

In [5]:
!pip install torchsummary

Collecting torchsummary
  Downloading torchsummary-1.5.1-py3-none-any.whl (2.8 kB)
Installing collected packages: torchsummary
Successfully installed torchsummary-1.5.1


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms

from torchinfo import summary
from torch.utils.data import DataLoader
from torchsummary import summary

In [7]:
import math

## Building model

Create a transformer model from the original [transformer paper](https://arxiv.org/abs/1706.03762)

Lets start the build by understanding the fundamental block of transformers and build the entire model from here

### Understanding Multi Head Attention (MHA)

In [8]:
class UnoptimizedMultiHeadAttention(nn.Module):
    """
    We can refer to the following blog to understand in depth about the transformer and MHA
    https://jalammar.github.io/illustrated-transformer/
    """
    def __init__(self, dk, dv, h):
        """
        Input Args:
        
        dk(int): Key dimensions used for generating Key weight matrix
        dv(int): Val dimensions used for generating val weight matrix
        h(int) : Number of heads in MHA
        """
        super().__init__()
        assert dk == dv
        self.dk = dk
        self.dv = dv
        self.h = h
        self.dmodel = self.dk * self.h  # model dimension
        
        # Add the params in modulelist as the params in the conv list needs to be tracked
        # wq, wk, wv -> multiple linear weights for the number of heads
        self.WQ = nn.ModuleList([nn.Linear(self.dmodel, self.dk) for _ in range(self.h)]) # shape -> (dmodel, dk)
        self.WK = nn.ModuleList([nn.Linear(self.dmodel, self.dk) for _ in range(self.h)]) # shape -> (dmodel, dk)
        self.WV = nn.ModuleList([nn.Linear(self.dmodel, self.dv) for _ in range(self.h)]) # shape -> (dmodel, dv)
        # Output Weights
        self.WO = nn.Linear(self.h*self.dv, self.dmodel)  # shape -> (dmodel, dmodel)
        
#         self.attention_dropout = nn.Dropout(p=dropout_probability)
        self.softmax = nn.Softmax(dim=-1)
        
    def attention(self, query, key, val):
        """
        Perform Scaled Dot Product Attention on multi head attention. 
        
        Notation: B - batch size, S/T - max src/trg token-sequence length
        query shape = (B, dmodel, S/T)
        key shape = (B, dmodel, S/T)
        val shape = (B, dmodel, S/T)
        """
        head = []
        # Create multiple heads using SDP
        for i in range(self.h):
            Q = self.WQ[i](query) # shape -> (B, 1, dk)
            K = self.WK[i](key)   # shape -> (B, 1, dk)
            V = self.WV[i](val)   # shape -> (B, 1, dv)
            score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.dk) # shape -> (B, 1, 1)
            score = self.softmax(score)
            H = torch.matmul(score, V) # V Transpose not needed here as per the paper shape -> (B, 1, dk)
            head.append(H)
        return head
            
    def forward(self, x):
        """
        Forward pass for MHA
        """
        query = key = val = x # For visualization we use the same input for all shape = (B, 1, dmodel)
        # Calculate multi head attentions for Q, K, V
        head = self.attention(query, key, val)
        # Concatenate multiple head along dim 1 as head shape = [B x 1 x dk]xh
        # therefore resultant would be out shape = B x 1 x dk*h
        out = torch.cat(head, axis=-1)
        # Final token_representation shape = (B, (dmodel*h), dmodel)
        token_representation = self.WO(out)  # shape = B x 1 x (dk*h=dmodel)
        return token_representation
        
        
    
    

In [9]:
dk = dv = 64
h = 8

In [10]:
net = UnoptimizedMultiHeadAttention(dk, dv, h)
print(net)
summary(net, (1, 512)) # Input should be 1, dk*h

UnoptimizedMultiHeadAttention(
  (WQ): ModuleList(
    (0-7): 8 x Linear(in_features=512, out_features=64, bias=True)
  )
  (WK): ModuleList(
    (0-7): 8 x Linear(in_features=512, out_features=64, bias=True)
  )
  (WV): ModuleList(
    (0-7): 8 x Linear(in_features=512, out_features=64, bias=True)
  )
  (WO): Linear(in_features=512, out_features=512, bias=True)
  (softmax): Softmax(dim=-1)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 1, 64]          32,832
            Linear-2                [-1, 1, 64]          32,832
            Linear-3                [-1, 1, 64]          32,832
           Softmax-4                 [-1, 1, 1]               0
            Linear-5                [-1, 1, 64]          32,832
            Linear-6                [-1, 1, 64]          32,832
            Linear-7                [-1, 1, 64]          32,832
           Softmax-8       

Now optimizing the multi head attention by removing the for loop and introducing matrix calculation for the optimization.

We also refer to this [blog](https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a) to understand the MHA further

In [11]:
class MultiHeadAttention(nn.Module):
    """
    We can refer to the following blog to understand in depth about the transformer and MHA
    https://medium.com/@hunter-j-phillips/multi-head-attention-7924371d477a
    
    Here we are clubbing all the linear layers together and duplicating the inputs and 
    then performing matrix multiplications
    """
    def __init__(self, dk, dv, h):
        """
        Input Args:
        
        dk(int): Key dimensions used for generating Key weight matrix
        dv(int): Val dimensions used for generating val weight matrix
        h(int) : Number of heads in MHA
        """
        super().__init__()
        assert dk == dv
        self.dk = dk
        self.dv = dv
        self.h = h
        self.dmodel = self.dk * self.h  # model dimension
        
        # Add the params in modulelist as the params in the conv list needs to be tracked
        # wq, wk, wv -> multiple linear weights for the number of heads
        self.WQ = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        self.WK = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        self.WV = nn.Linear(self.dmodel, self.dmodel) # shape -> (dmodel, dmodel)
        # Output Weights
        self.WO = nn.Linear(self.h*self.dv, self.dmodel)  # shape -> (dmodel, dmodel)
        self.softmax = nn.Softmax(dim=-1)
        
            
    def forward(self, x):
        """
        Forward pass for MHA
        
        X has a size of (batch_size, seq_length, d_model)
        Wq, Wk, and Wv have a size of (d_model, d_model)
        
        Perform Scaled Dot Product Attention on multi head attention. 
        
        Notation: B - batch size, S/T - max src/trg token-sequence length
        query shape = (B, S, dmodel)
        key shape = (B, S, dmodel)
        val shape = (B, S, dmodel)
        """
        query = key = val = x # For visualization we use the same input for all shape = (B, S, dmodel)
        
        # Weight the queries
        Q = self.WQ(query)     # shape -> (B, S, dmodel)
        K = self.WK(key)       # shape -> (B, S, dmodel)
        V = self.WV(val)       # shape -> (B, S, dmodel)
        
        # Separate last dimension to number of head and dk
        batch_size = Q.size(0)   
        Q = Q.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        K = K.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        V = V.view(batch_size, -1, self.h, self.dk)   # shape -> (B, S, h, dk)
        
        # each sequence is split across n_heads, with each head receiving seq_length tokens 
        # with d_key elements in each token instead of d_model.
        Q = Q.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        K = K.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        V = V.permute(0, 2, 1, 3) # shape -> (B, h, S, dk)
        
        # dot product of Q and K
        scaled_dot_product = torch.matmul(Q, K.permute(0, 1, 3, 2)) / math.sqrt(self.dk)
        scaled_dot_product = self.softmax(scaled_dot_product)
        
        # Create head 
        head = torch.matmul(scaled_dot_product, V)  # shape -> (B, h, S, S) * (B, h, S, dk) = (B, h, S, dk)
        # Prepare the head to pass it through output linear layer
        head = head.permute(0, 2, 1, 3).contiguous()  # shape -> (B, S, h, dk)
        # Concatenate the head together
        head = head.view(batch_size, -1, self.h* self.dk)  # shape -> (B, S, (h*dk = dmodel))
        # Pass through output layer
        token_representation = self.WO(head)
        return token_representation
        

In [12]:
dk = dv = 64
h = 8
net = MultiHeadAttention(dk, dv, h)
print(net)
summary(net, (10, 512)) # Input should be S, (dk*h=dmodel)

MultiHeadAttention(
  (WQ): Linear(in_features=512, out_features=512, bias=True)
  (WK): Linear(in_features=512, out_features=512, bias=True)
  (WV): Linear(in_features=512, out_features=512, bias=True)
  (WO): Linear(in_features=512, out_features=512, bias=True)
  (softmax): Softmax(dim=-1)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1              [-1, 10, 512]         262,656
            Linear-2              [-1, 10, 512]         262,656
            Linear-3              [-1, 10, 512]         262,656
           Softmax-4            [-1, 8, 10, 10]               0
            Linear-5              [-1, 10, 512]         262,656
Total params: 1,050,624
Trainable params: 1,050,624
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 0.16
Params size (MB): 4.01
Estimated Total Size (MB): 4.19


## Understanding Positonal Embedding