# **Implementing Transformer by Replicating *Attention Is All You Need***

## **I. Components of the Transformer ->**

### 1. Input Embeddings -> 

In [23]:
# %%writefile modules/_01_inputEmbeddings.py
import torch
import torch.nn as nn
import math

class InputEmbeddings(nn.Module):
    
    def __init__(self, d_model:int, vocab_size:int ):
        super().__init__()
        self.d_model = d_model  #Dimensionality -> d_model = 512: You choose to represent each word by a 512-dimensional vector.
        self.vocab_size = vocab_size #Number of Tokens 
        self.embedding = nn.Embedding(vocab_size,d_model)
        
    def forward(self,x):
        return self.embedding(x)*math.sqrt(self.d_model)
        

### 2. Positional Encoding -> 

\begin{align}
PE(pos, 2i)   &= \sin \left( \frac{pos}{10000^{\frac{2i}{d_{\text{model}}}}} \right) \\
PE(pos, 2i+1) &= \cos \left( \frac{pos}{10000^{\frac{2i}{d_{\text{model}}}}} \right)
\end{align}


In [24]:
# %%writefile modules/_02_positionalEncoding.py 

import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
    
    def __init__(self,d_model:int, seq_len:int, dropout:float ):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = dropout
        
        #Creating a Matrix of shape (seq_len , d_model)
        pe = torch.zeros(seq_len,d_model)
        
        #Creating a position vector of length seq_len
        position = torch.arange(0,seq_len,dtype=torch.float).unsqueeze(1) #[0,1,2,...,n]
        
        div_term = torch.exp(torch.arange(0,d_model,2)).float() * (-math.log(10000)/d_model) #This comes from the denominator of the function.
        
        
        #Applying sin to even positions and cos to odd positions
        pe[:, 0::2] = torch.sin(position*div_term)
        pe[:, 1::2] = torch.cos(position*div_term)
        
        pe = pe.unsqueeze(dim = 0) # [1,seq_len,d_model]
        
        #Register Buffer 
             #--> Is used for saving positonal encoding to model's state_dict as it is not updated during any backward propagation step but is needed for reliability and reusability. 
        self.register_buffer('pe',pe)         
        
    
    def forward(self,x):
        #Input tokens  --> Input tokens  + Position of the respective tokens
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
        
        #Dropout layer 
        return self.dropout(x)
        

### 3. Layer Normalisation -> 


\begin{align}
\large \hat{x}_j=\frac{x_j-\mu_j}{\sqrt{\sigma_j^2+\epsilon}}
\end{align}




In [22]:
import torch.nn as nn
alpha = nn.Parameter(torch.ones(1)) 
print(alpha)

Parameter containing:
tensor([1.], requires_grad=True)


In [25]:
# %%writefile modules/_03_layerNormalisation.py
import torch
import torch.nn as nn

class LayerNormalization(nn.Module):
    
    def __init__(self, eps:float = 10**-6):
        super().__init__()
        
        #Epsilon is a small value added for numerical stability and also to avoid division by 0
        self.eps = eps
        
        self.alpha = nn.Parameter(torch.ones(1))    #This is multiplied
        self.bias = nn.Parameter(torch.zeros(1))    #This is added
        
    
    def forward(self,x):
        mean = x.mean(dim = -1 , keepdim=True) #Usually mean doesnt keep dimension.
        std = x.std(dim = -1 , keepdim=True) #Usually std doesnt keep dimension.
        
        return self.alpha * (x-mean)/torch.sqrt(std + self.eps)   + self.bias

### 4. Feed-Forward Network -> 

The Feed Forward Network (FFN) in the Transformer model is represented as:


\begin{align}
\large \operatorname{FFN}(x)=\max \left(0, x W_1+b_1\right) W_2+b_2
\end{align}




In [26]:
# %%writefile modules/_04_feedForwardNetwork.py

import torch
import torch.nn as nn

class FeedForwardNetwork(nn.Module):
    def __init__(self, d_model: int , d_ff:int , dropout:float):
        super().__init__()
        
        self.linear_1=nn.Linear(in_features=d_model,out_features=d_ff) #W1,B1
        self.dropout = nn.Dropout(p=dropout) 
        
        self.linear_2 = nn.Linear(d_ff,d_model) #W2,B2
        
    def forward(self, x): 
        #(Batch_len , Seq_len , d_model) -> (Batch_len , Seq_len , d_ff) ->(Batch_len , Seq_len , d_model) 
        
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x)))) 

### 5. Multi-Head Attention -> 


<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <style>
        .centered-image {
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 90%; /* Adjust the width as needed */
        }
    </style>
</head>
<body>
    <img src="https://i.ibb.co/Y0mbNbH/image.png" alt="Transformer Encoder Block" class="centered-image">
</body>
</html>


In [27]:
# %%writefile modules/_05_multiHeadAttention.py

import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    
    def __init__(self,d_model:int,h:int,dropout:float):
        super().__init__()
        
        self.d_model = d_model
        self.h=h
        assert d_model%h==0 , 'Embedding dimension must be divisible by number of heads'
        
        self.d_k = d_model//h  #d_k = d_model/h
        
        self.w_q = nn.Linear(d_model,d_model) #Wq
        self.w_k = nn.Linear(d_model,d_model) #Wk
        self.w_v = nn.Linear(d_model,d_model) #Wv
        
        self.w_o = nn.Linear(d_model,d_model) #Wo as d_v is same as d_k and d_k*h = d_model

        self.dropout = nn.Dropout(p=dropout)
        
        
    @staticmethod    
    def attention(query,key,value,mask,dropout:nn.Dropout):
        #Calculating the attention score
        d_k = query.shape[-1]
        attention_scores = (query @ key.transpose(-2,-1))/math.sqrt(d_k) # @ --> Matrix Multiplication
        
        if mask is not None:
            attention_scores.masked_fill_(mask == 0 , -1e9) #-1e9 --> -Infinity. These values later become 0 after softmax
            
        attention_scores = attention_scores.softmax(dim = -1) #(Batch , h , Seq_len , Seq_len)
        
        if dropout is not None:
            attention_scores = dropout(attention_scores)
            
        return (attention_scores @ value) , attention_scores 
        
        
        
        
    def forward(self,q,k,v,mask):
        
        #Getting the Q',K' & V'
        
        query = self.w_q(q) # (Batch , Seq_len , d_model) -> (Batch , Seq_len , d_model)
        key = self.w_k(k) # (Batch , Seq_len , d_model) -> (Batch , Seq_len , d_model)
        value = self.w_v(v) # (Batch , Seq_len , d_model) -> (Batch , Seq_len , d_model)
        
        #Splitting the d_model dimension into h heads
        
        #Here query.shape[0] -> Batch & query.shape[1] -> Seq_len
        #Final shape of query after transpose -> (Batch , h , Seq_len , d_k) 
        query = query.reshape(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1,2)
        key = key.reshape(key.shape[0] , key.shape[1],self.h,self.d_k).transpose(1,2)
        value=key.reshape(value.shape[0] , value.shape[1],self.h,self.d_k).transpose(1,2) 
        
        x,self.attention_score = MultiHeadAttention.attention(query,key,value,mask,self.dropout)
        #Shape of x -> (Batch , h , Seq_len , d_k)
        #Shape of attention_score -> (Batch , h , Seq_len , Seq_len)
        
        x= x.transpose(1,2).contiguous().view(x.shape[0], -1 , self.d_k * self.h) #(Batch , Seq_len , d_model)
        
        
        return self.w_o(x) #(Batch , Seq_len , d_model) -> (Batch , Seq_len , d_model)
                   

### 6. Residual Connection -> 

Used all over the model to add the input x to modified input.

In [28]:
# %%writefile modules/_06_residualConnection.py

import torch
import torch.nn as nn
import math
from modules._03_layerNormalisation import LayerNormalization

class ResidualConnection(nn.Module):
    
    def __init__(self,features,dropout):
        super(ResidualConnection,self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)
        
    def forward(self,x,sublayer):
        return x + self.dropout(sublayer(self.norm(x))) 

### 7. Encoder Block -> 


<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <style>
        .centered-image {
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 50%; /* Adjust the width as needed */
        }
    </style>
</head>
<body>
    <img src="https://raw.githubusercontent.com/hyunwoongko/transformer/master/image/model.png" alt="Transformer Encoder Block" class="centered-image">
</body>
</html>


In [29]:
# %%writefile modules/_07_encoder.py


import torch
import torch.nn as nn
from modules._03_layerNormalisation import LayerNormalization
from modules._04_feedForwardNetwork import FeedForwardNetwork
from modules._05_multiHeadAttention import MultiHeadAttention
from modules._06_residualConnection import ResidualConnection

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttention, feed_forward_block: FeedForwardNetwork, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout=dropout,features= features) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x
    
class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

### 8. Decoder -> 

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <style>
        .centered-image {
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 50%; /* Adjust the width as needed */
        }
    </style>
</head>
<body>
    <img src="https://raw.githubusercontent.com/hyunwoongko/transformer/master/image/model.png" alt="Transformer Encoder Block" class="centered-image">
</body>
</html>


In [30]:
# %%writefile modules/_08_decoder.py 
import torch
import torch.nn as nn

from modules._03_layerNormalisation import LayerNormalization
from modules._04_feedForwardNetwork import FeedForwardNetwork
from modules._05_multiHeadAttention import MultiHeadAttention
from modules._06_residualConnection import ResidualConnection


class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttention, cross_attention_block: MultiHeadAttention, feed_forward_block: FeedForwardNetwork, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(dropout=dropout, features=features) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x
    
class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

### 9. Linear / Projection Layer -> 

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <style>
        .centered-image {
            display: block;
            margin-left: auto;
            margin-right: auto;
            width: 50%; /* Adjust the width as needed */
        }
    </style>
</head>
<body>
    <img src="https://raw.githubusercontent.com/hyunwoongko/transformer/master/image/model.png" alt="Transformer Encoder Block" class="centered-image">
</body>
</html>


In [31]:

# %%writefile modules/_09_projectionLayer.py
import torch
import torch.nn as nn
class ProjectionLayer(nn.Module):
    
    def __init__(self , d_model:int , vocab_size : int):
        super().__init__()
        
        self.proj = nn.Linear(in_features=d_model, out_features=vocab_size)
        
    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim=-1)
        #Dim -1 is the last dimension, Log Softmax is used for numerical stability

## **II. Defining & Building the Transformer ->**

In [32]:
# %%writefile modules/_10_transformer.py
import torch
import torch.nn as nn

from modules._01_inputEmbeddings import InputEmbeddings
from modules._02_positionalEncoding import PositionalEncoding
from modules._03_layerNormalisation import LayerNormalization
from modules._04_feedForwardNetwork import FeedForwardNetwork
from modules._05_multiHeadAttention import MultiHeadAttention
from modules._06_residualConnection import ResidualConnection
from modules._07_encoder import Encoder
from modules._08_decoder import Decoder
from modules._09_projectionLayer import ProjectionLayer


class Transformer(nn.Module):
    def __init__(self,
                 encoder : Encoder,
                 decoder : Decoder,
                 src_embed : InputEmbeddings,
                 tgt_embed : InputEmbeddings,
                 src_pos : PositionalEncoding,
                 tgt_pos : PositionalEncoding, 
                 projection_layer : ProjectionLayer):
        super().__init__()
        
        #Initialising encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos & projection_layer
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
        
        
    def encode(self, src, src_mask):
        src = self.src_embed(src)
        src = self.src_pos(src)
        
        return self.encoder(src, src_mask) #Parameter order is same as per forward() of Encoder
    
    def decode(self, encoder_output, src_mask, tgt, tgt_mask):
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        
        return self.decoder(tgt, encoder_output,src_mask, tgt_mask) #Parameter order is same as per forward() of Decoder
    
    def project(self,x): 
        return self.projection_layer(x)
        

#### **We have defined all individual parts, Now a function to get all individual parts and build a transformer ->** 

In [33]:
# %%writefile modules/_11_buildTransformer.py
import torch
import torch.nn as nn
from modules._01_inputEmbeddings import InputEmbeddings
from modules._02_positionalEncoding import PositionalEncoding
from modules._03_layerNormalisation import LayerNormalization
from modules._04_feedForwardNetwork import FeedForwardNetwork
from modules._05_multiHeadAttention import MultiHeadAttention
from modules._06_residualConnection import ResidualConnection
from modules._07_encoder import Encoder , EncoderBlock
from modules._08_decoder import Decoder , DecoderBlock
from modules._09_projectionLayer import ProjectionLayer
from modules._10_transformer import Transformer

def build_transformer(src_vocab_size: int ,
                      tgt_vocab_size: int ,
                      src_seq_len: int ,
                      tgt_seq_len: int,
                      d_model : int = 512, 
                      N:int = 6, #According to paper, The number of encoder and decoder blocks is 6.
                      h:int = 8, #According to paper, The number of heads is 8.
                      dropout:int = 0.1,
                      d_ff:int  = 2048 #Dimensions of feed forward network is 2048
                      ):
    
    #Creating the embedding layers -> 
    src_embed = InputEmbeddings(d_model=d_model, vocab_size=src_vocab_size)
    tgt_embed = InputEmbeddings(d_model=d_model, vocab_size=tgt_vocab_size)
    
    #Creating the positional encoding layers -> 
    src_pos = PositionalEncoding(d_model=d_model, seq_len=src_seq_len, dropout=dropout) 
    tgt_pos = PositionalEncoding(d_model=d_model, seq_len=tgt_seq_len, dropout=dropout)
    
    #Creating the encoder blocks -> 
    encoder_blocks = []
    for _ in range(N):
        
        encoder_self_attention_block = MultiHeadAttention(d_model=d_model,h=h,dropout=dropout)
        encoder_feed_forward_block = FeedForwardNetwork(d_model=d_model,d_ff=d_ff,dropout=dropout)
        
        encoder_block = EncoderBlock(self_attention_block=encoder_self_attention_block, feed_forward_block=encoder_feed_forward_block,dropout=dropout, features=d_model) 
        encoder_blocks.append(encoder_block)
        
    #Creating the decoder blocks -> 
    decoder_blocks = []
    for _ in range(N):
        
        decoder_self_attention_block = MultiHeadAttention(d_model=d_model,h=h,dropout=dropout)
        decoder_cross_attention_block = MultiHeadAttention(d_model=d_model, h=h, dropout=dropout)
        decoder_feed_forward_block = FeedForwardNetwork(d_model=d_model, d_ff=d_ff, dropout=dropout)
        
        decoder_block = DecoderBlock(self_attention_block=decoder_self_attention_block, cross_attention_block=decoder_cross_attention_block,feed_forward_block=decoder_feed_forward_block,dropout=dropout,features=d_model)
        decoder_blocks.append(decoder_block)
        
    
    #Creating the encoder and decoder -> 
    encoder = Encoder(layers=nn.ModuleList(encoder_blocks), features=d_model)
    decoder = Decoder(layers=nn.ModuleList(decoder_blocks), features=d_model)
    
    #Creating the projection player -> 
    projection_layer = ProjectionLayer(d_model=d_model, vocab_size=tgt_vocab_size)
    
    
    
    #CREATING THE TRANSFORMER ----> 
    transformer = Transformer(encoder=encoder,
                              decoder=decoder,
                              src_embed=src_embed,
                              tgt_embed=tgt_embed,
                              src_pos=src_pos,
                              tgt_pos = tgt_pos,
                              projection_layer=projection_layer)
    
    #Initializing the parameters using xavier unuform distribution ->
    for p in transformer.parameters():
        if p.dim()>1:
            nn.init.xavier_uniform_(p)
        
    
    return transformer

## **III. Data Preparation & Model Initiation->**

- Tokenizer 
- Dataset
- Config

### 1 & 2. Tokenizer & Dataset -> 

A **tokenizer** converts a sequence of text into smaller, manageable units called tokens. These tokens can be words, subwords, characters, or even sentences, depending on the specific tokenizer and its configuration.

We also have special tokens for the **<<SOS>SOS>** (start of sentence) , **<<SOS>EOS>** (end of sentence) & **Padding**. 

1. **Text Normalization**: The tokenizer first normalizes the text by converting it to lowercase, removing punctuation, and handling special characters. This step ensures consistency and reduces the complexity of the text.

2. **Splitting**: The normalized text is then split into tokens based on predefined rules. For example, a word-level tokenizer might split text at spaces, while a character-level tokenizer would split text into individual characters.

3. **Mapping to IDs**: Each token is mapped to a unique identifier (ID) from a vocabulary. This step converts the text into a numerical format that can be processed by machine learning models.

4. **Padding**: The tokens are padded to a uniform length. This step ensures that all sequences in the batch have the same length.


Consider the sentence: "Tokenization is essential for NLP."

- **Normalized Text**: "tokenization is essential for nlp"
- **Tokens**: ["tokenization", "is", "essential", "for", "nlp"]
- **Token IDs**: [1012, 2003, 3722, 2005, 17953] (IDs are hypothetical)




In [34]:
# %%writefile modules/_12_bilingualDataset.py

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader 

class BilingualDataset(Dataset):
    
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang , seq_len ):
        super().__init__()

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.seq_len = seq_len
        
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.long)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.long)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.long)
        
    def __len__(self):
        return len(self.ds)
    
    def causal_mask(self,size):
        mask = torch.triu(torch.ones(size, size), diagonal=1).type(torch.int)
        return mask == 0
    
    def __getitem__(self, index):
        
        # Getting the source and target sentences together and then splitting them
        src_tgt_pair = self.ds[index]
        src_text = src_tgt_pair['translation'][self.src_lang]
        tgt_text = src_tgt_pair['translation'][self.tgt_lang]
        
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
        
        enc_num_pad_tokens = self.seq_len - len(enc_input_tokens) - 2
        dec_num_pad_tokens = self.seq_len - len(dec_input_tokens) - 1
        
        if enc_num_pad_tokens < 0 or dec_num_pad_tokens < 0:
            raise ValueError("The input sentence is too long!")
        
        # Adding the SOS, EOS and Padding to encoder input
        encoder_input = torch.cat(
            [
                self.sos_token, 
                torch.tensor(enc_input_tokens, dtype=torch.long),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_pad_tokens, dtype=torch.long)
            ]
        ) 
        
        # Adding SOS token and padding to decoder input
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.long),
                torch.tensor([self.pad_token] * dec_num_pad_tokens, dtype=torch.long)
            ]
        )
        
        # Adding EOS token and padding to label (What we expect as output from decoder)
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.long),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_pad_tokens, dtype=torch.long)
            ]
        )
        
        # Checking the sizes of the tensors
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
        
        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).long(),  # (1,1,seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).long() & self.causal_mask(decoder_input.size(0)),# (1,1,seq_len)
            "label": label, #(seq_len)
            "src_text":src_text,
            "tgt_text" : tgt_text
        }


In [35]:
# %%writefile modules/_13_buildTokenizer_DataLoader_and_Transformer.py
import torch
import torch.nn as nn

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from pathlib import Path
from torch.utils.data import DataLoader, Dataset, random_split, Subset
from modules._10_transformer import Transformer
from modules._11_buildTransformer import build_transformer
from modules._12_bilingualDataset import BilingualDataset

def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    
    if not tokenizer_path.exists():
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
        
    return tokenizer

def get_ds(config):
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    ds_raw = load_dataset('cfilt/iitb-english-hindi', 'default', split="train[:1%]")
    
    
    # Building the tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config["lang_src"])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config["lang_tgt"])
    
    # Splitting the training and testing data into 80% & 20% split
    train_ds_size = int(0.8 * len(ds_raw))
    test_ds_size = len(ds_raw) - train_ds_size
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    train_ds_raw, test_ds_raw = random_split(ds_raw, [train_ds_size, test_ds_size])
    
    train_ds = BilingualDataset(ds=train_ds_raw, tokenizer_src=tokenizer_src, tokenizer_tgt=tokenizer_tgt, src_lang=config["lang_src"], tgt_lang=config["lang_tgt"], seq_len=config["seq_len"])
    test_ds = BilingualDataset(ds=test_ds_raw, tokenizer_src=tokenizer_src, tokenizer_tgt=tokenizer_tgt, src_lang=config["lang_src"], tgt_lang=config["lang_tgt"], seq_len=config["seq_len"])
    
    max_len_src = 0
    max_len_tgt = 0
    
    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
    
    train_dataloader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
    test_dataloader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=True)
    
    return train_dataloader, test_dataloader , tokenizer_src, tokenizer_tgt


def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(src_vocab_size=vocab_src_len, tgt_vocab_size=vocab_tgt_len, src_seq_len=config["seq_len"], tgt_seq_len=config["seq_len"], d_model=config["d_model"])
    
    return model


### 3. Config File -> 


In [36]:
# %%writefile modules/_14_config.py
from pathlib import Path

def get_config():
    return {
        "batch_size": 8,
        "num_epochs": 1,
        "lr": 0.001,
        "seq_len": 500,
        "d_model": 512,
        "lang_src": "en",
        "lang_tgt": "hi",
        "dataset_config": "default",
        "model_folder": "weights",
        "model_basename": "transformerModel",
        "preload": "latest",
        "N": 6,
        "h": 8,
        "tokenizer_file": "tokenizer_{0}.json",
        "experiment_name": "runs/transformerModel"
    }

def get_weights_file_path(config, epoch: str):
    model_folder = f"{config['model_folder']}"
    model_filename = f"{config['model_basename']}{epoch}.pth"
    return str(Path('.') / model_folder / model_filename)

# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
    model_folder = f"{config['model_folder']}"
    model_filename = f"{config['model_basename']}*"
    weights_files = list(Path(model_folder).glob(model_filename))
    if len(weights_files) == 0:
        return None
    weights_files.sort()
    return str(weights_files[-1])

## **IV. Training and Testing Loops ->** 

#### **Training Loop ->** 

In [37]:
# %%writefile modules/_15_train.py
import torch
import torch.nn as nn
from pathlib import Path
from tqdm import tqdm
from modules._13_buildTokenizer_DataLoader_and_Transformer import get_or_build_tokenizer, get_model, get_ds
from modules._14_config import get_config, get_weights_file_path



def latest_weights_file_path(config):
    model_folder = Path(config['model_folder'])
    model_files = list(model_folder.glob('*.pth'))
    if not model_files:
        print(f"No model files found in {model_folder}")
        return None
    return max(model_files, key=lambda x: x.stat().st_mtime)

def train_model(config):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
    
    train_dataloader, test_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config=config)
    
    model = get_model(config=config, vocab_src_len=tokenizer_src.get_vocab_size(), vocab_tgt_len=tokenizer_tgt.get_vocab_size()).to(device)
    
    optimizer = torch.optim.Adam(params=model.parameters(), lr=config['lr'], eps=1e-9)
    
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    if preload == 'latest':
        model_filename = latest_weights_file_path(config=config)
    elif preload:
        model_filename = get_weights_file_path(config, preload)
    else:
        model_filename = None

    if model_filename and Path(model_filename).exists():
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No valid model to preload, starting from scratch')
    
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
    
    print(f"Starting training from epoch {initial_epoch}")
    print(initial_epoch)
    print(config['num_epochs'])
    for epoch in range(initial_epoch, config['num_epochs']):

        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f'Processing epoch {epoch:02d}')
        
        total_loss =  0
        for batch in batch_iterator:
            encoder_input = batch["encoder_input"].to(device)
            decoder_input = batch["decoder_input"].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            decoder_mask = batch['decoder_mask'].to(device)
            
            encoder_output = model.encode(src=encoder_input, src_mask=encoder_mask)
            decoder_output = model.decode(tgt=decoder_input, encoder_output=encoder_output, src_mask=encoder_mask, tgt_mask=decoder_mask)
            project_output = model.projection_layer(decoder_output)
            
            label = batch['label'].to(device)
            
            loss = loss_fn(project_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            
            total_loss += loss.item()
            batch_iterator.set_postfix({"Loss": f"{loss.item():6.3f}"})
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            global_step += 1
        
        average_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch} completed. Average Loss: {average_loss:.4f}")
        
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step,
            "loss": average_loss
        }, model_filename)
        print(f"Model saved: {model_filename}")

if __name__ == "__main__":
    config = get_config()
    train_model(config)

KeyboardInterrupt: 

#### **Validation Loop ->**

In [51]:
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

def greedy_decode(model, source, source_mask , tokenizer_src, tokenizer_tgt, max_len , device):
    #This function precomputes the encoder output and then reuses it for every token that it gets from the decoder.
    
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')
    
    encoder_output = model.encoder(source, source_mask)
    
    #Initialising the decoder input with the SOS token -> 
    decoder_input = torch.empty(1,1).fill(sos_idx).type_as(source).to(device)
    
    while True:
        if decoder_input.size(1) == max_len:
            break
        
        #Building a mask for the target (decode input ) -> 
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source).to(device)
        
        #Calculate the output of the decoder -> 
        out = model.decoder(encoder_output, source_mask, decoder_input, decoder_mask)
        
        #Getting the next token -> 
        prob = model.project(out[:,-1])
        
        #Selecting the token with maximum probability as it is called greedy search.
        _,next_word = torch.max(prob,dim=1)
        
        decoder_input = torch.cat ([decoder_input, torch.empty(1,1).type_as(source).fill_(next_word.item()).to(device)],dim=1)
        
        if next_word == eos_idx:
            break
    
    return decoder_input.squeeze(0)
    
    



def run_validation(model,validation_ds,tokenizer_src, tokenizer_tgt, max_len, device, print_msg = lambda msg: print(msg), num_examples=1,):
    
    model.eval()
    count = 0   
    
    #Size of the control window -> 
    console_window = 80
    
    with torch.inference_mode():
        for batch in validation_ds:
            count+=1
            
            encoder_input = batch['encoder_input'].to(device)
            encoder_mask = batch['encoder_mask'].to(device)
            assert encoder_input.size(0)==1 ,"Batch size should be 1 for Validation."
            
            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
            
            source_text=batch['src_text'][0]
            target_text=batch['tgt_text'][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
            
            print_msg("-"*console_window)
            print_msg(f"SOURCE : {source_text}")
            print_msg(f"TARGET : {target_text}")
            print_msg(f"PREDICTED : {model_out_text}")
            
            
            if count == num_examples:
                break
    
    
    