# Building Transformers

**Author: Ümit Mert Çağlar**

## Introduction

In this notebook, we will be demonstrating how to use the torch libraries to build transformers. We will be building transformers following the original architecture from the paper [Attention Is All You Need](https://arxiv.org/pdf/1706.03762). This notebook is the solution notebook where every code is working as intended. The counterpart of this notebook is the exercise notebook which should be completed prior to this one.

<img src="transformer.png" width="400">


## Imports

We will begin with importing necessary libraries.

This part might be required to be operational on the Colab servers for later parts. It is advised to check the associated python requirements.txt, that is frozen at the time of preparation of this notebook, in case of any library or version error occurs while running this notebook. Mind that installing everything locally via pip install -r "requirements.txt" is not advised though, mainly because of the discrepancies between Colab and locally available machine.

The [torch](https://pytorch.org/) is a popular and diverse machine learning framework, enabling low level implementation (as low as it gets with Python anyway). The Neural Networks (nn) is a library within PyTorch that enables operations with neural network structures.

In [1]:
import math
import torch
import torchmetrics
import torch.nn as nn

from tqdm import tqdm
from dataclasses import dataclass
from torch import Tensor
from pathlib import Path
from typing import Optional, Dict, Any, Callable, Iterator
from tokenizers import Tokenizer
from torch.utils.data import Dataset, Subset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.trainers import WordLevelTrainer





## Embeddings

#### Instructions
* Specify the PyTorch class that the positional encoder should subclass from.
* Initialize a positional encoding matrix for token positions in sequences up to max_length.
* Assign unique position encodings to the matrix pe by alternating the use of sine and cosine functions.
* Update the input embeddings tensor x to add position information about the sequence using the positional encodings matrix.


Remember we require embeddings to translate human words into machine readable format **and** add position information for the transformer architecture.

<img src="embeddings.png" width="600">

### Input Embeddings

The input tokens passed through the Transformer model are first converted to vectors of dimension $d_{model}$ through a fixed **Input Embeddings**.

The output of the embedding layer is scaled with $\sqrt{d_{model}}$


In [2]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        """
        Embedding layer for input tokens

        Args:
            d_model (int): Hidden dimension of the model. The size of the vector
                representations (embeddings / hidden states) used throughout the
                Transformer model.
            vocab_size (int): Size of the vocabulary. Number of unique tokens in the
                input data.
        """
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size

        # Input to embedding layer: (*)
        # Output from embedding layer: (*, H), where H is the hidden dim of the model.
        
        # TODO: Create an embedding layer of size (vocab_size, d_model)
        self.embedding = ...

    def forward(self, x: Tensor) -> Tensor:
        """
        Embed input tokens.

        Args:
            x (Tensor): Input tokens of shape `(bs, seq_len)`.

        Returns:
            Tensor: Embedded input of shape `(bs, seq_len, d_model)`.
        """
        # seq_len dimension contains token ids that can be mapped back to unique word
        
        # TODO: Return the result of the embedded `x` tensor
        return ...

### Positional Embeddings

<div>
<img src="https://github.com/caglarmert/DI725/blob/main/src/PE_highlight.png?raw=true" width="300"/>
</div>

To be able to use the position information, transformers use **Positional Encoding**.

We will be using sine and cosine waves for even and odd positions:
$$
\text{PE}_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}}) 
$$

$$
\text{PE}_{(pos, 2i + 1)} = \cos(pos / 10000^{2i/d_{model}})
$$

In [3]:
class PositionalEncoding(nn.Module):
    pe: Tensor

    def __init__(self, d_model: int, max_seq_len: int, dropout: float = 0.1):
        """
        Positional encoding / embeddings for input tokens

        Args:
            d_model (int): Hidden dimension of the model. The size of the vector
                representations (embeddings / hidden states) used throughout the
                Transformer model.
            max_seq_len (int, optional): Maximum sequence length.
            dropout (float, optional): Dropout rate. Defaults to 0.1.
        """
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        # TODO: apply dropout
        self.dropout = ...

        # Create positional encodings of shape (max_seq_len, d_model)
        pe = torch.zeros(max_seq_len, d_model)

        # Create tensor of shape (max_seq_len, 1)
        pos = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)

        # PE division term => 10000^(2 * i / d_model)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        
        # TODO: Use sine and cosine functions for even and odd positions
        pe[:, 0::2] = ...  # Even positions
        pe[:, 1::2] = ...  # Odd positions

        # Add batch dimension to positional encodings
        pe = pe.unsqueeze(0)  # (1, max_seq_len, d_model)

        # Tensor is saved in file when model is saved.
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Apply positional encoding to input embeddings.

        Args:
            x (Tensor): Input embeddings of shape `(bs, seq_len, d_model)`.

        Returns:
            Tensor: Positional encodings of shape `(bs, seq_len, d_model)`.
        """

        # Add positional encodings to input embeddings
        seq_len = x.size(1)
        # Shorten positional encodings if seq_len is greater than max_seq_len
        pe_out = (self.pe[:, :seq_len, :]).requires_grad_(False)
        
        # TODO: Add the positional information onto the input tensor `x`
        x = ...
        
        # TODO: Apply dropout
        return ...

In the __init__ method, we first initialize the superclass nn.Module and then define the model's dimension d_model and the maximum sequence length max_length. We then create a zero matrix pe of size max_length by d_model to store the positional encodings.

Next, we calculate the positional encodings. We create a tensor position that contains the sequence positions and a tensor div_term that contains the division terms. The division terms are calculated using a formula that involves the natural logarithm of 10000 and the model's dimension. We then calculate the positional encodings by applying the sine function to the product of position and div_term for even indices and the cosine function for odd indices. The calculated positional encodings are then stored in the pe matrix.

In the forward method, we add the positional encodings to the input embeddings tensor x. We slice the pe matrix to match the size of x before adding. The updated tensor x is then returned.

## Attention Layers

The **Multi-Head Attention block** is where the attention mechanism exists. It is computed fundamentally with scaled dot product attention.

<div>
<img src="https://github.com/caglarmert/DI725/blob/main/src/multi_headed_attention_highlighted.png?raw=true" width="300"/>
</div>


There are 2 types of attention in the Transformer: **Self-Attention** and **Cross-Attention**. They both use the same multi-head attention mechanism.

The primary differences are:

- **Self-Attention:** Queries, keys, and values come from the same input sequence.

- **Cross Attention:** Queries come from the decoder’s hidden state and keys and values come from the encoder’s outputs.

<br>

<img src="attention.png" width="600">

### Scaled Dot Product Attention

$$
\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V
$$

We will be using $d_{\text{model}} = 512$ with $h=8$ parallel attention layers (heads).

Therefore, the dimension of queries, keys, and values will be $d_q=d_k=d_v=d_{model}/h = 64$.

### Multi-Headed Attention

Each head of attention is computed with the above scaled dot-product attention and then concatenated.

$$
\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \cdots, \text{head}_h) W^O
$$
$$
\text{where head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$$


#### Instructions
* Split the sequence embeddings x across the multiple attention heads.
* Compute dot-product based attention scores between the project query and key.
* Normalize the attention scores to obtain attention weights.
* Multiply the attention weights by the values and linearly transform the concatenated outputs per head.

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads  # aka. `h`

        assert d_model % num_heads == 0  # Ensure d_model is divisible by num_heads

        # TODO: Calculate the dimension of d_k, d_v
        
        self.d_k = ...

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def scaled_dot_product_attention(
        q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None, dropout: nn.Dropout | None
    ) -> tuple[Tensor, Tensor]:
        """Compute Scaled Dot Product Attention."""

        d_k = q.shape[-1]

        # TODO: Compute attention scores (not applying softmax) with
        # @ operation (matrix multiply) and .transpose()

        
        # (bs, num_heads, seq_len, d_k) -> (bs, num_heads, seq_len, seq_len)
        scores = ...

        if mask is not None:
            # For all values in mask == 0 replace with -inf
            scores = scores.masked_fill(mask == 0, float("-inf"))

        # Each row is a query, each column is a key. You want to convert raw scores over keys
        # into a probability distribution. In other words, you want each row / query to have
        # weights that sum to 1.
        
        # TODO: Apply softmax to last dim
        scores = ...  # (bs, num_heads, seq_len, seq_len)

        if dropout is not None:
            scores = dropout(scores)

        # TODO: Multiply by values    
        weights = ...  # (bs, num_heads, seq_len, d_k)

        # We return the scores for visualization
        return weights, scores

    def forward(self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None) -> Tensor:
        """Compute Multi-Headed Attention."""

        query = self.W_q(q)  # (bs, seq_len, d_model) -> (bs, seq_len, d_model)
        key = self.W_k(k)  # (bs, seq_len, d_model) -> (bs, seq_len, d_model)
        value = self.W_v(v)  # (bs, seq_len, d_model) -> (bs, seq_len, d_model)

        # (bs, seq_len, d_model) -> (bs, seq_len, num_heads, d_k)
        query = query.view(query.shape[0], query.shape[1], self.num_heads, self.d_k)
        # (bs, seq_len, num_heads, d_k) -> (bs, num_heads, seq_len, d_k)
        query = query.transpose(1, 2)

        key = key.view(key.shape[0], key.shape[1], self.num_heads, self.d_k)
        key = key.transpose(1, 2)

        value = value.view(value.shape[0], value.shape[1], self.num_heads, self.d_k)
        value = value.transpose(1, 2)

        
        # TODO: apply MultiHeadAttention.scaled_dot_product_attention
        weights, scores = ...

        ### Perform concatenation of the heads ###

        # (bs, num_heads, seq_len, d_k) -> (bs, seq_len, num_heads, d_k)
        weights = weights.transpose(1, 2)

        # (bs, seq_len, num_heads, d_k) -> (bs, seq_len, d_model)
        concat = weights.contiguous().view(
            weights.shape[0], weights.shape[1], self.d_model
        )

        # (bs, seq_len, d_model) -> (bs, seq_len, d_model)
        return self.W_o(concat)

## Feed-Forward Network

<div>
<img src="https://github.com/caglarmert/DI725/blob/main/src/feed_forward_highlighted.png?raw=true" width="300"/>
</div>

The Feed-Forward Network (FFN) in the Transformer architecture is applied independently to each token position after the Multi-Head Self-Attention Mechanism. It consists of two linear layers with a non-linearity:

$$
\text{FFN}(x) = \text{ReLU}(xW_1 + b_1)W_2 + b_2
$$

The dimensionality of input and output is $d_{model} = 512$, and the inner-layer has dimensionality $d_{ff} = 2048$.


#### Instructions
* Specify in the __init__() method the sizes of the two linear fully connected layers.
* Apply a forward pass through the two linear layers, using the ReLU() activation in between.

In [5]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int = 512, d_ff: int = 2048, dropout: float = 0.1):
        super().__init__()

        # 'The dimensionality of input and output is d_model = 512, and the inner-layer
        # has dimensionality d_ff = 2048.'
        
        # TODO: Create the two linear transformations between
        self.linear1 = ...
        self.dropout = ...
        self.linear2 = ...

    def forward(self, x: Tensor) -> Tensor:
        """
        Applies linear transformations with ReLU activation function between.
            1. (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff)
            2. (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        
        # TODO: Create forward pass. Apply dropout after ReLU.
        x = ...
        x = ...
        x = ...
        x = ...

        return x

## Intermediate Layers

### Layer Normalization

LayerNorm operates independently on each sample within a batch, unlike BatchNorm, which normalizes across the batch dimension. It normalizes the inputs across the feature dimension.

**Purpose:** Mitigate internal covariate shift thus improving training speed, stability, and convergence of the model. Also, improves generalization.

In [6]:
class LayerNormalization(nn.Module):
    def __init__(self, eps: float = 1e-6):
        """
        LayerNorm operates independently on each sample within a batch, unlike
        BatchNorm, which normalizes across the batch dimension. It normalizes the
        inputs across the feature dimension.

        Purpose: Mitigate internal covariate shift thus improving training speed,
        stability, and convergence of the model. Also, improves generalization.

        Args:
            eps (float, optional): Epsilon value to avoid division by zero.
                Defaults to 1e-6.
        """
        super().__init__()
        self.eps = eps

        # Two learnable parameters
        self.alpha = nn.Parameter(torch.ones(1))  # Scale parameter (Multiplicative)
        self.bias = nn.Parameter(torch.zeros(1))  # Shift parameter (Additive)

    def forward(self, x: Tensor) -> Tensor:
        """
        Apply layer norm to last dimension of the input tensor.

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        # Apply mean & std to last dimension

        # TODO: Apply mean & std to last dimension
        mean = ...  # (bs, seq_len, 1)
        std = ... # (bs, seq_len, 1)

        return self.alpha * (x - mean) / (std + self.eps) + self.bias


### Residual Connections

The paper defines the residual connection implementation as
$$
\text{LayerNorm}(x + \text{Sublayer}(x))
$$

However, we will follow [The Annotated Transformer's](https://nlp.seas.harvard.edu/2018/04/03/attention.html) implementation by applying dropout to the output of each normalized sub-layer, before adding it to the input.


In [7]:
class ResidualConnection(nn.Module):
    def __init__(self, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization()

    def forward(self, x: Tensor, sublayer: nn.Module) -> Tensor:
        """
        Residual connection with layer normalization.

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.
            sublayer (nn.Module): The intermediate layer to wrap w/ residual connection.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        # TODO: Apply dropout to sublayer 
        
        return ...

### Linear Layer

This layer is a projection from $d_{model}$ into log probabilities across the entire vocab.

In [8]:
class LinearLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        """
        Linear Layer is a projection layer that converts the embedding into the
        vocabulary.

        Args:
            d_model (int): The size of the model's hidden dimension.
            vocab_size (int): The size of the vocabulary.
        """
        super().__init__()
        
        # TODO: Create a linear layer of size (d_model, vocab_size)
        self.linear = ...

    def forward(self, x: Tensor) -> Tensor:
        """
        Apply projection on embeddings.
        Output will be a log probability distribution over the vocabulary.

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.

        Returns:
            Tensor: `(bs, seq_len, vocab_size)`.
        """

        # (bs, seq_len, d_model) -> (bs, seq_len, vocab_size)
        # TODO: return log probabilities
        return ...

## Encoder-Decoder Structure
We can finally put everything together

<img src="transformer.png" width="400">

### Encoder

In [9]:
### ENCODER ###
class EncoderBlock(nn.Module):
    def __init__(
        self,
        self_attention_block: MultiHeadAttention,
        feed_forward_block: FeedForwardBlock,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList(
            [ResidualConnection(dropout) for _ in range(2)]
        )

    def forward(self, x: Tensor, src_mask: Tensor) -> Tensor:
        """
        Forward pass through the encoder block.

        Args:
            x (Tensor): `(bs, seq_len, d_model)`.
            src_mask (Tensor): The mask for the source language `(bs, 1, 1, seq_len)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        # TODO: build self attention blocks and residual connections
        x = self.residual_connections[0](
            x, lambda x: ...
        )
        x = ...
        return x


class Encoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(self, x: Tensor, src_mask: Tensor) -> Tensor:
        """
        Foward pass through the encoder.

        Args:
            x (Tensor): The input to the encoder.
            src_mask (Tensor): The mask for the source language.

        Returns:
            Tensor: A tensor of `(batch_size, seq_len, d_model)` represents a sequence
                of context-rich embeddings that encode the input sequence's semantic and
                positional information.
        """
        for layer in self.layers:
            x = layer(x, src_mask)

        # TODO: Apply a final layer normalization after all encoder blocks
        return ...

### Decoder

<div>
<img src="https://github.com/caglarmert/DI725/blob/main/src/Cross_attention.png?raw=true" width="300"/>
</div>

For **Cross Attention**, queries come from the decoder and keys and values come from the encoder’s outputs.

In [10]:
### DECODER ###
class DecoderBlock(nn.Module):
    def __init__(
        self,
        self_attention_block: MultiHeadAttention,
        cross_attention_block: MultiHeadAttention,
        feed_forward_block: FeedForwardBlock,
        dropout: float = 0.1,
    ):
        """
        Decoder block contains:
            1. (Masked Multi-Head Attention) A self-attention block where `qkv` come
                from decoder's input embedding.
            2. (Multi-Head Attention) A cross-attention block where `q` come from
                decoder and `k`,`v` come from encoder outputs.
        """
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList(
            [ResidualConnection(dropout) for _ in range(3)]
        )

    def forward(
        self,
        x: Tensor,
        encoder_output: Tensor,
        src_mask: Tensor,
        tgt_mask: Tensor,
    ) -> Tensor:
        """
        Forward pass through the decoder block.
        Decoder block ussed for machine-translation to go from source to target lang.

        Args:
            x (Tensor): The decoder input `(bs, seq_len, d_model)`.
            encoder_output (Tensor): `(bs, seq_len, d_model)`.
            src_mask (Tensor): `(bs, 1, 1, seq_len)`.
            tgt_mask (Tensor): `(bs, 1, seq_len, seq_len)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        x = self.residual_connections[0](
            x, lambda x: self.self_attention_block(x, x, x, tgt_mask)
        )

        # TODO: use encoder output here in cross-attention block
        x = self.residual_connections[1](
            x,
            lambda x: self.cross_attention_block(
                ...
            ),
        )
        # TODO: finish feed forward network and residual connections
        x = ...
        return x


class Decoder(nn.Module):
    def __init__(self, layers: nn.ModuleList):
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization()

    def forward(
        self,
        x: Tensor,
        encoder_output: Tensor,
        src_mask: Tensor,
        tgt_mask: Tensor,
    ) -> Tensor:
        """
        Forward pass through the decoder.

        Args:
            x (Tensor): The input to the decoder block.
            encoder_output (Tensor): The output from the encoder.
            src_mask (Tensor): The mask used for the source language (e.g. English).
            tgt_mask (Tensor): The mask used for the target language (e.g. Turkish).

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return self.norm(x)

### Causal Mask

In [11]:
def create_causal_mask(size: int) -> Tensor:
    """
    Causal mask used only in decoder to ensure that future is masked.
    https://discuss.huggingface.co/t/difference-between-attention-mask-and-causal-mask/104922
    """

    # Diagonal=1 to get a mask that does not include the main diagonal and only the
    # upper triangular part of the matrix excluding the main diagonal
    ones_matrix = torch.ones(1, size, size)
    # TODO: create a causal mask
    mask = ...

    # The above returns the mask for the upper diagonal which we DON'T want to include
    # in the causal mask. We want it False. Therefore, we use `mask == 0`.
    return mask == 0

## Transformer

In [12]:
### TRANSFORMER ###
class Transformer(nn.Module):
    def __init__(
        self,
        encoder: Encoder,
        decoder: Decoder,
        src_embed: InputEmbeddings,
        tgt_embed: InputEmbeddings,
        src_pos: PositionalEncoding,
        tgt_pos: PositionalEncoding,
        projection_layer: LinearLayer,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src: Tensor, src_mask: Tensor) -> Tensor:
        """Forward pass through the encoder with input tokens of type int64.

        Args:
            src (Tensor): `(bs, seq_len)`.
            src_mask (Tensor): `(bs, 1, 1, seq_len)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """

        # Embedding maps token ids to dense vectors of type float32
        # TODO: apply positional embedding and encoder layers to encoder input
        src = ...  # (bs, seq_len) -> (bs, seq_len, d_model)
        src = ...
        return ...

    def decode(
        self, encoder_output: Tensor, src_mask: Tensor, tgt: Tensor, tgt_mask: Tensor
    ) -> Tensor:
        """
        Forward pass through the decoder.
        - Encoder output is used in the cross-attention block and is of type float32.
        - Target tokens are still of type int64 and need to be embedded with input
        embeddings + positional encoding.

        Args:
            encoder_output (Tensor): `(bs, seq_len, d_model)`.
            src_mask (Tensor): `(bs, 1, 1, seq_len)`.
            tgt (Tensor): `(bs, seq_len)`.
            tgt_mask (Tensor): `(bs, 1, seq_len, seq_len)`.

        Returns:
            Tensor: `(bs, seq_len, d_model)`.
        """
        # TODO: Apply positional embedding and decoder layers to decoder input
        tgt = ...  # (bs, seq_len) -> (bs, seq_len, d_model)
        tgt = ...
        return ...

    def project(self, x: Tensor) -> Tensor:
        """
        Project the output of the decoder to the target vocabulary size.

        Args:
            x (Tensor): The output of the decoder `(bs, seq_len, d_model)`.

        Returns:
            Tensor: `(bs, seq_len, vocab_size)`.
        """

        return self.projection_layer(x)


def build_transformer(
    src_vocab_size: int,
    tgt_vocab_size: int,
    src_seq_len: int,
    tgt_seq_len: int,
    d_model: int = 512,  # hidden dimension of the model
    num_blocks: int = 6,  # number of encoder and decoder blocks
    num_heads: int = 8,  # number of attention heads
    d_ff: int = 2048,  # size of the feed-forward layer
    dropout: float = 0.1,
) -> Transformer:
    """Build and return Transformer."""

    # TODO: Create embedding layers
    src_embed = ...
    tgt_embed = ...

    # TODO: Create positional encoding layers
    src_pos = ...
    tgt_pos = ...

    # TODO: Create encoder blocks
    encoder_layers = nn.ModuleList(
        [
            EncoderBlock(
                ...
                ...
                ...
            )
            for _ in range(num_blocks)
        ]
    )

    # TODO: Create decoder blocks
    decoder_layers = nn.ModuleList(
        [
            DecoderBlock(
                ...
                ...
                ...
                ...
            )
            for _ in range(num_blocks)
        ]
    )

    # TODO: Create encoder and decoder
    encoder = ...
    decoder = ...

    # TODO: Create projection layer
    projection_layer = ...

    # TODO: Create transformer
    transformer = Transformer(
        ...
    )

    # Initialize parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)

    return transformer

# Training

## Training Config

In [13]:
@dataclass
class TransformerConfig:
    """
    Transformer training configuration
    """

    batch_size: int = 16
    num_epochs: int = 10
    lr: float = 1e-4
    seq_len: int = 350  # Max sequence length
    d_model: int = 512
    lang_src: str = "en"  # Source language: English
    lang_tgt: str = "it"  # Target language: Italian
    model_folder: str = "weights"
    model_filename: str = "transformer_"  # Base filename for saved weights
    load_from: Optional[str] = None  # Load from this epoch (e.g.)
    tokenizer_file: str = "tokenizer_0.json"
    experiment_name: str = "runs/transformer"

    def get_weights_file_path(self, epoch_name: str) -> str:
        """
        Get the file path for the model weights corresponding to a given epoch.

        Parameters:
            epoch_name (str): The epoch name to get weights file path for.

        Returns:
            str: The complete file path for the weights file.
        """
        return str(
            Path(".") / self.model_folder / f"{self.model_filename}{epoch_name}.pt"
        )


def get_config(overrides: Optional[Dict[str, Any]] = None) -> TransformerConfig:
    """
    Retrieve the default configuration for the Transformer model training.
    Optionally override configuration values by passing in a dictionary.

    Parameters:
        overrides (Optional[Dict[str, Any]]): Dictionary with keys corresponding to the
            TransformerConfig fields that should be overridden.

    Returns:
        TransformerConfig: An instance of TransformerConfig with default values.
    """
    config = TransformerConfig()
    if overrides:
        for key, value in overrides.items():
            if hasattr(config, key):
                setattr(config, key, value)
            else:
                raise KeyError(f"Invalid config key: {key}")
    return config



# Test get config
config = get_config({"batch_size": 32, "num_epochs": 5})
epoch = "epoch_1"
weights_path = config.get_weights_file_path(epoch)
print(f"Weights file path for epoch {epoch}: {weights_path}")

Weights file path for epoch epoch_1: weights\transformer_epoch_1.pt


## Dataset

In [14]:
class BilingualDataset(Dataset):
    def __init__(
        self,
        dataset: Subset,
        tokenizer_src: Tokenizer,
        tokenizer_tgt: Tokenizer,
        lang_src: str,
        lang_tgt: str,
        seq_len: int,
    ):
        self.dataset = dataset
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.lang_src = lang_src
        self.lang_tgt = lang_tgt
        self.seq_len = seq_len

        # Vocab can be longer than 32 bits so we use int64
        self.sos_token = torch.tensor(
            [tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64
        )
        self.eos_token = torch.tensor(
            [tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64
        )
        self.pad_token = torch.tensor(
            [tokenizer_src.token_to_id("[PAD]")], dtype=torch.int64
        )

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, index: Any) -> Dict[str, Any]:
        # Get the source and target text pair
        src_tgt_pair = self.dataset[index]

        # Extract the individual source and target text
        src_text = src_tgt_pair["translation"][self.lang_src]
        tgt_text = src_tgt_pair["translation"][self.lang_tgt]

        # Convert tokens to ids
        enc_input = self.tokenizer_src.encode(src_text).ids
        dec_input = self.tokenizer_tgt.encode(tgt_text).ids

        # -2 to account for the start and end tokens
        enc_num_padding_tokens = self.seq_len - len(enc_input) - 2

        # -1 to account for the end token
        dec_num_padding_tokens = self.seq_len - len(dec_input) - 1

        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long.")

        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input, dtype=torch.int64),
                self.eos_token,
                self.pad_token.repeat(enc_num_padding_tokens),  # Pad to reach seq_len
            ]
        )

        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input, dtype=torch.int64),
                self.pad_token.repeat(dec_num_padding_tokens),  # Pad to reach seq_len
            ]
        )

        label = torch.cat(
            [
                torch.tensor(dec_input, dtype=torch.int64),
                self.eos_token,
                self.pad_token.repeat(dec_num_padding_tokens),  # Pad to reach seq_len
            ]
        )

        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        encoder_mask = (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int()

        # Causal mask to prevent attending to future
        casual_mask = create_causal_mask(decoder_input.size(0))
        decoder_mask = (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int()
        decoder_mask = decoder_mask & casual_mask

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": encoder_mask,  # (1, 1, seq_len)
            "decoder_mask": decoder_mask,  # (1, seq_len, seq_len)
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }

## Trainer

In [15]:
class Trainer:
    """Transformer model training and validation"""

    def __init__(self, config: TransformerConfig) -> None:
        self.config = config

        # Define the device
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")

        print(f"Using device: {self.device}")

        # Create model folder
        Path(self.config.model_folder).mkdir(parents=True, exist_ok=True)

        train_dl, val_dl, tokenizer_src, tokenizer_tgt = self._get_dataset()
        self.model = self._get_model(
            tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()
        ).to(self.device)

        self.train_dl, self.val_dl = train_dl, val_dl
        self.tokenizer_src, self.tokenizer_tgt = tokenizer_src, tokenizer_tgt

        # 1. Ignore the padding ([PAD]) tokens
        # 2. Apply label smoothing - distributes X% of highest probability tokens to other tokens
        self.loss_fn = nn.CrossEntropyLoss(
            ignore_index=self.tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
        ).to(self.device)

        self.writer = SummaryWriter(self.config.experiment_name)  # Tensorboard

        # Create adam optimizer
        self.optimizer = torch.optim.Adam(
            self.model.parameters(), lr=self.config.lr, eps=1e-9
        )

        # Load existing model if it is specified / exists
        self._load_existing_model()

        self.global_step = 0
        self.initial_epoch = 0
        self.max_len = self.config.seq_len

    # Use word level tokenization
    def _get_all_sentences(self, dataset, lang) -> Iterator[str]:
        """Get all sentences in a given language from the dataset as a generator."""

        for item in dataset:
            yield item["translation"][lang]

    def _build_tokenizer(self, dataset, lang) -> Tokenizer:
        """Build / load a word-level tokenizer."""

        tokenizer_path = Path(self.config.tokenizer_file.format(lang))

        if not Path.exists(tokenizer_path):
            tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
            tokenizer.pre_tokenizer = Whitespace()

            # Train tokenizer with special tokens:
            # Unknown, padding, start of sentence, end of sentence
            trainer = WordLevelTrainer(
                special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2
            )

            tokenizer.train_from_iterator(
                self._get_all_sentences(dataset, lang), trainer=trainer
            )
        else:
            tokenizer = Tokenizer.from_file(str(tokenizer_path))

        return tokenizer

    def _get_dataset(self) -> tuple[DataLoader, DataLoader, Tokenizer, Tokenizer]:
        """Load the dataset to create dataloaders and tokenizers."""

        ds_raw = load_dataset(
            "opus_books",
            f"{self.config.lang_src}-{self.config.lang_tgt}",
            split="train",
        )

        # Build tokenizers
        tokenizer_src = self._build_tokenizer(ds_raw, self.config.lang_src)
        tokenizer_tgt = self._build_tokenizer(ds_raw, self.config.lang_tgt)

        # Keep 90% for training and 10% for validation
        train_size = int(0.9 * len(ds_raw))
        val_size = len(ds_raw) - train_size
        train_ds_raw, val_ds_raw = random_split(ds_raw, [train_size, val_size])

        # Create datasets
        train_dataset = BilingualDataset(
            train_ds_raw,
            tokenizer_src,
            tokenizer_tgt,
            self.config.lang_src,
            self.config.lang_tgt,
            self.config.seq_len,
        )
        val_dataset = BilingualDataset(
            val_ds_raw,
            tokenizer_src,
            tokenizer_tgt,
            self.config.lang_src,
            self.config.lang_tgt,
            self.config.seq_len,
        )

        # Find max sentence in src and tgt languages
        max_len_src = 0
        max_len_tgt = 0

        for item in ds_raw:
            src_text = item["translation"][self.config.lang_src]
            tgt_text = item["translation"][self.config.lang_tgt]

            src_ids = tokenizer_src.encode(src_text).ids
            tgt_ids = tokenizer_src.encode(tgt_text).ids
            max_len_src = max(max_len_src, len(src_ids))
            max_len_tgt = max(max_len_tgt, len(tgt_ids))

        print(f"Max length of source sentence: {max_len_src}")
        print(f"Max length of source sentence: {max_len_tgt}")

        train_dataloader = DataLoader(
            train_dataset, batch_size=self.config.batch_size, shuffle=True
        )
        val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=True)

        return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

    def _get_model(self, vocab_src_len, vocab_tgt_len) -> Transformer:
        """Get the transformer model."""

        model = build_transformer(
            vocab_src_len,
            vocab_tgt_len,
            self.config.seq_len,
            self.config.seq_len,
            self.config.d_model,
        )
        return model

    def _load_existing_model(self) -> None:
        """Load an existing model from a given epoch."""

        if self.config.load_from:
            epoch_name = self.config.load_from
            model_filename = self.config.get_weights_file_path(epoch_name)
            print(f"Preloading model {model_filename}")
            state = torch.load(model_filename)

            self.initial_epoch = state["epoch"] + 1
            self.model.load_state_dict(state["model_state_dict"])
            self.optimizer.load_state_dict(state["optimizer_state_dict"])
            self.global_step = state["global_step"]

    ### TRAINING CODE ###
    def train(self) -> None:
        """Train the transformer model."""

        for epoch in range(self.initial_epoch, self.config.num_epochs):
            batch_iterator = tqdm(self.train_dl, desc=f"Processing epoch {epoch:02d}")

            for batch in batch_iterator:
                self.model.train()

                encoder_input = batch["encoder_input"].to(self.device)  # (bs, seq_len)
                decoder_input = batch["decoder_input"].to(self.device)  # (bs, seq_len)

                # Attention mask (hide padding tokens)
                encoder_mask = batch["encoder_mask"].to(
                    self.device
                )  # (bs, 1, 1, seq_len)

                # Casual mask (hide padding tokens and future)
                decoder_mask = batch["decoder_mask"].to(
                    self.device
                )  # (bs, 1, seq_len, seq_len)

                # Input passthrough
                encoder_output = self.model.encode(
                    encoder_input, encoder_mask
                )  # (bs, seq_len, d_model)
                decoder_output = self.model.decode(
                    encoder_output, encoder_mask, decoder_input, decoder_mask
                )  # (bs, seq_len, d_model)
                proj_output = self.model.project(
                    decoder_output
                )  # (bs, seq_len, tgt_vocab_size)

                label = batch["label"].to(self.device)  # (bs, seq_len)

                # (bs, seq_len, tgt_vocab_size) -> (bs * seq_len, tgt_vocab_size)
                pred = proj_output.view(-1, self.tokenizer_tgt.get_vocab_size())
                gt = label.view(-1)  # Ground truth
                loss = self.loss_fn(pred, gt)
                batch_iterator.set_postfix({"Loss": f"{loss.item():6.3f}"})

                # Log the loss in tensorboard
                self.writer.add_scalar("Train Loss", loss.item(), self.global_step)
                self.writer.flush()

                # Backpropagation
                loss.backward()

                # Update weights
                self.optimizer.step()
                self.optimizer.zero_grad()

                self.global_step += 1

            model_filename = self.config.get_weights_file_path(f"{epoch:02d}")
            torch.save(
                {
                    "epoch": epoch,
                    "model_state_dict": self.model.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                    "global_step": self.global_step,
                },
                model_filename,
            )

            # Run validation at every 5th epoch
            if (epoch + 1) % 5 == 0:
                self.validate(lambda msg: batch_iterator.write(msg), self.writer)

    ### VALIDATION CODE ###
    def greedy_decode(self, src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        """
        Greedy decode for efficient validation.
        Highest probability token is selected at each step as the next word.
        """

        sos_idx = self.tokenizer_tgt.token_to_id("[SOS]")
        eos_idx = self.tokenizer_tgt.token_to_id("[EOS]")

        # Precompute the encoder output and reuse it for every token we get from decoder
        encoder_output = self.model.encode(src, src_mask)

        # Initialize the decoder input with the SOS token
        decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(src).to(self.device)

        # Keep predicting until we reach EOS or max_len
        while True:
            if decoder_input.size(1) == self.config.seq_len:
                break

            # Build the mask
            decoder_mask = create_causal_mask(decoder_input.size(1))
            decoder_mask = decoder_mask.type_as(src_mask).to(self.device)

            # Calculate the output of the decoder
            out = self.model.decode(
                encoder_output, src_mask, decoder_input, decoder_mask
            )

            # Get the next token
            prob = self.model.project(out[:, -1])  # the project of the last token

            # Select the token with the highest probability (because it is greedy search)
            _, next_word = torch.max(prob, dim=1)

            decoder_input = torch.cat(
                [
                    decoder_input,
                    torch.empty(1, 1)
                    .type_as(src)
                    .fill_(next_word.item())
                    .to(self.device),
                ],
                dim=1,
            )

            if next_word == eos_idx:
                break

        return decoder_input.squeeze(0)  # remove batch dimension

    def validate(self, print_msg: Callable, writer: SummaryWriter) -> None:
        """Run transformer model on the validation dataset."""

        self.model.eval()  # Put in eval mode
        count = 0

        source_texts = []
        expected = []
        predicted = []

        CONSOLE_WIDTH = 80

        with torch.no_grad():
            for batch in self.val_dl:
                # Stop validation after 5 examples
                if count >= 5:
                    break

                count += 1

                # (bs, seq_len)
                encoder_input = batch["encoder_input"].to(self.device)

                # (bs, 1, 1, seq_len)
                encoder_mask = batch["encoder_mask"].to(self.device)

                assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

                # Get generation
                model_out = self.greedy_decode(encoder_input, encoder_mask)

                source_text = batch["src_text"][0]
                target_text = batch["tgt_text"][0]

                # Detach model_out from computational graph
                model_out_tensor = model_out.detach().cpu()
                model_out_array = model_out_tensor.numpy()
                model_out_text = self.tokenizer_tgt.decode(model_out_array)

                source_texts.append(source_text)
                expected.append(target_text)
                predicted.append(model_out_text)

                # Print to the console
                print_msg("-" * CONSOLE_WIDTH)
                print_msg(f"SOURCE: {source_text}")
                print_msg(f"TARGET: {target_text}")
                print_msg(f"PREDICTED: {model_out_text}")

        if writer:
            # Compute the char error rate
            metric = torchmetrics.CharErrorRate()
            cer = metric(predicted, expected)
            writer.add_scalar("Validation CER", cer, self.global_step)
            writer.flush()

            # Compute the word error rate
            metric = torchmetrics.WordErrorRate()
            wer = metric(predicted, expected)
            writer.add_scalar("Validation WER", wer, self.global_step)
            writer.flush()

            # Compute the BLEU metric
            metric = torchmetrics.BLEUScore()
            bleu = metric(predicted, expected)
            writer.add_scalar("Validation BLEU", bleu, self.global_step)
            writer.flush()

## Conclusion

In [16]:
config = get_config()
trainer = Trainer(config)
#trainer.train()

Using device: cpu
Max length of source sentence: 309
Max length of source sentence: 274


In this notebook we have observed and practiced how to build transformers from zero.

# Resources
References and further tutorials to check out:
1. This notebook is built on top of the works of [Transformers From Scratch](https://github.com/aandyw/TransformerFromScratch) 
1. The paper, Attention is all you need, the transformer paper that introduced a whole new way of approaching deep learning: https://arxiv.org/abs/1706.03762
1. A solid blog that explains the Transformer architecture and its sub components: https://buomsoo-kim.github.io/attention/2020/04/19/Attention-mechanism-17.md/
1. Training a Transformer model, with a real dataset: https://buomsoo-kim.github.io/attention/2020/04/20/Attention-mechanism-18.md/
1. Further explanations on top of previous posts https://buomsoo-kim.github.io/attention/2020/04/21/Attention-mechanism-19.md/
1. Another Transformer tutorial: https://towardsdatascience.com/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
1. The annotated Transformer, updated for newer version of PyTorch, an excellent guide that uses the original authors sentences and implement everything (just like we did in this tutorial): https://nlp.seas.harvard.edu/annotated-transformer/
1. PyTorch's documentation for Transformers and other helper libraries: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
1. Dive into Deep Learning tutorial that uses their version of libraries, a different take on Transformer model implementation: https://d2l.ai/chapter_attention-mechanisms-and-transformers/self-attention-and-positional-encoding.html