# Transformer Model From Scratch

**References**
- *[Coding a Transformer from scratch on PyTorch, with full explanation, training and inference](https://youtu.be/ISNdQcPhsts?si=M80AnG5chc6sKzk5) [Youtube Video]*
- *[An Introduction to Transformers](https://arxiv.org/abs/2304.10557) [Paper]*



In [1]:
# Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter

# General
import math
import warnings
from tqdm import tqdm
import os
from pathlib import Path
from typing import Dict, Tuple, List

# Huggingface
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

# lightning ai
import torchmetrics

# My python files
from config import Config
from helpers import latest_weights_file_path, get_weights_file_path

# Transformer Model

## Input Embeddings

- Input Matrix (sequence, d_model) - each word is made up of say 512 numbers then the shape of input matrix will be (x, 512) where x is the number of words.

![Input Embeddings](images/InputEmbeddings.png)

- Embedding isn't fixed, it is learned by the model.
- 

*Why are we multiplying the embedding output by the square root of the dimension?*


In [59]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        """Input Embedding Module.
        Initializes the Embedder module. This module is used to embed the input tokens into a vector space.
        
        Args:
            d_model (int): Dimension of Input Embedding
            vocab_size (int): Vocabulary Size
        """
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Embeds the input tensor into a vector space. 
        This is done by looking up the embedding for each token in the input tensor.
        - Scales output by sqrt(dimensions) as per Vaswani et al. (2017).

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, sequence_length).

        Returns:
            torch.Tensor: Output tensor in R^d (batch_size, sequence_length, dimensions).
        """
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)


## Positional Encoding

- Conveys information about the position of the word in a sentence.
- We want the model to treat words that appear close to each other as "close" and words that are distant as "distant"
- 

![Positional Embedding](images/PositionalEmbedding.png)

- Unlike embeddings, positional embeddings aren't learned. We just compute this once. *Is this why we need a large dataset? Because we aren't learning positional encoding to tell the model appropriately what word can be closer to the other?*
- 

The formula used to create the positional encoding:
![Positional Embedding Vector](images/PositionalEmbeddingVector.png)

- *Why is there an odd and even position equation?*
- *Why do we want the positional encoding to represent a pattern that can be learned by the model? Is that why we use sine and cosine?*



In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: Config.dropout_constant) -> None:
        """Initializes the PositionalEncoder module. 
        This module is used to encode the position of the input tokens into the input embeddings.
        Create a matrix of shape (max_sequence_length, dimensions) to store the positional encodings.
        For each position in the sequence, compute the positional encoding for each dimension.
        Store the positional encodings in the matrix.
        Implement the positional encodings as per Vaswani et al. (2017).

        Args:
            d_model (int): The number of dimensions to embed the input tokens into.
            seq_len (int): The maximum length of the sentence since we need to create one vector for each position.
            dropout (float): The dropout rate to apply to the PEs.
        """
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        # Create a matrix of shape(seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len, 1)
        # torch.arange - returns a 1D tensor of size (end-start/step) with values from the interval [start, end) 
        # taken with common difference step beginning from start.
        # ? Why are we unsqueezing here?
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)  # (seq_len, 1)
        # ? Formula - I am not sure of the arange part here and how it translates to the equation above
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # (d_model / 2)
        # ? It seems that you are doing this log this for numerical stability, need to look into why?
        # Apply the sin to even position & cosine to odd positions
        pe[:, 0::2] = torch.sin(position * div_term)  # sin(position * (10000 ** (2i / d_model))
        pe[:, 1::2] = torch.cos(position * div_term)  # cos(position * (10000 ** (2i / d_model))
        # Adding a new dimension to account for the batch of sentences, we use unsqueeze to do this.
        pe = pe.unsqueeze(0)  # (1, seq_len, d_model)

        # Register the positional encoding as a buffer
        # Why: to keep inside the module, not as a learned param, but to be saved along with the state of the model
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Adds the positional encodings to the input tensor. Applies dropout to the output tensor.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to add positional encodings to.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # Since we are not learning this tensor, we add requires grad as False
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)  # (batch, seq_len, d_model)
        return self.dropout(x)


## Multi-Head Attention

### Self Attention

- Why self-attention? Each word in a sentence relates to each other.

### How to compute Self Attention

- Self-Attention is permutation invariant.

![Self Attention Part 1](images/SelfAttentionPart1.png)

![Self Attention Part 2](images/SelfAttentionPart2.png)

- Self-Attention requires no parameters. Up to now the interaction between words has been driven by their embedding and the positional encodings. This will change later.
- We expect values along the diagonal to be the highest.
- If we don’t want some positions to interact, we can always set their values to –∞ before applying the softmax in this matrix and the model will not learn those interactions. We will use this in the decoder.

**What is query, keys and values?**
- Query (Q): This represents the current word or part of the sequence the model is focusing on.
- Key (K): This represents all the words or parts in the input sequence.
- Value (V): This also represents all the words or parts in the input sequence, but it contains the actual information associated with each key.

**Steps Involved:**
- Encoding Inputs: Each word in the input sequence is converted into a vector representation.
- Calculating Compatibility Scores: The model calculates a score for each key-query pair. This score represents how relevant a particular word (key) is to the current word being processed (query). There are different ways to calculate this score, but they typically involve measuring the dot product or cosine similarity between the query and key vectors.
- Softmax Attention: A softmax function is applied to the compatibility scores, converting them into attention weights. These weights represent the importance of each word in the sequence relative to the current word.
- Weighted Sum: The attention weights are multiplied element-wise with the value vectors. This essentially amplifies the information from the relevant parts of the sequence (based on the weights) and weakens the contribution of less relevant parts.
- Attention Output: The weighted sum of the value vectors is the attention output. This output incorporates information from the entire sequence, but with a focus on the parts most relevant to the current word.


### Multi-Head Attention

![Multi-Head Attention](images/MultiHeadAttention.png)


**Here's how it works:**
- Multiple Attention Heads: Instead of having a single set of query (Q), key (K), and value (V) vectors, the model creates multiple sets (often 4, 8, or 16). These are called attention heads.
- Linear Projections: Each input word embedding is projected independently for each attention head using different weight matrices. This creates different query, key, and value vectors for each head, allowing them to focus on different aspects of the relationships between words.
- Independent Attention: Each attention head then performs the standard attention mechanism steps (compatibility score calculation, softmax, weighted sum) independently. This results in multiple attention outputs, each capturing a different aspect of the context.
- Concatenation: Finally, the outputs from all the attention heads are concatenated to form a single final output. This combined output incorporates information from all the heads, providing a richer representation of the context.

In [6]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        """This module is used to calculate the attention scores for the input tensor.

        Args:
            d_model (int): Embedding vector size.
            h (int): Number of heads.
            dropout (float): The dropout rate to apply to the attention scores.
        """
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model is not divisible by h"

        # for integer divions: // for floating point division: /
        self.d_k = d_model // h  # Dimension of vector seen by each head

        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)

        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    # staticmethod: because we can call this without an instance of the above class
    @staticmethod
    def attention(
        query: torch.Tensor, 
        key: torch.Tensor, 
        value: torch.Tensor, 
        mask: torch.Tensor = None, 
        dropout: int = Config.dropout_constant
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Calculates the attention scores for the input tensor.

        Args:
            query (torch.Tensor): The query tensor to calculate the attention scores for.
            key (torch.Tensor): The key tensor to calculate the attention scores for.
            value (torch.Tensor): The value tensor to calculate the attention scores for.
            mask (torch.Tensor, optional): The mask tensor to apply to the attention scores. Defaults to None.
            dropout (int, optional): The dropout rate to apply to the attention scores. Defaults to Config.dropout.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        d_k = query.shape[-1]

        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        # @ matrix multiplication in pytorch
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        # masking is just multiplying by a very small value which then becomes zero after applying softmax
        if mask is not None:
            # replace all the values for mask==0 with -1e9
            attention_scores.masked_fill(mask==0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1)  # (batch, h, seq_len, seq_len)

        if dropout is not None:
            attention_scores = nn.Dropout(dropout)(attention_scores)
        
        # The second value in the tuple is used for visualization
        return (attention_scores @ value), attention_scores

    
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """Projects the input tensor into the query, key, and value tensors.
        Splits the query, key, and value tensors into multiple heads.
        Calculates the attention scores for the input tensor.
        Concatenates the attention scores for the multiple heads.
        Projects the attention scores back into the original dimension.

        Args:
            q (torch.Tensor): The query tensor to calculate the attention scores for.
            k (torch.Tensor): The key tensor to calculate the attention scores for.
            v (torch.Tensor): The value tensor to calculate the attention scores for.
            mask (torch.Tensor, optional): The mask tensor to apply to the attention scores. Defaults to None.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # for q, k, v: (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
    
        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        # Using view: keep batch dim since we don't want to split the sentence, we want to split the embedding
        # torch.tensor.view: returns a new tensor with the same data as the self tensor but of a different shape.
        # Why transpose: prefer h as the second dim, this way each head will see all the sentences
        # view() is used cause it's typically faster than reshape(), cause it doesn't need to copy the underlying data
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Getting attention scores
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        
        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.w_o(x)



## General Layers

### Layer Normalization

**Batch Normalization (BN):**
- Normalization Scope: BN normalizes activations across a mini-batch of training data. It calculates the mean and standard deviation of the activations for each feature (channel) within the mini-batch and then uses these statistics to normalize the individual activations.
- Impact: This helps to address the problem of internal covariate shift, where the distribution of activations can change significantly between mini-batches during training. By normalizing each mini-batch, BN ensures that the gradients used for updating the model's weights are more stable and less prone to exploding or vanishing gradients.

*Limitations*:
- Reliance on Batch Size: BN's effectiveness depends on the mini-batch size. Smaller batch sizes can lead to less accurate estimates of the mean and standard deviation, potentially reducing its effectiveness.
- Increased Memory Consumption: BN requires storing the mean and standard deviation for each feature across the mini-batch, which can increase memory usage.

![Layer Normalization](images/LayerNormalization.png)

**Layer Normalization (LN):**
- Normalization Scope: LN normalizes activations within each layer, independently for each sample in the mini-batch. It calculates the mean and standard deviation of the activations for each feature (channel) across all elements in a single sample (e.g., across the entire image width and height for a convolutional layer).
- Impact: LN also addresses internal covariate shift, but at a different level. It ensures that the activations within a layer have a consistent distribution for each sample, regardless of the mini-batch or other samples in the training data. This can improve the stability of gradients and learning.

*Benefits*:
- Less Sensitive to Batch Size: LN is less sensitive to the mini-batch size compared to BN. This makes it potentially more robust for various training scenarios, including settings with small batch sizes.
- Lower Memory Footprint: LN doesn't require storing statistics across the entire mini-batch, reducing memory consumption.

*Potential Drawbacks*:
- Limited Context: LN might not capture long-range dependencies between features as effectively as BN, which considers activations across the entire mini-batch.
- Less Flexibility: LN applies the same normalization across all features within a layer, while BN allows for different normalizations for each feature.



In [4]:
class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps: float = 10**-6) -> None:
        """Initializes the LayerNormalization module. This module is used to normalize the input tensor.

        Args:
            features (int): The number of features in the input tensor.
            eps (float, optional): Needed to prevent div by 0 error. Defaults to 10**-6.
        """
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))  # Multiplied
        self.bias = nn.Parameter(torch.zeros(features))  # Added
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Normalizes the input tensor using the gamma and beta parameters.
        - Assumes input has batch, the mean and standard deviation is calculated over the last dimension 
        leaving the batch dimension.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to normalize.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # x: (batch, seq_len, hidden_size)
        # Keep the dimension for broadcasting
        self.mean = x.mean(dim=-1, keepdim=True)  # (batch, seq_len, 1)
        self.std = x.std(dim=-1, keepdim=True)  # (batch, seq_len, 1)
        return self.alpha * (x - self.mean) / (self.std + self.eps) + self.bias
    

### Feed Forward



In [5]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        """This module is used to project the input tensor into a higher dimension 
        and then back into the original dimension.

        Args:
            d_model (int): The number of dimensions of the input tensor.
            d_ff (int): The number of dimensions to project the input tensor into.
            dropout (float): The dropout rate to apply to the output tensor.
        """
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)  # W1 and B1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)  # W2 and B2
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Projects the input tensor into a higher dimension using a linear layer and a ReLU activation function.
        Projects the input tensor back into the original dimension using a linear layer.
        Applies dropout to the output tensor.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to project and then project back.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))


### Residual Connection

In [7]:
class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float = Config.dropout_constant) -> None:
        """This module is used to add the input tensor to the output tensor.

        Args:
            features (int): The number of features in the input tensor.
            dropout (float, optional): The dropout rate to apply to the output tensor. Defaults to Config.dropout_constant.
        """
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)
    
    def forward(self, x: torch.Tensor, sublayer: torch.nn.Module) -> torch.Tensor:
        """Adds the input tensor to the output tensor. Applies dropout to the output tensor.

        Args:
            x (torch.Tensor): The input tensor to add to the output tensor.
            sublayer (torch.nn.Module): The network layer to apply to the input tensor.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        return x + self.dropout(sublayer(self.norm(x)))

## Encoder Block

![Encoder Block](images/EncoderBlock.png)



In [8]:
class EncoderBlock(nn.Module):
    def __init__(
            self,
            features: int,
            self_attention_block: MultiHeadAttentionBlock,
            feed_forward_block: FeedForwardBlock,
            dropout: float = Config.dropout_constant
        ) -> None:
        """Initializes the EncoderBlock module. This module is used to encode the input tensor.

        Args:
            features (int): The number of features in the input tensor.
            self_attention_block (MultiHeadAttentionBlock): The self attention layer to apply to the input tensor.
            feed_forward_block (FeedForwardBlock): The feed forward layer to apply to the input tensor.
            dropout (float, optional): The dropout rate to apply to the output tensor. Defaults to Config.dropout_constant.
        """
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connection = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
    
    def forward(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        """Applies the self attention layer to the input tensor.
        Adds the input tensor to the output tensor.
        Applies the feed forward layer to the output tensor.
        Adds the input tensor to the output tensor.

        Args:
            x (torch.Tensor): The input tensor to encode.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connection[1](x, self.feed_forward_block)
        return x

In [9]:
class Encoder(nn.Module):
    def __init__(self, features:int, layers: nn.ModuleList) -> None:
        """Initializes the Encoder module. This module is used to encode the input tensor.

        Args:
            features (int): The number of features in the input tensor.
            layers (nn.ModuleList): The layers to apply to the input tensor.
        """
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Applies the layers to the input tensor. Applies the layer normalization to the output tensor.

        Args:
            x (torch.Tensor): The input tensor to encode.
            mask (torch.Tensor): The mask tensor to apply to the attention scores.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

## Decoder Block

![Decoder Block](images/DecoderBlock.png)


**Masked Multi-Head Attention**
- Our goal is to make the model causal: it means the output at a certain position can only depend on the words on the previous positions. The model must not be able to see future words.

![Masked Multi-Head Attention](images/MaskedMultiHeadAttention.png)

- All the values above the diagonal are replace with -∞ before applying the softmax, which will replace them with zero.

In [10]:
class DecoderBlock(nn.Module):
    def __init__(
            self, 
            features: int, 
            self_attention_block: MultiHeadAttentionBlock,
            cross_attention_block: MultiHeadAttentionBlock,
            feed_forward_block: FeedForwardBlock,
            dropout: float = Config.dropout_constant
        ) -> None:
        """_summary_

        Args:
            features (int): The number of features in the input tensor.
            self_attention_block (MultiHeadAttentionBlock): The self attention layer to apply to the input tensor.
            cross_attention_block (MultiHeadAttentionBlock): The source attention layer to apply to the input tensor.
            feed_forward_block (FeedForwardBlock): The feed forward layer to apply to the input tensor.
            dropout (float, optional): The dropout rate to apply to the output tensor. Defaults to Config.dropout_constant.
        """
        super().__init__()

        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_bock = feed_forward_block
        self.residual_connection = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(
        self,
        x: torch.Tensor, 
        encoder_output: torch.Tensor, 
        src_mask: torch.Tensor, 
        tgt_mask: torch.Tensor
    ) -> torch.Tensor:
        """Applies the self attention layer to the input tensor.
        Adds the input tensor to the output tensor.
        Applies the source attention layer to the output tensor.
        Adds the input tensor to the output tensor.
        Applies the feed forward layer to the output tensor.
        Adds the input tensor to the output tensor.

        Args:
            x (torch.Tensor): The input tensor to decode.
            encoder_output (torch.Tensor): The output tensor from the encoder.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores for the source tensor.
            tgt_mask (torch.Tensor): The mask tensor to apply to the attention scores for the target tensor.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connection[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connection[2](x, self.feed_forward_bock)
        return x


In [11]:
class Decoder(nn.Module):
    def __init__(self, features:int, layers: nn.ModuleList) -> None:
        """Initializes the Decoder module. This module is used to decode the input tensor.

        Args:
            features (int): The number of features in the input tensor.
            layers (nn.ModuleList): The layers to apply to the input tensor.
        """
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(
        self, 
        x: torch.Tensor, 
        encoder_output: torch.Tensor, 
        src_mask: torch.Tensor, 
        tgt_mask: torch.Tensor
    ) -> torch.Tensor:
        """Applies the layers to the input tensor. Applies the layer normalization to the output tensor.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to decode.
            encoder_output (torch.Tensor): The output tensor from the encoder.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores for the source tensor.
            tgt_mask (torch.Tensor): The mask tensor to apply to the attention scores for the target tensor.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
    
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)




## Projection Layer

In [12]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        """Initializes the ProjectionLayer module. 
        This module is used to project the input tensor into a higher dimension.

        Args:
            d_model (int): The number of dimensions of the input tensor.
            vocab_size (int): The number of unique tokens in the input.
        """
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Projects the input tensor into a higher dimension using a linear layer.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to project.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, vocabulary_size).
        """
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return torch.log_softmax(self.proj(x), dim=-1)

## Transformer Block

![Transformer Block](images/TransformerBlock.png)

In [13]:
class Transformer(nn.Module):
    def __init__(
            self, 
            encoder: Encoder,
            decoder: Decoder,
            src_embed: InputEmbeddings,
            tgt_embed: InputEmbeddings,
            src_pos: PositionalEncoding,
            tgt_pos: PositionalEncoding,
            projection_layer: ProjectionLayer
        ) -> None:
        """Initializes the Transformer module. This module is used to encode and decode the input tensor.

        Args:
            encoder (Encoder): The encoder to encode the input tensor.
            decoder (Decoder): The decoder to decode the input tensor.
            src_embed (InputEmbeddings): The input embedder to embed the input tensor.
            tgt_embed (InputEmbeddings): The target embedder to embed the input tensor.
            src_pos (PositionalEncoding): The positional encoder to encode the input tensor.
            tgt_pos (PositionalEncoding): The positional encoder to encode the target tensor.
            projection_layer (ProjectionLayer): _description_
        """
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
    
    def encode(self, src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        """Embeds the input tensor. Encodes the input tensor.

        Args:
            src (torch.Tensor): The input tensor (batch_size, sequence_length) to encode.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(
        self,
        encoder_output: torch.Tensor, 
        src_mask: torch.Tensor, 
        tgt: torch.Tensor, 
        tgt_mask: torch.Tensor
    ) -> torch.Tensor:
        """Decodes the input tensor.

        Args:
            encoder_output (torch.Tensor): The output tensor from the encoder.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores for the source tensor.
            tgt (torch.Tensor): The target tensor.
            tgt_mask (torch.Tensor): The mask tensor to apply to the attention scores for the target tensor.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x: torch.Tensor) -> torch.Tensor:
        """Projects the input tensor into a higher dimension.

        Args:
            x (torch.Tensor): The input tensor to project.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, vocabulary_size).
        """
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)

## Build Transformer

In [14]:
def build_transformer(
    src_vocab_size: int,
    tgt_vocab_size: int,
    src_seq_len: int,
    tgt_seq_len: int,
    d_model: int,
    num_layers: int,
    num_heads: int,
    dropout: float,
    d_ff: int
) -> Transformer:
    """Builds the transformer model.

    Args:
        src_vocab_size (int): The number of unique tokens in the input.
        tgt_vocab_size (int): The number of unique tokens in the target.
        src_seq_len (int): The maximum length of the input sequence.
        tgt_seq_len (int): The maximum length of the target sequence.
        d_model (int): The number of dimensions to embed the input tokens into.
        num_layers (int): The number of layers to apply to the input tensor.
        num_heads (int): The number of heads to split the input tensor into.
        dropout (float): The dropout rate to apply to the output tensor.
        d_ff (int): The number of dimensions to project the input tensor into.

    Returns:
        Transformer: Returns a transformer model.
    """
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positioonal encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(num_layers):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)
    
    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(num_layers):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(
            d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout
        )
        decoder_blocks.append(decoder_block)
    
    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
        
    return transformer


# Training a Translation Model

## Tokenizer

In [None]:
def get_all_sentences(ds: Dataset, lang: str):
    """A generator to get all sentences in a dataset.

    Args:
        ds (Dataset): The dataset.
        lang (str): The language.

    Yields:
        _type_: A generator.
    """
    for item in ds:
        yield item['translation'][lang]

In [17]:
def get_or_build_tokenizer(config: Config, ds: Dataset, lang: str):
    """Gets the tokenizer if it exists, otherwise builds it and saves it.

    Args:
        config (Config): The configuration.
        ds (Dataset): The dataset.
        lang (str): The language.

    Returns:
        _type_: The tokenizer.
    """
    tokenizer_path = Path(config.tokenizer_file.format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer


## Load Dataset

In [18]:
def casual_mask(size: int) -> torch.Tensor:
    """Create a causal mask for the decoder attention."""
    # torch.triu: returns the upper triangular part of a matrix (2-D tensor) or batch of matrices
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0


In [36]:
class BilingualDataset(Dataset):
    def __init__(self, ds: Dataset, tokenizer_src, tokenizer_tgt, src_lang: str, tgt_lang: str, seq_len: int) -> None:
        """
        Constructor for the BilingualDataset class:
        - dataset: the dataset to use
        - tokenizer_source: the tokenizer for the source language
        - tokenizer_target: the tokenizer for the target language
        - source_language: the source language
        - target_language: the target language
        - sequence_length: the sequence length to use
        """
        super().__init__()
        self.seq_len = seq_len
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        # Start of Sentence, End of Sentence and Padding Tokens
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        """Get the length of the dataset."""
        return len(self.ds)
    
    def truncate_sequence(self, sequence, max_len: int):
        """Truncates a sequence of tokens to the specified maximum length.

        Args:
            sequence: A list of integer tokens representing the sentence.
            max_len: The maximum allowed length for the sequence.

        Returns:
            A list of truncated tokens and the original length of the sequence.
        """
        if len(sequence) <= max_len:
            return sequence, len(sequence)
        else:
            # Truncate from the beginning (optional: modify to truncate from the end)
            truncated_sequence = sequence[:max_len]
            return truncated_sequence, len(sequence)


    def __getitem__(self, idx) -> Dict:
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        # Transform the text into tokens & Truncate sentences if necessary
        # Note that for sentences > seq_len, by adding -2 and -1, I am choosing to not have any padding, 
        # cause enc_num_padding_tokens and dec_num_padding_tokens will turn out to be = 0
        enc_input_tokens, orig_enc_len = self.truncate_sequence(self.tokenizer_src.encode(src_text).ids, self.seq_len - 2)
        dec_input_tokens, orig_dec_len = self.truncate_sequence(self.tokenizer_tgt.encode(tgt_text).ids, self.seq_len - 1)

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # -2 cause we will add <s> and </s>
        # We will only add <s> (SOS), and </s> (EOS) only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1


        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long.
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError(
                f"Original sentence length exceeded allowed limit of seq_len - {self.seq_len}. "
                f"Original lengths: source - {orig_enc_len}, target - {orig_dec_len}"
            )

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
            ],
            dim=0,
        )
        
        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
            ],
            dim=0
        )
        
        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len, f"encoder_input size: {encoder_input.size(0)} and seq_len: {self.seq_len}"
        assert decoder_input.size(0) == self.seq_len, f"decoder_input size: {decoder_input.size(0)} and seq_len: {self.seq_len}"
        assert label.size(0) == self.seq_len, f"label size: {label.size(0)} and seq_len: {self.seq_len}"

        return {
            "encoder_input": encoder_input,   # (seq_len)
            "decoder_input": decoder_input,   # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),  # (1, 1, seq_len)
            # (1, seq_len) & (1, seq_len, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & casual_mask(decoder_input.size(0)),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text
        }


In [20]:
def get_ds(config: Config) -> Tuple:
    """Loads the dataset, builds the tokenizers and returns the dataloaders.

    Args:
        config (Config): The configuration.

    Returns:
        Tuple: The training and validation dataloaders.
    """
    # It only has the train split, so we divide it overselves
    ds_raw = load_dataset(f"{config.datasource}", f"{config.lang_src}-{config.lang_tgt}", split="train")

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config.lang_src)
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config.lang_tgt)

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    # Need to create the tensors the model will use
    train_ds = BilingualDataset(
        train_ds_raw, tokenizer_src, tokenizer_tgt, config.lang_src, config.lang_tgt, config.seq_len
    )
    val_ds = BilingualDataset(
        val_ds_raw, tokenizer_src, tokenizer_tgt, config.lang_src, config.lang_tgt, config.seq_len
    )

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config.lang_src]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config.lang_tgt]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
    
    print(f"Max length of source sentence: {max_len_src}")
    print(f"Max length of target sentence: {max_len_tgt}")

    train_dataloader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt


## Train Model

![Training Model](images/TrainingModel.png)


In [42]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len: int, device):
    """
    Greedy Decode:
    - Finds the most likely next token at each step and appends it to the decoder input.
    - Just picks the token with the max probabilitiesability at each step as the next token.
    - May not give the ideal output, but it's fast and simple. For better results, use beam search.
    
    Parameters:
    - model: The transformer model
    - source: The input sentence
    - source_mask: The mask for the input sentence
    - source_tokenizer: The tokenizer for the source language
    - target_tokenizer: The tokenizer for the target language
    - max_length: The maximum length of the output sentence
    - device: The device to run the model on
    
    Shapes:
    - source: (1, sequence_length, dimensions)
    - source_mask: (1, 1, sequence_length, sequence_length)
    
    Returns: 
    - The output sentence
    """
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)

    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = casual_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break
    
    return decoder_input.squeeze(0)



In [24]:
def get_model(config: Config, vocab_src_len: int, vocab_tgt_len: int) -> Transformer:
    """Builds the transformer model

    Args:
        config (Config): The configuration.
        vocab_src_len (int): The length of the source vocabulary.
        vocab_tgt_len (int): The length of the target vocabulary.

    Returns:
        Transformer: The transformer model.
    """
    model = build_transformer(
        vocab_src_len,
        vocab_tgt_len,
        config.seq_len,
        config.seq_len,
        config.d_model,
        config.num_layers,
        config.num_heads,
        config.dropout_constant,
        config.d_ff,
    )
    return model


In [47]:
def run_validation(
    model,
    validation_ds,
    tokenizer_src,
    tokenizer_tgt,
    max_len: int,
    device,
    print_msg,
    global_step,
    writer,
    num_examples: int = 2
):
    """
    Run Validation:
    - Runs the model on the validation dataset and prints the source, target and predicted sentences.
    
    Parameters:
    - model: The transformer model
    - validation_ds: The validation dataset
    - source_tokenizer: The tokenizer for the source language
    - target_tokenizer: The tokenizer for the target language
    - max_length: The maximum length of the output sentence
    - device: The device to run the model on
    - print_message: A function to print messages
    - global_step: The global step
    - num_examples: The number of examples to print
    
    Shapes:
    - source: (1, sequence_length, dimensions)
    - source_mask: (1, 1, sequence_length, sequence_length)
    
    Returns:
    - Inference
    """
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except Exception:
        # If we can't get the console width, use 80 as default
        console_width = 80
    
    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch['encoder_input'].to(device)  # (b, seq_len)
            encoder_mask = batch['encoder_mask'].to(device)  # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(0)==1, "Batch size must be 1 for validation."

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{'SOURCE: ':>12}{source_text}")
            print_msg(f"{'TARGET: ':>12}{target_text}")
            print_msg(f"{'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-' * console_width)
                break
    
    if writer:
        # Compute the char error rate
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()


In [49]:
def train_model(config: Config):
    """Trains the transformer model.

    Args:
        config (Config): The configuration.
    
    Steps:
        - Get the dataset
        - Get the model
        - Define the optimizer
        - Define the loss function
        - For each epoch:
            - For each batch:
                - Run the tensors through the encoder, decoder and the projection layer
                - Compare the output with the label
                - Compute the loss using a simple cross entropy
                - Backpropagate the loss
                - Update the weights
            - Run validation at the end of every epoch
            - Save the model at the end of every epoch
    """
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")

    if (device == "cuda"):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == "mps"):
        print("Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
    
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config.datasource}_{config.model_folder}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config.experiment_name)

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config.preload
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f"Preloading model {model_filename}")
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print("No model to preload, starting from scratch.")
    
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config.num_epochs):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch: 02d}")

        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)  # (B, seq_len)
            decoder_input = batch['decoder_input'].to(device)  # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device)  # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device)  # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask)  # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)  # (B, seq_len, d_model)
            proj_output = model.project(decoder_output)  # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device)  # (B, seq_len)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss" : f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1
        
        # Run validation at the end of every epoch
        run_validation(
            model,
            val_dataloader,
            tokenizer_src,
            tokenizer_tgt,
            config.seq_len,
            device,
            lambda msg: batch_iterator.write(msg),
            global_step,
            writer
        )

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step
        }, model_filename)



# Main

In [None]:
# Load the Saved Weights

# Adjust Training Loop : epoch counter

In [50]:
config = Config()
warnings.filterwarnings("ignore")
train_model(config)

Using device: cuda
Device name: NVIDIA GeForce RTX 3080 Ti Laptop GPU
Device memory: 15.99951171875 GB
Max length of source sentence: 471
Max length of target sentence: 482
No model to preload, starting from scratch.


Processing Epoch  0: 100%|██████████| 14297/14297 [1:50:54<00:00,  2.15it/s, loss=1.781] 


--------------------------------------------------------------------------------
    SOURCE: "Well, either that boot comes back before sundown or I'll see the manager and tell him that I go right straight out of this hotel."
    TARGET: – Écoutez-moi : ou bien ce soulier me sera rendu avant ce soir, ou bien je me rends chez le directeur pour lui annoncer que je quitte immédiatement cet hôtel.
 PREDICTED: Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je Je J

Processing Epoch  1: 100%|██████████| 14297/14297 [3:43:50<00:00,  1.06it/s, loss=1.661]      


--------------------------------------------------------------------------------
    SOURCE: See how les petites cheries step out for the credit of their master.
    TARGET: Voyez comme elles vont, les petites chéries, pour faire honneur à leur maître.
 PREDICTED: Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendre

Processing Epoch  2:   2%|▏         | 347/14297 [02:25<1:37:36,  2.38it/s, loss=1.609]


KeyboardInterrupt: 

## Validation

In [51]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

# Load the pretrained weights
model_filename = latest_weights_file_path(config)
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])


Max length of source sentence: 471
Max length of target sentence: 482


<All keys matched successfully>

In [53]:
run_validation(
    model, 
    val_dataloader, 
    tokenizer_src, 
    tokenizer_tgt, 
    config.seq_len, 
    device, 
    lambda msg: print(msg), 
    0, 
    None, 
    num_examples=5
)

--------------------------------------------------------------------------------
    SOURCE: It was she who attacked me--I was in a narrow and shallow bay--the frigate barred my way--and I sank her!"
    TARGET: J'étais resserré dans une baie étroite et peu profonde!... il me fallait passer, et... j'ai passé!»
 PREDICTED: Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi V

## Inference

![Inference Model Time Step 1](images/InferenceModelTimeStep1.png)

![Inference Model Time Step 4](images/InferenceModelTimeStep4.png)

- We selected, at every step, the word with the maximum softmax value. This strategy is called *greedy* and usually does not perform very well.
- A better strategy is to select at each step the top B words and evaluate all the possible next words for each of them and at each step, keeping the top B most probable sequences. This is the *Beam Search strategy* and generally performs better

In [57]:
def translate(config: Config, sentence: str):
    # Define the device, tokenizers, and model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer_src = Tokenizer.from_file(str(Path(config.tokenizer_file.format(config.lang_src))))
    tokenizer_tgt = Tokenizer.from_file(str(Path(config.tokenizer_file.format(config.lang_tgt))))
    model = build_transformer(
        tokenizer_src.get_vocab_size(),
        tokenizer_tgt.get_vocab_size(),
        config.seq_len,
        config.seq_len,
        config.d_model
    ).to(device)

    # Load the pretrained weights
    model_filename = latest_weights_file_path(config)
    state = torch.load(model_filename)
    model.load_state_dict(state['model_state_dict'])

    # translate sentence
    seq_len = config.seq_len
    model.eval()
    with torch.no_grad():
        # Precompute the encoder output and reuse it for every generation step
        source = tokenizer_src.encode(sentence)
        source = torch.cat(
            [
                torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
                torch.tensor(source.ids, dtype=torch.int64),
                torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
                torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (seq_len - len(source.ids) - 2), dtype=torch.int64)
            ], dim=0
        ).to(device)
        source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
        encoder_output = model.encode(source, source_mask)

        # Initialize the decoder input with the sos token
        decoder_input = torch.empty(1, 1).fill_(tokenizer_tgt.token_to_id('[SOS]')).type_as(source).to(device)

        # Generate the translation word by word
        while decoder_input.size(1) < seq_len:
            # build mask for target and calculate output
            decoder_mask = torch.triu(
                torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1
            ).type(torch.int).type_as(source_mask).to(device)
            out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

            # project next token
            prob = model.project(out[:, -1])
            _, next_word = torch.max(prob, dim=1)
            decoder_input = torch.cat(
                [
                    decoder_input,
                    torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)
                ], dim=1
            )

            # print the translated word
            print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')

            # break if we predict the end of sentence token
            if next_word == tokenizer_tgt.token_to_id('[EOS]'):
                break
    
    # convert ids to tokens
    return tokenizer_tgt.decode(decoder_input[0].tolist())


In [58]:
translate(config, "Good morning transformer.")

Using device: cuda
Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi 

'Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi Vendredi 