# Transformer Model From Scratch

**References**
- *Coding a Transformer from scratch on PyTorch, with full explanation, training and inference: [Youtube Video](https://youtu.be/ISNdQcPhsts?si=M80AnG5chc6sKzk5), [Code](https://github.com/hkproj/pytorch-transformer)*
- *Attention in transformers, visually explained | Chapter 6, Deep Learning: [Youtube Video](https://youtu.be/eMlx5fFNoYc?si=_GGTHKDsoqi9KuVE)*
- *Illustrated Guide to Transformers Neural Network: A step by step explanation: [Youtube Video](https://youtu.be/4Bdc55j80l8?si=SbXl6rr0Sh4_3IuS)*
- *An Introduction to Transformers: [Paper](https://arxiv.org/abs/2304.10557)*
- *The Transformer Family Version 2.0: [Blog](https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/)*
- *A survey of transformers: [Paper](https://www.sciencedirect.com/science/article/pii/S2666651022000146)*



## Imports

In [None]:
# Torch
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

# General
import math
import warnings
from tqdm import tqdm
import os
from pathlib import Path
from typing import Dict, Tuple

# Huggingface
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

# lightning ai
import torchmetrics

# My python files
from config import Config
from helpers import latest_weights_file_path, get_weights_file_path, load_pkl_files, save_pkl_files

# Transformer Model

## Input Embeddings

- Input Matrix (sequence, d_model) - each word is made up of say 512 numbers then the shape of input matrix will be (x, 512) where x is the number of words.

![Input Embeddings](images/InputEmbeddings.png)

- Embedding isn't fixed, it is learned by the model.
- 

*Why are we multiplying the embedding output by the square root of the dimension?*


In [None]:
class InputEmbeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        """Input Embedding Module.
        Initializes the Embedder module. This module is used to embed the input tokens into a vector space.
        
        Args:
            d_model (int): Dimension of Input Embedding
            vocab_size (int): Vocabulary Size
        """
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Embeds the input tensor into a vector space. 
        This is done by looking up the embedding for each token in the input tensor.
        - Scales output by sqrt(dimensions) as per Vaswani et al. (2017).

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, sequence_length).

        Returns:
            torch.Tensor: Output tensor in R^d (batch_size, sequence_length, dimensions).
        """
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)


## Positional Encoding

- Because self-attention operation is permutation invariant, it is important to use proper positional encoding to provide order information to the model.
- Positional Encoding conveys information about the position of the word in a sentence.
- We want the model to treat words that appear close to each other as "close" and words that are distant as "distant"

![Positional Embedding](images/PositionalEmbedding.png)

- Unlike embeddings, positional embeddings aren't learned. We just compute this once. *Is this why we need a large dataset? Because we aren't learning positional encoding to tell the model appropriately what word can be closer to the other?*

**Sinusoidal Positional Encoding**

Sinusoidal positional encoding is defined as follows, given the token position $i = 1, ..., L$ and the dimension $\delta = 1, ..., d$:

$$
PE (i, \delta) = \begin{cases}
   sin \frac{i}{10000^{d \delta ' / d}} \text{if} \ \delta = 2 \delta ' \\
   cos \frac{i}{10000^{d \delta ' / d}} \text{if} \ \delta = 2 \delta ' + 1
\end{cases}
$$

In this way each dimension of the positional encoding corresponds to a sinusoid of different wavelengths in different dimensions, from $2 \pi$ to $10000 \cdot 2 \pi$

![Positional Embedding Vector](images/PositionalEmbeddingVector.png)

- *Why is there an odd and even position equation?*
- *Why do we want the positional encoding to represent a pattern that can be learned by the model? Is that why we use sine and cosine?*

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, seq_len: int, dropout: Config.dropout_constant) -> None:
        """Initializes the PositionalEncoder module. 
        This module is used to encode the position of the input tokens into the input embeddings.
        Create a matrix of shape (max_sequence_length, dimensions) to store the positional encodings.
        For each position in the sequence, compute the positional encoding for each dimension.
        Store the positional encodings in the matrix.
        Implement the positional encodings as per Vaswani et al. (2017).

        Args:
            d_model (int): The number of dimensions to embed the input tokens into.
            seq_len (int): The maximum length of the sentence since we need to create one vector for each position.
            dropout (float): The dropout rate to apply to the PEs.
        """
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)

        # Create a matrix of shape(seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len, 1)
        # torch.arange - returns a 1D tensor of size (end-start/step) with values from the interval [start, end) 
        # taken with common difference step beginning from start.
        # ? Why are we unsqueezing here?
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)  # (seq_len, 1)
        # ? Formula - I am not sure of the arange part here and how it translates to the equation above
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  # (d_model / 2)
        # ? It seems that you are doing this log this for numerical stability, need to look into why?
        # Apply the sin to even position & cosine to odd positions
        pe[:, 0::2] = torch.sin(position * div_term)  # sin(position * (10000 ** (2i / d_model))
        pe[:, 1::2] = torch.cos(position * div_term)  # cos(position * (10000 ** (2i / d_model))
        # Adding a new dimension to account for the batch of sentences, we use unsqueeze to do this.
        pe = pe.unsqueeze(0)  # (1, seq_len, d_model)

        # Register the positional encoding as a buffer
        # Why: to keep inside the module, not as a learned param, but to be saved along with the state of the model
        self.register_buffer('pe', pe)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Adds the positional encodings to the input tensor. Applies dropout to the output tensor.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to add positional encodings to.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # Since we are not learning this tensor, we add requires grad as False
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)  # (batch, seq_len, d_model)
        return self.dropout(x)


## Multi-Head Attention

### Attention

Attention is a mechanism in neural network that a model can learn to make predictions by selectively attending to a given set of data. The amount of attention is quantified by learned weights and thus the output is usually formed as a weighted average.

The output of the first stage of the transformer block is another $D \times N$ array, $Y^{(m)}$. The output is produced by aggregating information across the sequence independently for each feature using an operation called attention.

Specifically, the output vector at location n, denoted $y_n^{(m)}$, is produced by a simple weighted average of the input features at location $n' = 1 ... N$, denoted $x_{n'}^{(m-1)}$, that is

$$y_n^{(m)} = \sum_{n'=1}^N x_{n'}^{(m-1)} A_{n', n}^{(m)}$$

**Relationship to Convolutional Neural Networks**
- The attention mechanism can recover convolutional filtering as a special case e.g if $x_n^{(0)}$ is a 1D regularly sampled time-series and $A_{n', n}^{(m)} = A_{n' - n}^{(m)}$ then the attention mechanism above becomes a convolution. Unlike normal CNNs, these filters have full temporal support. Later we will see that the filters themselves dynamically depend on the input, another difference from standard CNNs. We will also see a similarity: transformers will use multiple attention maps in each layer in the same way that CNNs use multiple filters (though typically transformers have fewer attention maps than CNNs have channels).

Here the weighting is given by a so-called attention matrix $A_{n', n}^{(m)}$ which is of size $N \times N$ and normalises over its columns $\sum_{n'=1}^{N} A_{n', n}^{(m)} = 1$. Intuitively speaking $A_{n', n}^{(m)}$ will take a high value for locations in the sequence n' which are of high relevance for location n.  For irrelevant locations, it will take the value 0. For example, all patches of a visual scene coming from a single object might have high corresponding attention values. We can compactly write the relationship as a matrix multiplication,

$$Y^{(m)} = X^{(m-1)} A^{(m)}$$

and we illustrate it below:

![Element of the Attention Mechanism](images/ElementoftheAttentionMechanism.png)



### Self Attention

**Where does the attention matrix come from?**

A simple way of generating the attention matrix from the input would be to measure the similarity between two locations by the dot product between the features at those two locations and then use a softmax function to handle the normalisation. However, this naïve approach entangles information about the similarity between locations in the sequence with the content of the sequence itself. An alternative is to perform the same operation on a linear transformation of the sequence, $Ux_n$, so that

$$A_{n,n'} = \frac{exp(x_n^{\top} U^{\top} Ux_{n'})}{ \sum_{n''=1}^{N} exp(x_{n''}^{\top} U^{\top} Ux_{n'})}$$

Typically, U will project to a lower dimensional space i.e. U is $K \times D$ dimensional with $K < D$. In this way only some of the features in the input sequence need be used to compute the similarity, the others being projected out, thereby decoupling the attention computation from the content. However, the numerator in this construction is symmetric. This could be a disadvantage. For example, we might want the word ‘caulking iron’ to be strongly associated with the word ‘tool’ (as it is a type of tool), but have the word ‘tool’ more weakly associated with the word ‘caulking iron’ (because most of us rarely encounter it).

Fortunately, it is simple to generalise the attention mechanism above to be asymmetric by applying two different linear transformations to the original sequence

$$A_{n,n'} = \frac{exp(x_n^{\top} U_k^{\top} U_q x_{n'})}{ \sum_{n''=1}^{N} exp(x_{n''}^{\top} U_k^{\top} U_q x_{n'})}$$

The two quantities that are dot-producted together here $q_n = U_q x_n$ and $k_n = U_k x_n$ are typically known as the *queries* and the *keys*, respectively.

**Relationship to Recurrent Neural Networks**
- It is illuminating to compare the temporal processing in the transformer to that of RNNs which recursively update a hidden state feature representation $(x_n^{(1)})$ based on the current observation $(x_n^{(0)})$ and the previous hidden state $(x_n^{(1)}) = f(x_{n-1}^{(1)} ; x_{n}^{(0)}) = f(f(x_{n-2}^{(1)} ; x_{n-1}^{(0)}); x_n^{(0)})$. Here we’ve unrolled the RNN one step to show that observations which are nearby to the hidden state (e.g. $x_n^{(0)}$) are treated differently from observations that are further away (e.g. $x_{n-1}^{(0)}$), as information is propagated by recurrent application of the function $f(\cdot)$. In contrast, in the transformer, self-attention treats all observations at all time-points in an identical manner, no matter how far away they are. This is one reason why they find it simpler to learn long-range relationships.


### How to compute Self Attention

- Self-Attention is permutation invariant.

![Self Attention Part 1](images/SelfAttentionPart1.png)

![Self Attention Part 2](images/SelfAttentionPart2.png)

- Self-Attention requires no parameters. Up to now the interaction between words has been driven by their embedding and the positional encodings. This will change later.
- We expect values along the diagonal to be the highest.
- If we don’t want some positions to interact, we can always set their values to –∞ before applying the softmax in this matrix and the model will not learn those interactions. We will use this in the decoder.

**What is query, keys and values?**
- Query (Q): This represents the current word or part of the sequence the model is focusing on.
- Key (K): This represents all the words or parts in the input sequence.
- Value (V): This also represents all the words or parts in the input sequence, but it contains the actual information associated with each key.

**Where are Q and K coming from?**
- The transformer encoder training builds the weight parameter matrices W_Q and W_k. The calculation goes like below where x is a sequence of position-encoded word embedding vectors that represents an input sentence.
    1. Picks up <span style="color:yellow">a word vector (position encoded)</span> from the input sentence sequence, and transfer it to a vector space Q. This becomes the query. $Q = X \cdot W_Q^T$
    2. Pick <span style="color:yellow">all the words in the sentence</span> and transfer them to the vector space K. They become keys and each of them is used as key. $K = X \cdot W_K^T$
    3. For each (q, k) pair, their relation strength is calculated using dot product. $\text{qk similarity scores} = matmul(Q, K^T)$
    4. Weight matrices $W_Q$ and $W_K$ are trained via the back propagations during the Transformer training.

**Steps Involved:**
- Encoding Inputs: Each word in the input sequence is converted into a vector representation.
- Calculating Compatibility Scores: The model calculates a score for each key-query pair. This score represents how relevant a particular word (key) is to the current word being processed (query). There are different ways to calculate this score, but they typically involve measuring the dot product or cosine similarity between the query and key vectors.
- Softmax Attention: A softmax function is applied to the compatibility scores, converting them into attention weights. These weights represent the importance of each word in the sequence relative to the current word.
- Weighted Sum: The attention weights are multiplied element-wise with the value vectors. This essentially amplifies the information from the relevant parts of the sequence (based on the weights) and weakens the contribution of less relevant parts.
- Attention Output: The weighted sum of the value vectors is the attention output. This output incorporates information from the entire sequence, but with a focus on the parts most relevant to the current word.


### Multi-Head Attention

In the self-attention mechanisms there is one attention matrix which describes the similarity of two locations within the sequence. This can act as a bottleneck in the architecture – it would be useful for pairs of points to be similar in some ‘dimensions’ and different in others. If attention matrices are viewed as a datadriven version of filters in a CNN, then the need for more filters/channels is clear. Typical choices for the number of heads H is 8 or 16, lower than typical numbers of channels in a CNN.

In order to increase capacity of the first self-attention stage, the transformer block applies H sets of self-attention in parallel and then linearly projects the results down to the $D \times N$ array required for further processing. This slight generalisation is called multi-head self-attention. 

The computational cost of multi-head self-attention is usually dominated by the matrix multiplication involving the attention matrix and is therefore $O(HDN^2)$.

![Multi-Head Attention Equation](images/MultiHeadAttentionEquation.png)

Here $H$ matrices $V_h^{(m)}$ which are $D \times D$ project the $H$ self-attention stages down to the required output dimensionality $D$. 

The product of the matrices $V_h^{(m)} X^{(m-1)}$ is related to the so-called values which are normally introduced in descriptions of self-attention along side queries and keys. In the usual presentation, there is a redundancy between the linear transform used to compute the values and the linear projection at the end of the multi-head selfattention, so we have not explicitly introduced them here. The standard presentation can be recovered by setting $V_h$ to be a low-rank matrix $V_h = U_h U_{v, h}$ where $U_h$ is $D \times K$ and $U_{v, h}$ is $K \times D$. Typically K is set to K = D/H so that changing the number of heads leads to models with similar numbers of parameters and computational demands.

The addition of the matrices $V_h^{(m)}$, and the fact that retaining just the diagonal elements of the attention matrix $A^{(m)}$ will interact the signal instantaneously with itself, does mean there is some cross-feature processing in multi-head self-attention, as opposed to it containing purely cross-sequence processing. However, the stage has limited capacity for this type of processing and it is the job of the second stage to address this.

![Multi-Head Attention](images/MultiHeadAttention.png)

**Here's how it works:**
- Multiple Attention Heads: Instead of having a single set of query (Q), key (K), and value (V) vectors, the model creates multiple sets (often 4, 8, or 16). These are called attention heads.
- Linear Projections: Each input word embedding is projected independently for each attention head using different weight matrices. This creates different query, key, and value vectors for each head, allowing them to focus on different aspects of the relationships between words.
- Independent Attention: Each attention head then performs the standard attention mechanism steps (compatibility score calculation, softmax, weighted sum) independently. This results in multiple attention outputs, each capturing a different aspect of the context.
- Concatenation: Finally, the outputs from all the attention heads are concatenated to form a single final output. This combined output incorporates information from all the heads, providing a richer representation of the context.

In [None]:
class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        """This module is used to calculate the attention scores for the input tensor.

        Args:
            d_model (int): Embedding vector size.
            h (int): Number of heads.
            dropout (float): The dropout rate to apply to the attention scores.
        """
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model is not divisible by h"

        # for integer divions: // for floating point division: /
        self.d_k = d_model // h  # Dimension of vector seen by each head

        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)

        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    # staticmethod: because we can call this without an instance of the above class
    @staticmethod
    def attention(
        query: torch.Tensor, 
        key: torch.Tensor, 
        value: torch.Tensor, 
        mask: torch.Tensor, 
        dropout: nn.Dropout
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Calculates the attention scores for the input tensor.

        Args:
            query (torch.Tensor): The query tensor to calculate the attention scores for.
            key (torch.Tensor): The key tensor to calculate the attention scores for.
            value (torch.Tensor): The value tensor to calculate the attention scores for.
            mask (torch.Tensor): The mask tensor to apply to the attention scores.
            dropout (nn.Dropout): The dropout rate to apply to the attention scores

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        d_k = query.shape[-1]

        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        # @ matrix multiplication in pytorch
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        # masking is just multiplying by a very small value which then becomes zero after applying softmax
        if mask is not None:
            # replace all the values for mask==0 with -1e9
            attention_scores.masked_fill(mask==0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1)  # (batch, h, seq_len, seq_len)

        if dropout is not None:
            attention_scores = dropout(attention_scores)
        
        # The second value in the tuple is used for visualization
        return (attention_scores @ value), attention_scores

    
    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """Projects the input tensor into the query, key, and value tensors.
        Splits the query, key, and value tensors into multiple heads.
        Calculates the attention scores for the input tensor.
        Concatenates the attention scores for the multiple heads.
        Projects the attention scores back into the original dimension.

        Args:
            q (torch.Tensor): The query tensor to calculate the attention scores for.
            k (torch.Tensor): The key tensor to calculate the attention scores for.
            v (torch.Tensor): The value tensor to calculate the attention scores for.
            mask (torch.Tensor, optional): The mask tensor to apply to the attention scores. Defaults to None.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # for q, k, v: (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)
    
        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        # Using view: keep batch dim since we don't want to split the sentence, we want to split the embedding
        # torch.tensor.view: returns a new tensor with the same data as the self tensor but of a different shape.
        # Why transpose: prefer h as the second dim, this way each head will see all the sentences
        # view() is used cause it's typically faster than reshape(), cause it doesn't need to copy the underlying data
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Getting attention scores
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        
        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        return self.w_o(x)



## General Layers

### Layer Normalization

**Batch Normalization (BN):**
- Normalization Scope: BN normalizes activations across a mini-batch of training data. It calculates the mean and standard deviation of the activations for each feature (channel) within the mini-batch and then uses these statistics to normalize the individual activations.
- Impact: This helps to address the problem of internal covariate shift, where the distribution of activations can change significantly between mini-batches during training. By normalizing each mini-batch, BN ensures that the gradients used for updating the model's weights are more stable and less prone to exploding or vanishing gradients.

*Limitations*:
- Reliance on Batch Size: BN's effectiveness depends on the mini-batch size. Smaller batch sizes can lead to less accurate estimates of the mean and standard deviation, potentially reducing its effectiveness.
- Increased Memory Consumption: BN requires storing the mean and standard deviation for each feature across the mini-batch, which can increase memory usage.

![Layer Normalization](images/LayerNormalization.png)

**Layer Normalization (LN):**
- Normalization Scope: LN normalizes activations within each layer, independently for each sample in the mini-batch. It calculates the mean and standard deviation of the activations for each feature (channel) across all elements in a single sample (e.g., across the entire image width and height for a convolutional layer).
- Impact: LN also addresses internal covariate shift, but at a different level. It ensures that the activations within a layer have a consistent distribution for each sample, regardless of the mini-batch or other samples in the training data. This can improve the stability of gradients and learning.

*Benefits*:
- Less Sensitive to Batch Size: LN is less sensitive to the mini-batch size compared to BN. This makes it potentially more robust for various training scenarios, including settings with small batch sizes.
- Lower Memory Footprint: LN doesn't require storing statistics across the entire mini-batch, reducing memory consumption.

*Potential Drawbacks*:
- Limited Context: LN might not capture long-range dependencies between features as effectively as BN, which considers activations across the entire mini-batch.
- Less Flexibility: LN applies the same normalization across all features within a layer, while BN allows for different normalizations for each feature.



In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps: float = 10**-6) -> None:
        """Initializes the LayerNormalization module. This module is used to normalize the input tensor.

        Args:
            features (int): The number of features in the input tensor.
            eps (float, optional): Needed to prevent div by 0 error. Defaults to 10**-6.
        """
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features))  # Multiplied
        self.bias = nn.Parameter(torch.zeros(features))  # Added
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Normalizes the input tensor using the gamma and beta parameters.
        - Assumes input has batch, the mean and standard deviation is calculated over the last dimension 
        leaving the batch dimension.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to normalize.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # x: (batch, seq_len, hidden_size)
        # Keep the dimension for broadcasting
        self.mean = x.mean(dim=-1, keepdim=True)  # (batch, seq_len, 1)
        self.std = x.std(dim=-1, keepdim=True)  # (batch, seq_len, 1)
        return self.alpha * (x - self.mean) / (self.std + self.eps) + self.bias
    

### Feed Forward Network

The transformer also includes this point-wise feed-forward network in both the encoder and decoder:

![Feed Forward Network](images/FeedForwardNetwork.png)


In [None]:
class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        """This module is used to project the input tensor into a higher dimension 
        and then back into the original dimension.

        Args:
            d_model (int): The number of dimensions of the input tensor.
            d_ff (int): The number of dimensions to project the input tensor into.
            dropout (float): The dropout rate to apply to the output tensor.
        """
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)  # W1 and B1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)  # W2 and B2
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Projects the input tensor into a higher dimension using a linear layer and a ReLU activation function.
        Projects the input tensor back into the original dimension using a linear layer.
        Applies dropout to the output tensor.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to project and then project back.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))


### Residual Connection

The use of residual connections is widespread across machine learning as they make initialisation simple, have a sensible inductive bias towards simple functions, and stabilise learning. Instead of directly specifying a function $x^{(m)} = f_{\theta} (x^{(m-1)})$, the idea is to parameterise it in terms of an identity mapping and a residual term
$$x^{(m)} = x^{(m-1)} + \text{res}_{\theta} (x^{(m-1)})$$

Equivalently, this can be viewed as modelling the differences between the representation $x^{(m)} - x^{(m-1)} = \text{res}_{\theta} (x^{(m-1)})$ and will work well when the function that is being modelled is close to identity. This type of parameterisation is used for both the Multi-Head Self-Attention and MLP stages in the transformer, with the idea that each applies a mild non-linear transformation to the representation. Over many layers, these mild non-linear transformations compose to form large transformations.


In [None]:
class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float = Config.dropout_constant) -> None:
        """This module is used to add the input tensor to the output tensor.

        Args:
            features (int): The number of features in the input tensor.
            dropout (float, optional): The dropout rate to apply to the output tensor. Defaults to Config.dropout_constant.
        """
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)
    
    def forward(self, x: torch.Tensor, sublayer: torch.nn.Module) -> torch.Tensor:
        """Adds the input tensor to the output tensor. Applies dropout to the output tensor.

        Args:
            x (torch.Tensor): The input tensor to add to the output tensor.
            sublayer (torch.nn.Module): The network layer to apply to the input tensor.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        return x + self.dropout(sublayer(self.norm(x)))

## Encoder Block

![Encoder Block](images/EncoderBlock.png)


In [None]:
class EncoderBlock(nn.Module):
    def __init__(
            self,
            features: int,
            self_attention_block: MultiHeadAttentionBlock,
            feed_forward_block: FeedForwardBlock,
            dropout: float = Config.dropout_constant
        ) -> None:
        """Initializes the EncoderBlock module. This module is used to encode the input tensor.

        Args:
            features (int): The number of features in the input tensor.
            self_attention_block (MultiHeadAttentionBlock): The self attention layer to apply to the input tensor.
            feed_forward_block (FeedForwardBlock): The feed forward layer to apply to the input tensor.
            dropout (float, optional): The dropout rate to apply to the output tensor. Defaults to Config.dropout_constant.
        """
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connection = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
    
    def forward(self, x: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        """Applies the self attention layer to the input tensor.
        Adds the input tensor to the output tensor.
        Applies the feed forward layer to the output tensor.
        Adds the input tensor to the output tensor.

        Args:
            x (torch.Tensor): The input tensor to encode.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connection[1](x, self.feed_forward_block)
        return x

### Encoder

![Encoder](images/Encoder.png)


In [None]:
class Encoder(nn.Module):
    def __init__(self, features:int, layers: nn.ModuleList) -> None:
        """Initializes the Encoder module. This module is used to encode the input tensor.

        Args:
            features (int): The number of features in the input tensor.
            layers (nn.ModuleList): The layers to apply to the input tensor.
        """
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Applies the layers to the input tensor. Applies the layer normalization to the output tensor.

        Args:
            x (torch.Tensor): The input tensor to encode.
            mask (torch.Tensor): The mask tensor to apply to the attention scores.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

## Decoder Block

![Decoder Block](images/DecoderBlock.png)

**Masked Multi-Head Attention**
- Our goal is to make the model causal: it means the output at a certain position can only depend on the words on the previous positions. The model must not be able to see future words.

![Masked Multi-Head Attention](images/MaskedMultiHeadAttention.png)

- All the values above the diagonal are replace with -∞ before applying the softmax, which will replace them with zero.

In [None]:
class DecoderBlock(nn.Module):
    def __init__(
            self, 
            features: int, 
            self_attention_block: MultiHeadAttentionBlock,
            cross_attention_block: MultiHeadAttentionBlock,
            feed_forward_block: FeedForwardBlock,
            dropout: float = Config.dropout_constant
        ) -> None:
        """_summary_

        Args:
            features (int): The number of features in the input tensor.
            self_attention_block (MultiHeadAttentionBlock): The self attention layer to apply to the input tensor.
            cross_attention_block (MultiHeadAttentionBlock): The source attention layer to apply to the input tensor.
            feed_forward_block (FeedForwardBlock): The feed forward layer to apply to the input tensor.
            dropout (float, optional): The dropout rate to apply to the output tensor. Defaults to Config.dropout_constant.
        """
        super().__init__()

        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_bock = feed_forward_block
        self.residual_connection = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(
        self,
        x: torch.Tensor, 
        encoder_output: torch.Tensor, 
        src_mask: torch.Tensor, 
        tgt_mask: torch.Tensor
    ) -> torch.Tensor:
        """Applies the self attention layer to the input tensor.
        Adds the input tensor to the output tensor.
        Applies the source attention layer to the output tensor.
        Adds the input tensor to the output tensor.
        Applies the feed forward layer to the output tensor.
        Adds the input tensor to the output tensor.

        Args:
            x (torch.Tensor): The input tensor to decode.
            encoder_output (torch.Tensor): The output tensor from the encoder.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores for the source tensor.
            tgt_mask (torch.Tensor): The mask tensor to apply to the attention scores for the target tensor.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        x = self.residual_connection[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connection[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connection[2](x, self.feed_forward_bock)
        return x


### Decoder

![Decoder](images/Decoder.png)


In [None]:
class Decoder(nn.Module):
    def __init__(self, features:int, layers: nn.ModuleList) -> None:
        """Initializes the Decoder module. This module is used to decode the input tensor.

        Args:
            features (int): The number of features in the input tensor.
            layers (nn.ModuleList): The layers to apply to the input tensor.
        """
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)
    
    def forward(
        self, 
        x: torch.Tensor, 
        encoder_output: torch.Tensor, 
        src_mask: torch.Tensor, 
        tgt_mask: torch.Tensor
    ) -> torch.Tensor:
        """Applies the layers to the input tensor. Applies the layer normalization to the output tensor.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to decode.
            encoder_output (torch.Tensor): The output tensor from the encoder.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores for the source tensor.
            tgt_mask (torch.Tensor): The mask tensor to apply to the attention scores for the target tensor.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)


## Projection Layer

In [None]:
class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int) -> None:
        """Initializes the ProjectionLayer module. 
        This module is used to project the input tensor into a higher dimension.

        Args:
            d_model (int): The number of dimensions of the input tensor.
            vocab_size (int): The number of unique tokens in the input.
        """
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Projects the input tensor into a higher dimension using a linear layer.

        Args:
            x (torch.Tensor): The input tensor (batch_size, sequence_length, dimensions) to project.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, vocabulary_size).
        """
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return torch.log_softmax(self.proj(x), dim=-1)
    

## Transformer Block

![Transformer Block](images/TransformerBlock.png)

The representation of the input sequence will be produced by iteratively applying a transformer block 
$$X^{(m)} = \text{transformer-block} (X^{(m-1)})$$

The block itself comprises two stages: 
- one operating across the sequence and 
- one operating across the features.

The first stage refines each feature independently according to relationships between tokens across the sequence e.g. how much a word in a sequence at position $n$ depends on previous words at position n′, or how much two different patches from an image are related to one another. This stage acts horizontally across rows of $X^{(m-1)}$.

The second stage refines the features representing each token. This stage acts vertically across a column of $X^{(m-1)}$. By repeatedly applying the transformer block the representation at token n and feature d can be shaped by information at token n′ and feature d′.



In [None]:
class Transformer(nn.Module):
    def __init__(
            self, 
            encoder: Encoder,
            decoder: Decoder,
            src_embed: InputEmbeddings,
            tgt_embed: InputEmbeddings,
            src_pos: PositionalEncoding,
            tgt_pos: PositionalEncoding,
            projection_layer: ProjectionLayer
        ) -> None:
        """Initializes the Transformer module. This module is used to encode and decode the input tensor.

        Args:
            encoder (Encoder): The encoder to encode the input tensor.
            decoder (Decoder): The decoder to decode the input tensor.
            src_embed (InputEmbeddings): The input embedder to embed the input tensor.
            tgt_embed (InputEmbeddings): The target embedder to embed the input tensor.
            src_pos (PositionalEncoding): The positional encoder to encode the input tensor.
            tgt_pos (PositionalEncoding): The positional encoder to encode the target tensor.
            projection_layer (ProjectionLayer): _description_
        """
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer
    
    def encode(self, src: torch.Tensor, src_mask: torch.Tensor) -> torch.Tensor:
        """Embeds the input tensor. Encodes the input tensor.

        Args:
            src (torch.Tensor): The input tensor (batch_size, sequence_length) to encode.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(
        self,
        encoder_output: torch.Tensor, 
        src_mask: torch.Tensor, 
        tgt: torch.Tensor, 
        tgt_mask: torch.Tensor
    ) -> torch.Tensor:
        """Decodes the input tensor.

        Args:
            encoder_output (torch.Tensor): The output tensor from the encoder.
            src_mask (torch.Tensor): The mask tensor to apply to the attention scores for the source tensor.
            tgt (torch.Tensor): The target tensor.
            tgt_mask (torch.Tensor): The mask tensor to apply to the attention scores for the target tensor.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, dimensions).
        """
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x: torch.Tensor) -> torch.Tensor:
        """Projects the input tensor into a higher dimension.

        Args:
            x (torch.Tensor): The input tensor to project.

        Returns:
            torch.Tensor: An output tensor of shape (batch_size, sequence_length, vocabulary_size).
        """
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)
    

## Build Transformer

In [None]:
def build_transformer(
    src_vocab_size: int,
    tgt_vocab_size: int,
    src_seq_len: int,
    tgt_seq_len: int,
    d_model: int,
    num_layers: int,
    num_heads: int,
    dropout: float,
    d_ff: int
) -> Transformer:
    """Builds the transformer model.

    Args:
        src_vocab_size (int): The number of unique tokens in the input.
        tgt_vocab_size (int): The number of unique tokens in the target.
        src_seq_len (int): The maximum length of the input sequence.
        tgt_seq_len (int): The maximum length of the target sequence.
        d_model (int): The number of dimensions to embed the input tokens into.
        num_layers (int): The number of layers to apply to the input tensor.
        num_heads (int): The number of heads to split the input tensor into.
        dropout (float): The dropout rate to apply to the output tensor.
        d_ff (int): The number of dimensions to project the input tensor into.

    Returns:
        Transformer: Returns a transformer model.
    """
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positioonal encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)

    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(num_layers):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)
    
    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(num_layers):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, num_heads, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(
            d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout
        )
        decoder_blocks.append(decoder_block)
    
    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))

    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)

    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)

    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
        
    return transformer


# Training a Translation Model

In order to apply a transformer, data must be converted into a set or sequence of N tokens $x_n$ of dimension D. The tokens can be collected into a matrix $X^{(0)}$ which is $ D \times N$. To give two concrete examples
1. a passage of text can be broken up into a sequence of words or sub-words, with each word being represented by a single unique vector,
2. an image can be broken up into a set of patches and each patch can be mapped into a vector.

The embeddings can be fixed or they can be learned with the rest of the parameters of the model e.g. the vectors representing words can be optimised or a learned linear transform can be used to embed image patches.

![Encoding an Image](images/EncodinganImage.png)

A sequence of tokens is a generic representation to use as an input – many different types of data can be “tokenised” and transformers are then immediately applicable rather than requiring a bespoke architectures for each modality as was previously the case (CNNs for images, RNNs for sequences, deepsets for sets etc.). Moreover, this means that you don’t need bespoke handcrafted architectures for mixing data of different modalities — you can just throw them all into a big set of tokens.

The transformer will ingest the input data $X^{(0)}$ and return a representation of the sequence in terms of another matrix $X^{(M)}$ which is also of size $ D \times N$. The slice $x_n = X_{:,n}^{(M)}$ will be a vector of features representing the sequence at the location of token $n$. These representations can be used for auto-regressive prediction of the next (n+1)th token,  global classification of the entire sequence (by pooling across the whole representation), sequence-to-sequence or imageto-image prediction problems, etc. Here $M$ denotes the number of layers in the transformer.



## Tokenizer

In [None]:
def get_all_sentences(ds: Dataset, lang: str):
    """A generator to get all sentences in a dataset.

    Args:
        ds (Dataset): The dataset.
        lang (str): The language.

    Yields:
        _type_: A generator.
    """
    for item in ds:
        yield item['translation'][lang]

In [None]:
def get_or_build_tokenizer(config: Config, ds: Dataset, lang: str) -> Tokenizer:
    """Gets the tokenizer if it exists, otherwise builds it and saves it.

    Args:
        config (Config): The configuration.
        ds (Dataset): The dataset.
        lang (str): The language.

    Returns:
        Tokenizer: The tokenizer.
    """
    tokenizer_path = Path(config.tokenizer_file.format(lang))
    if not Path.exists(tokenizer_path):
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer


## Load data

In [None]:
def casual_mask(size: int) -> torch.Tensor:
    """Create a causal mask for the decoder attention."""
    # torch.triu: returns the upper triangular part of a matrix (2-D tensor) or batch of matrices
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0


In [None]:
class BilingualDataset(Dataset):
    def __init__(
            self,
            ds: Dataset, 
            tokenizer_src: Tokenizer, 
            tokenizer_tgt: Tokenizer, 
            src_lang: str, 
            tgt_lang: str, 
            seq_len: int
        ) -> None:
        """Constructor for the BilingualDataset class.

        Args:
            ds (Dataset): The dataset to use.
            tokenizer_src (Tokenizer): The tokenizer for the source language.
            tokenizer_tgt (Tokenizer): The tokenizer for the target language.
            src_lang (str): The source language.
            tgt_lang (str): The target language.
            seq_len (int): The sequence length to use.
        """
        
        super().__init__()
        self.seq_len = seq_len
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        # Start of Sentence, End of Sentence and Padding Tokens
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        """Get the length of the dataset."""
        return len(self.ds)
    
    def truncate_sequence(self, sequence, max_len: int):
        """Truncates a sequence of tokens to the specified maximum length.

        Args:
            sequence: A list of integer tokens representing the sentence.
            max_len: The maximum allowed length for the sequence.

        Returns:
            A list of truncated tokens and the original length of the sequence.
        """
        if len(sequence) <= max_len:
            return sequence, len(sequence)
        else:
            # Truncate from the beginning (optional: modify to truncate from the end)
            truncated_sequence = sequence[:max_len]
            return truncated_sequence, len(sequence)


    def __getitem__(self, idx) -> Dict:
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        # Transform the text into tokens & Truncate sentences if necessary
        # Note that for sentences > seq_len, by adding -2 and -1, I am choosing to not have any padding, 
        # cause enc_num_padding_tokens and dec_num_padding_tokens will turn out to be = 0
        enc_input_tokens, orig_enc_len = self.truncate_sequence(self.tokenizer_src.encode(src_text).ids, self.seq_len - 2)
        dec_input_tokens, orig_dec_len = self.truncate_sequence(self.tokenizer_tgt.encode(tgt_text).ids, self.seq_len - 1)

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # -2 cause we will add <s> and </s>
        # We will only add <s> (SOS), and </s> (EOS) only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1


        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long.
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError(
                f"Original sentence length exceeded allowed limit of seq_len - {self.seq_len}. "
                f"Original lengths: source - {orig_enc_len}, target - {orig_dec_len}"
            )

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
            ],
            dim=0,
        )
        
        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
            ],
            dim=0
        )
        
        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len, f"encoder_input size: {encoder_input.size(0)} and seq_len: {self.seq_len}"
        assert decoder_input.size(0) == self.seq_len, f"decoder_input size: {decoder_input.size(0)} and seq_len: {self.seq_len}"
        assert label.size(0) == self.seq_len, f"label size: {label.size(0)} and seq_len: {self.seq_len}"

        return {
            "encoder_input": encoder_input,   # (seq_len)
            "decoder_input": decoder_input,   # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),  # (1, 1, seq_len)
            # (1, seq_len) & (1, seq_len, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & casual_mask(decoder_input.size(0)),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text
        }


In [None]:
def get_ds(config: Config) -> Tuple[DataLoader, DataLoader, Tokenizer, Tokenizer]:
    """Loads the dataset, builds the tokenizers and returns the dataloaders.

    Args:
        config (Config): The configuration.

    Returns:
        Tuple[DataLoader, DataLoader, Tokenizer, Tokenizer]: The training and validation dataloaders.
    """
    # It only has the train split, so we divide it overselves
    ds_raw = load_dataset(f"{config.datasource}", f"{config.lang_src}-{config.lang_tgt}", split="train")

    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config.lang_src)
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config.lang_tgt)

    # Keep 90% for training, 10% for validation
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    # Need to create the tensors the model will use
    train_ds = BilingualDataset(
        train_ds_raw, tokenizer_src, tokenizer_tgt, config.lang_src, config.lang_tgt, config.seq_len
    )
    val_ds = BilingualDataset(
        val_ds_raw, tokenizer_src, tokenizer_tgt, config.lang_src, config.lang_tgt, config.seq_len
    )

    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config.lang_src]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config.lang_tgt]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))
    
    print(f"Item: {item}")
    print(f"Encoder: {tokenizer_src.encode(item['translation'][config.lang_src])}")
    print(f"Source Encoder ids: {tokenizer_src.encode(item['translation'][config.lang_src]).ids}")
    print(f"Source Encoder tokens: {tokenizer_src.encode(item['translation'][config.lang_src]).tokens}")
    print(f"Max length of source sentence: {max_len_src}")
    print(f"Max length of target sentence: {max_len_tgt}")

    train_dataloader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt


## Train Model

![Training Model](images/TrainingModel.png)


In [None]:
def greedy_decode(
        model: Transformer, 
        source: torch.Tensor, 
        source_mask: torch.Tensor, 
        tokenizer_src: Tokenizer, 
        tokenizer_tgt: Tokenizer, 
        max_len: int, 
        device: torch.device
    ):
    """
    Greedy Decode:
    - Finds the most likely next token at each step and appends it to the decoder input.
    - Just picks the token with the max probabilitiesability at each step as the next token.
    - May not give the ideal output, but it's fast and simple. For better results, use beam search.
    
    Parameters:
    - model: The transformer model
    - source: The input sentence
    - source_mask: The mask for the input sentence
    - source_tokenizer: The tokenizer for the source language
    - target_tokenizer: The tokenizer for the target language
    - max_length: The maximum length of the output sentence
    - device: The device to run the model on
    
    Shapes:
    - source: (1, sequence_length, dimensions)
    - source_mask: (1, 1, sequence_length, sequence_length)
    
    Returns: 
    - The output sentence
    """
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)

    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = casual_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break
    
    return decoder_input.squeeze(0)



In [None]:
def get_model(config: Config, vocab_src_len: int, vocab_tgt_len: int) -> Transformer:
    """Builds the transformer model

    Args:
        config (Config): The configuration.
        vocab_src_len (int): The length of the source vocabulary.
        vocab_tgt_len (int): The length of the target vocabulary.

    Returns:
        Transformer: The transformer model.
    """
    model = build_transformer(
        vocab_src_len,
        vocab_tgt_len,
        config.seq_len,
        config.seq_len,
        config.d_model,
        config.num_layers,
        config.num_heads,
        config.dropout_constant,
        config.d_ff,
    )
    return model


In [None]:
def run_validation(
    model: Transformer,
    validation_ds: DataLoader,
    tokenizer_src: Tokenizer,
    tokenizer_tgt: Tokenizer,
    max_len: int,
    device: torch.device,
    print_msg,
    global_step: int,
    writer,
    num_examples: int = 2
):
    """
    Run Validation:
    - Runs the model on the validation dataset and prints the source, target and predicted sentences.
    
    Parameters:
    - model: The transformer model
    - validation_ds: The validation dataset
    - source_tokenizer: The tokenizer for the source language
    - target_tokenizer: The tokenizer for the target language
    - max_length: The maximum length of the output sentence
    - device: The device to run the model on
    - print_message: A function to print messages
    - global_step: The global step
    - num_examples: The number of examples to print
    
    Shapes:
    - source: (1, sequence_length, dimensions)
    - source_mask: (1, 1, sequence_length, sequence_length)
    
    Returns:
    - Inference
    """
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except Exception:
        # If we can't get the console width, use 80 as default
        console_width = 80
    
    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch['encoder_input'].to(device)  # (b, seq_len)
            encoder_mask = batch['encoder_mask'].to(device)  # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(0)==1, "Batch size must be 1 for validation."

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{'SOURCE: ':>12}{source_text}")
            print_msg(f"{'TARGET: ':>12}{target_text}")
            print_msg(f"{'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-' * console_width)
                break
    
    if writer:
        # Compute the char error rate
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()


In [None]:
def train_model(config: Config):
    """Trains the transformer model.

    Args:
        config (Config): The configuration.
    
    Steps:
        - Get the dataset
        - Get the model
        - Define the optimizer
        - Define the loss function
        - For each epoch:
            - For each batch:
                - Run the tensors through the encoder, decoder and the projection layer
                - Compare the output with the label
                - Compute the loss using a simple cross entropy
                - Backpropagate the loss
                - Update the weights
            - Run validation at the end of every epoch
            - Save the model at the end of every epoch
    """
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print(f"Using device: {device}")

    if (device == "cuda"):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == "mps"):
        print("Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
    
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config.datasource}_{config.model_folder}").mkdir(parents=True, exist_ok=True)

    # check pre-existing saved data and either load or continue to getting the dataset
    datafolder_path = Path(config.data_folder)
    filenames = ["train_dataloader.pkl", "val_dataloader.pkl", "tokenizer_src.pkl", "tokenizer_tgt.pkl"]
    file_existences = [os.path.exists(f"{datafolder_path}/{filename}") for filename in filenames]
    if all(file_existences):
        train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = load_pkl_files(datafolder_path, filenames)
        print("Pickle files loaded.")
    else:
        train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
        save_pkl_files(datafolder_path, filenames, [train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt])
    
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config.experiment_name)

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config.preload
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f"Preloading model {model_filename}")
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print("No model to preload, starting from scratch.")
    
    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config.num_epochs):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch: 02d}")

        for batch in batch_iterator:
            encoder_input = batch['encoder_input'].to(device)  # (B, seq_len)
            decoder_input = batch['decoder_input'].to(device)  # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device)  # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device)  # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask)  # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)  # (B, seq_len, d_model)
            proj_output = model.project(decoder_output)  # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device)  # (B, seq_len)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss" : f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1
        
        # Run validation at the end of every epoch
        run_validation(
            model,
            val_dataloader,
            tokenizer_src,
            tokenizer_tgt,
            config.seq_len,
            device,
            lambda msg: batch_iterator.write(msg),
            global_step,
            writer
        )

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "global_step": global_step
        }, model_filename)



# Main

## Training

In [None]:
config = Config()
warnings.filterwarnings("ignore")
train_model(config)

## Validation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
datafolder_path = Path(config.data_folder)
filenames = ["train_dataloader.pkl", "val_dataloader.pkl", "tokenizer_src.pkl", "tokenizer_tgt.pkl"]
file_existences = [os.path.exists(f"{datafolder_path}/{filename}") for filename in filenames]
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = load_pkl_files(datafolder_path, filenames)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

# Load the pretrained weights
model_filename = latest_weights_file_path(config)
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])


In [None]:
run_validation(
    model, 
    val_dataloader, 
    tokenizer_src, 
    tokenizer_tgt, 
    config.seq_len, 
    device, 
    lambda msg: print(msg), 
    0, 
    None, 
    num_examples=5
)

## Inference



![Inference Model Time Step 1](images/InferenceModelTimeStep1.png)

![Inference Model Time Step 4](images/InferenceModelTimeStep4.png)

- We selected, at every step, the word with the maximum softmax value. This strategy is called *greedy* and usually does not perform very well.
- A better strategy is to select at each step the top B words and evaluate all the possible next words for each of them and at each step, keeping the top B most probable sequences. This is the *Beam Search strategy* and generally performs better




In [None]:
def translate(config: Config, sentence: str):
    # Define the device, tokenizers, and model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer_src = Tokenizer.from_file(str(Path(config.tokenizer_file.format(config.lang_src))))
    tokenizer_tgt = Tokenizer.from_file(str(Path(config.tokenizer_file.format(config.lang_tgt))))
    model = build_transformer(
        tokenizer_src.get_vocab_size(),
        tokenizer_tgt.get_vocab_size(),
        config.seq_len,
        config.seq_len,
        config.d_model
    ).to(device)

    # Load the pretrained weights
    model_filename = latest_weights_file_path(config)
    state = torch.load(model_filename)
    model.load_state_dict(state['model_state_dict'])

    # translate sentence
    seq_len = config.seq_len
    model.eval()
    with torch.no_grad():
        # Precompute the encoder output and reuse it for every generation step
        source = tokenizer_src.encode(sentence)
        source = torch.cat(
            [
                torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
                torch.tensor(source.ids, dtype=torch.int64),
                torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
                torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (seq_len - len(source.ids) - 2), dtype=torch.int64)
            ], dim=0
        ).to(device)
        source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
        encoder_output = model.encode(source, source_mask)

        # Initialize the decoder input with the sos token
        decoder_input = torch.empty(1, 1).fill_(tokenizer_tgt.token_to_id('[SOS]')).type_as(source).to(device)

        # Generate the translation word by word
        while decoder_input.size(1) < seq_len:
            # build mask for target and calculate output
            decoder_mask = torch.triu(
                torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1
            ).type(torch.int).type_as(source_mask).to(device)
            out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

            # project next token
            prob = model.project(out[:, -1])
            _, next_word = torch.max(prob, dim=1)
            decoder_input = torch.cat(
                [
                    decoder_input,
                    torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)
                ], dim=1
            )

            # print the translated word
            print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')

            # break if we predict the end of sentence token
            if next_word == tokenizer_tgt.token_to_id('[EOS]'):
                break
    
    # convert ids to tokens
    return tokenizer_tgt.decode(decoder_input[0].tolist())


In [None]:
translate(config, "Good morning transformer.")

# Transformer Alternatives

Due to the success of Transformers, a variety of variants (a.k.a. X-formers) have been proposed over the past few years. These X-formers improve the vanilla Transformer from different perspectives.

1. *Model Efficiency*. A key challenge of applying Transformer is its inefficiency at processing long sequences mainly due to the computation and memory complexity of the self-attention module. The improvement methods include lightweight attention (e.g. sparse attention variants) and Divide-and-conquer methods (e.g., recurrent and hierarchical mechanism).
2. *Model Generalization*. Since the transformer is a flexible architecture and makes few assumptions on the structural bias of input data, it is hard to train on small-scale data. The improvement methods include introducing structural bias or regularization, pre-training on large-scale unlabeled data, etc.
3. *Model Adaptation*. This line of work aims to adapt the Transformer to specific downstream tasks and applications.

**Model usage**
Generally, the Transformer architecture can be used in three different ways:
- *Encoder–Decoder.* The full Transformer architecture as introduced in Section 2.1 is used. This is typically used in sequence-to-sequence modeling (e.g., neural machine translation).
- *Encoder only*. Only the encoder is used and the outputs of the encoder are utilized as a representation for the input sequence. This is often used for Natural Language Understanding (NLU) tasks (e.g., text classification and sequence labeling).
- *Decoder only*. Only the decoder is used, where the encoder-decoder cross-attention module is also removed. This is typically used for sequence generation (e.g., language modeling).

**Taxonomy of Transformers**

![Taxonomy of Transformers](images/TaxonomyofTransformers.jpg)

**Attention**
The improvements on attention mechanism can be divided into several directions:
1. *Sparse Attention*. This line of work introduces sparsity bias into the attention mechanism, leading to reduced complexity.

    ![Sparse Attention Variants](images/SparseAttentionVariants.jpg)

2. *Linearized Attention*. This line of work disentangles the attention matrix with kernel feature maps. The attention is then computed in reversed order to achieve linear complexity.

    ![Linearized Attention](images/LinearizedAttention.jpg)

3. *Prototype and Memory Compression*. This class of methods reduces the number of queries or key–value memory pairs to reduce the size of the attention matrix.

    ![Prototype and Memory Compression](images/PrototypeandMemoryCompression.jpg)

4. *Low-rank Self-Attention*. This line of work capture the low-rank property of self-attention.

5. *Attention with Prior*. The line of research explores supplementing or substituting standard attention with prior attention distributions.

6. *Improved Multi-Head Mechanism*. The line of studies explores different alternative multi-head mechanisms.



## Positional Encoding

### Learned Positional Encoding

Learned positional encoding assigns each element with a learned column vector which encodes its absolute position  [(Gehring, et al. 2017)](https://arxiv.org/abs/1705.03122) and furthermroe this encoding can be learned differently per layer [(Al-Rfou et al. 2018)](https://arxiv.org/abs/1808.04444).

### Relative Position Encoding

[Shaw et al. (2018)](https://arxiv.org/abs/1803.02155) incorporated relative positional information into $W^k$ and $W^v$. Maximum relative position is clipped to a maximum absolute value of $k$ and this clipping operation enables the model to generalize to unseen sequence lengths. Therefore, $2k + 1$ unique edge labels are considered and let us denote $P^k, P^v \in \R^{2k + 1}$ as learnable relative position representations.

**Transformer-XL** [(Dai et al., 2019)](https://arxiv.org/abs/1901.02860) proposed a type of relative positional encoding based on reparametrization of dot-product of keys and queries. To keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. $i - j$, between one key vector $k_{\tau, j}$ and its query $q_{\tau, i}$.

### Rotary Position Embedding

Rotary position embedding [(RoPE; Su et al. 2021)](https://arxiv.org/abs/2104.09864) encodes the absolution position with a rotation matrix and multiplies key and value matrices of every attention layer with it to inject relative positional information at every layer.

When encoding relative positional information into the inner product of the i-th key and the j-th query, we would like to formulate the function in a way that the inner product is only about the relative position $i - j$. Rotary Position Embedding (RoPE) makes use of the rotation operation in Euclidean space and frames the relative position embedding as simply rotating feature matrix by an angle proportional to its position index.



## Longer Context

The length of an input sequence for transformer models at inference time is upper-bounded by the context length used for training. Naively increasing context length leads to high consumption in both time ($O(L^2d)$) and memory ($O(L^2)$) and may not be supported due to hardware constraints.

### Context Memory

The vanilla Transformer has a fixed and limited attention span. The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments. This context segmentation causes several issues:
- The model cannot capture very long term dependencies.
- It is hard to predict the first few tokens in each segment given no or thin context.
- The evaluation is expensive. Whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens.

**Transformer-XL** ([Dai et al., 2019](https://arxiv.org/abs/1901.02860); “XL” means “extra long”) modifies the architecture to reuse hidden states between segments with an additional memory. The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments.

**Compressive Transformer** ([Rae et al. 2019](https://arxiv.org/abs/1911.05507)) extends Transformer-XL by compressing past memories to support longer sequences. It explicitly adds memory slots of size $m_m$ per layer for storing past activations of this layer to preserve long context. When some past activations become old enough, they are compressed and saved in an additional compressed memory of size $m_{cm}$ per layer.
- Compressive transformer has two additional training losses:
    - Auto-encoding loss (lossless compression objective) measures how well we can reconstruct the original memories from compressed memories
    - Attention-reconstruction loss (lossy objective) reconstructs content-based attention over memory vs compressed memory and minimize the difference

### Non-Differentiable External Memory

**kNN-LM** ([Khandelwal et al. 2020](https://arxiv.org/abs/1911.00172)) enhances a pretrained LM with a separate kNN model by linearly interpolating the next token probabilities predicted by both models. The kNN model is built upon an external key-value store which can store any large pre-training dataset or OOD new dataset. This datastore is preprocessed to save a large number of pairs, (LM embedding representation of context, next token) and the nearest neighbor retrieval happens in the LM embedding space. Because the datastore can be gigantic, we need to rely on libraries for fast dense vector search such as [FAISS](https://github.com/facebookresearch/faiss) or [ScaNN](https://github.com/google-research/google-research/tree/master/scann). The indexing process only happens once and parallelism is easy to implement at inference time.

**SPALM** (Adaptive semiparametric language models; [Yogatama et al. 2021](https://arxiv.org/abs/2102.02557)) incorporates both (1) Transformer-XL style memory for hidden states from external context as short-term memory and (2) kNN-LM style key-value store as long memory. During training, the key representations in the long-term memory stay constant, produced by a pretrained LM, but the value encoder, aka the word embedding matrix, gets updated.

**Memorizing Transformer** [(Wu et al. 2022)](https://arxiv.org/abs/2203.08913) adds a kNN-augmented attention layer near the top stack of a decoder-only Transformer. This special layer maintains a Transformer-XL style FIFO cache of past key-value pairs. The same QKV values are used for both local attention and kNN mechanisms. The kNN lookup returns top-k (key, value) pairs for each query in the input sequence and then they are processed through the self-attention stack to compute a weighted average of retrieved values. Two types of attention are combined with a learnable per-head gating parameter. To prevent large distributional shifts in value magnitude, both keys and values in the cache are normalized.

### Distance-Enhanced Attention Scores

**Distance Aware Transformer** (DA-Transformer; [Wu, et al. 2021](https://arxiv.org/abs/2010.06925)) and **Attention with Linear Biases** (ALiBi; [Press et al. 2022](https://arxiv.org/abs/2108.12409)) are motivated by similar ideas — in order to encourage the model to extrapolate over longer context than what the model is trained on, we can explicitly attach the positional information to every pair of attention score based on the distance between key and query tokens.

Note that the default positional encoding in vanilla Transformer only adds positional information to the input sequence, while later improved encoding mechanisms alter attention scores of every layer, such as rotary position embedding, and they take on form very similar to distance enhanced attention scores.

DA-Transformer (Wu, et al. 2021) multiplies attention scores at each layer by a learnable bias that is formulated as a function of the distance between key and query. Different attention heads use different parameters to distinguish diverse preferences to short-term vs long-term context.

### Make it Recurrent

**Universal Transformer** [(Dehghani, et al. 2019)](https://arxiv.org/abs/1807.03819) combines self-attention in Transformer with the recurrent mechanism in RNN, aiming to benefit from both a long-term global receptive field of Transformer and learned inductive biases of RNN. Rather than going through a fixed number of layers, Universal Transformer dynamically adjusts the number of steps using adaptive computation time. If we fix the number of steps, an Universal Transformer is equivalent to a multi-layer Transformer with shared parameters across layers.

On a high level, the universal transformer can be viewed as a recurrent function for learning the hidden state representation per token. The recurrent function evolves in parallel across token positions and the information between positions is shared through self-attention.




## Adaptive Modeling

Adaptive modeling refers to a mechanism that can adjust the amount of computation according to different inputs. For example, some tokens may only need local information and thus demand a shorter attention span; Or some tokens are relatively easier to predict and do not need to be processed through the entire attention stack.

### Adaptive Attention Span

If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model. This is the motivation for Adaptive Attention Span. [Sukhbaatar et al (2019)](https://arxiv.org/abs/1905.07799) proposed a self-attention mechanism that seeks an optimal attention span. They hypothesized that different attention heads might assign scores differently within the same context window and thus the optimal span would be trained separately per head.

### Depth-Adaptive Transformer

At inference time, it is natural to assume that some tokens are easier to predict and thus do not require as much computation as others. Therefore we may only process its prediction through a limited number of layers to achieve a good balance between speed and performance. Both **Depth-Adaptive Transformer** [(Elabyad et al. 2020)](https://arxiv.org/abs/1910.10073) and **Confident Adaptive Language Model** (CALM; [Schuster et al. 2022](https://arxiv.org/abs/2207.07061)) are motivated by this idea and learn to predict optimal numbers of layers needed for different input tokens.


## Efficient Attention

The computation and memory cost of the vanilla Transformer grows quadratically with sequence length and hence it is hard to be applied on very long sequences. Many efficiency improvements for Transformer architecture have something to do with the self-attention module - making it cheaper, smaller or faster to run. See the survey paper on Efficient Transformers [(Tay et al. 2020)](https://arxiv.org/abs/2009.06732).

### Sparse Attention Patterns

#### Fixed Local Context

A simple alternation to make self-attention less expensive is to restrict the attention span of each token to local context only, so that self-attention grows linearly with the sequence length.

The idea was introduced by **Image Transformer** [(Parmer, et al 2018)](https://arxiv.org/abs/1802.05751), which formulates image generation as sequence modeling using an encoder-decoder transformer architecture:
- The encoder generates a contextualized, per-pixel-channel representation of the source image;
- Then the decoder autoregressively generates an output image, one channel per pixel at each time step.

#### Strided Context

**Sparse Transformer** [(Child et al., 2019)](https://arxiv.org/abs/1904.10509) introduced factorized self-attention, through sparse matrix factorization. It also proposed a set of changes so as to train the Transformer up to hundreds of layers, including gradient checkpointing, recomputing attention & FF layers during the backward pass, mixed precision training, efficient block-sparse implementation, etc. Please check the paper for more details or the blog post ([techniques for scaling up model training](https://lilianweng.github.io/posts/2021-09-25-train-large/)).

**Blockwise Attention** [(Qiu et al. 2019)](https://arxiv.org/abs/1911.02972) introduces a sparse block matrix to only allow each token to attend to a small set of other tokens. 

#### Combination of Local and Global Context

**ETC** (Extended Transformer Construction; [Ainslie et al. 2019](https://arxiv.org/abs/2004.08483)), **Longformer** [(Beltagy et al. 2020)](https://arxiv.org/abs/2004.05150) and Big Bird [(Zaheer et al. 2020)](https://arxiv.org/abs/2007.14062) models combine both local and global context when building an attention matrix. All these models can be initialized from existing pretrained models.

One more update in ETC is to incorporate a CPC (contrastive predictive coding) task using NCE loss into the pretraining stage, besides the MLM task: The representation of one sentence should be similar to the representation of context around it when this sentence is masked.

Attention pattern in Longformer contains three components:
1. Local attention: Similar to ETC, local attention is controlled by a sliding window of fixed size w
2. Global attention of preselected tokens: Longformer has a few pre-selected tokens (e.g. [CLS] token) assigned with global attention span, that is, attending to all other tokens in the input sequence.
3. Dilated attention: Dilated sliding window of fixed size r and gaps of dilation size d, similar to Sparse Transformer

Big Bird is quite similar to Longformer, equipped with both local attention and a few preselected tokens with global attention span, but Big Bird replaces dilated attention with a new mechanism where all tokens attend to a set of random tokens. The design is motivated by the fact that attention pattern can be viewed as a directed graph and a random graph has the property that information is able to rapidly flow between any pair of nodes.

### Content-based Attention

The improvements proposed by **Reformer** [(Kitaev, et al. 2020)](https://arxiv.org/abs/2001.04451) aim to solve the following pain points in vanilla Transformer:
- Quadratic time and memory complexity within self-attention module.
- Memory in a model with N layers is N-times larger than in a single-layer model because we need to store activations for back-propagation.
- The intermediate FF layers are often quite large.

Reformer proposed two main changes:
1. Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from $O(L^2)$ to $O(L log L)$.
2. Replace the standard residual blocks with reversible residual layers, which allows storing activations only once during training instead of N times (i.e. proportional to the number of layers).

In order to find nearest neighbors quickly in high-dimensional space, Reformer incorporates Locality-Sensitive Hashing (LSH) into its attention mechanism.

Another improvement by Reformer is to use reversible residual layers [(Gomez et al. 2017)](https://arxiv.org/abs/1707.04585). The motivation for reversible residual network is to design the architecture in a way that activations at any given layer can be recovered from the activations at the following layer, using only the model parameters. Hence, we can save memory by recomputing the activation during backprop rather than storing all the activations.

**Routing Transformer** [(Roy et al. 2021)](https://arxiv.org/abs/2003.05997) is also built on content-based clustering of keys and queries. Instead of using a static hashing function like LSH, it utilizes online k-means clustering and combines it with local, temporal sparse attention to reduce the attention complexity from $O(L^2)$ to $O(L^{1.5})$.

### Low-Rank Attention

**Linformer** [(Wang et al. 2020)](https://arxiv.org/abs/2006.04768) approximates the full attention matrix with a low rank matrix, reducing the time & space complexity to be linear.

Additional techniques can be applied to further improve efficiency of Linformer:
- Parameter sharing between projection layers, such as head-wise, key-value and layer-wise (across all layers) sharing.
- Use different k at different layers, as heads in higher layers tend to have a more skewed distribution (lower rank) and thus we can use smaller k at higher layers.
- Use different types of projections; e.g. mean/max pooling, convolution layer with kernel and stride L/k.

**Random Feature Attention** (RFA; [Peng et al. 2021](https://arxiv.org/abs/2103.02143)) relies on random feature methods (Rahimi & Recht, 2007) to approximate softmax operation in self-attention with low rank feature maps in order to achieve linear time and space complexity. **Performers** [(Choromanski et al. 2021)](https://arxiv.org/abs/2009.14794) also adopts random feature attention with improvements on the kernel construction to further reduce the kernel approximation error.


## Transformers for Reinforcement Learning

The self-attention mechanism avoids compressing the whole past into a fixed-size hidden state and does not suffer from vanishing or exploding gradients as much as RNNs. Reinforcement Learning tasks can for sure benefit from these traits. However, it is quite difficult to train Transformer even in supervised learning, let alone in the RL context. It could be quite challenging to stabilize and train a LSTM agent by itself, after all.

The **Gated Transformer-XL** (GTrXL; [Parisotto, et al. 2019](https://arxiv.org/abs/1910.06764)) is one attempt to use Transformer for RL. GTrXL succeeded in stabilizing training with two changes on top of Transformer-XL:
1. The layer normalization is only applied on the input stream in a residual module, but NOT on the shortcut stream. A key benefit to this reordering is to allow the original input to flow from the first to last layer.
2. The residual connection is replaced with a GRU-style (Gated Recurrent Unit; Chung et al., 2014) gating mechanism.

**Decision Transformer** (DT; [Chen et al 2021](https://arxiv.org/abs/2106.01345)) formulates Reinforcement Learning problems as a process of conditional sequence modeling, outputting the optimal actions conditioned on the desired return, past states and actions. It therefore becomes straightforward to use Transformer architecture. Decision Transformer is for off-policy RL, where the model only has access to a fixed collection of trajectories collected by other policies.
