## Transformer101

Vanilla Transformer implementation in Pytorch based on the paper [Attention Is All You Need, 2017](https://arxiv.org/pdf/1706.03762.pdf) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and Illia Polosukhin.

> Scaled Dot-Product Attention | Multi-Head Attention | Absolute Positional Encodings | Learned Positional Encodings | Dropout | Layer Normalization | Residual Connection | Linear Layer | Position-Wise Feed-Forward Layer | GELU | Softmax | Transformer

### 1. Background

Sequence modeling and transduction tasks, such as language modeling and machine translation, were typically addressed with RNNs and CNNs. However, these architectures are limited by: (i) ``long training times``, due to the sequential nature of RNNs, which constrains parallelization, and results in increased memory and computational demands as the text sequence grows; and (ii) ``difficulty in learning dependencies between distant positions``, where CNNs, although much less sequential than RNNs, require a number of steps to integrate information that is, in most cases, correlated (linearly for models like ConvS2S and logarithmically for ByteNet) with the distance between elements in the sequence.

### 2. Technical Approach

This paper introduces the novel Transformer model, ``a stacked encoder-decoder architecture that utilizes self-attention mechanisms instead of recurrence and convolution to compute input and output representations``. In this model, each of the six layers of both the encoder and decoder is composed of two main sub-layers: a multihead self-attention sub-layer, which allows the model to focus on different parts of the input sequence, and a position-wise fully connected feed-forward sub-layer.

At its core, the ``self-attention mechanism`` enables the model to weigh the relationships between input tokens at different positions, resulting in a more effective handling of long-range dependencies. Additionally, by integrating ``multiple attention heads``, the model gains the ability to simultaneously attend to various aspects of the input data during training.

In the proposed implementation, the input and output tokens are converted to 512-dimensional embeddings, to which ``positional embeddings`` are added, enabling the model to use sequence order information.

### 3. Implementation

In [2]:
import torch
import torch.nn as nn

#### 3.1 Self-Attention

The intuition behind ``self-attention`` is that averaging token embeddings instead of using a fixed embedding for each token, enables the model to capture how words relate to each other in the input. In practice, said weighted relationships (attention weights) represent the syntactic and contextual structure of the sentence, leading to a more nuanced and rich understanding of the data.

The most common way to implement a self-attention layer is based on ``scaled dot-product attention``, which involves:
1. ``Linear projection`` of each token embedding into three vectors: ``query (q)``, ``key (k)``, ``value (v)``.
2. Compute ``scaled attention scores``: determine the similary between ``q`` and ``k`` by applying the ``dot product``. Since the results of this function are typically large numbers, they are then divided by a scaling factor inferred from the dimensionality of (k). This scaling contributes to stabilize gradients during training.
3. Normalize the ``attention scores`` into ``attention weights`` by applying the ``softmax`` function (this ensures all the values sum to 1).
4. ``Update the token embeddings`` by multiplying the attention weights by the value vector.

> A mask will be used only in the decoder layer. The main idea is to prevent the decoder from having access to future tokens in the sequence it is generating. In practice, this is implemented with a binary mask that designates which tokens should be attended to (assigned non-zero weights) and which should be ignored (assigned zero weights). In our function, setting the future tokens (upper values) to negative infinity guarantees that the attention weights become zero after applying the softmax function (e exp -inf == 0). This design aligns with the nature of many tasks like translation, summarization, or text generation, where the output sequence needs to be generated one element at a time, and the prediction of each element should be based only on the previously generated elements.

In [3]:
class AttentionHead(nn.Module):
    """
    Represents a single attention head within a multi-head attention mechanism.
    
    Parameters:
    n_embd (int): The size of the input feature dimension.
    head_dim (int): The size of the output feature dimension for this attention head.
    """
    def __init__(self, n_embd, n_headd):
        super().__init__()
        # step_1: linear projections to query (q), key (k), and value (v) vectors
        self.q = nn.Linear(n_embd, n_headd)
        self.k = nn.Linear(n_embd, n_headd)
        self.v = nn.Linear(n_embd, n_headd)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        dim_k = torch.tensor(k.size(-1), dtype=torch.float32)
        # step_2: calculate similarity with the dot product, and scale attention scores
        attn_scores = torch.bmm(q, k.transpose(1, 2)) / torch.sqrt(dim_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        # step_3: normalize the attention scores with the softmax function
        attn_weights = torch.softmax(attn_scores, axis=-1)
        # step_4: update the token embeddings by multiplying attention weights by the value vector
        output = torch.bmm(attn_weights, v)
        return output

    def forward(self, hidden_state):
        attn_outputs = self.scaled_dot_product_attention(self.q(hidden_state),
                                                         self.k(hidden_state), 
                                                         self.v(hidden_state))
        return attn_outputs

#### 3.2 Multi-Headed Attention

In a standard attention mechanism, the ``softmax`` of a single head tends to concentrate on a specific aspect of similarity, potentially overlooking other relevant features in the input. By integrating multiple attention heads, the model gains the ability to simultaneously attend to various aspects of the input data.

The basic approach to implement Multi-Headed Attention comprises:

1. Initialize the ``attention heads``. E.g. BERT has 12 attention heads whereas the embeddings dimension is 768, resulting in 768 / 12 = 64 as the head dimension.
2. ``Concatenate attention heads`` to combines the outputs of the attention heads into a single vector while preserving the dimensionality of the embeddings.
3. Apply a ``linear projection``.

> Note that the softmax function is a probability distribution, which when applied within a single attention head tends to amplify certain features (those with higher scores) while diminishing others. Thus, leading to a focus on specific aspects of similarity.

In [4]:
class MultiHeadAttention(nn.Module):
    """
    Implements the Multi-Head Attention mechanism.

    Parameters:
    n_embd (int): The size of the input feature dimension.
    n_head (int): The number of attention heads.
    """
    def __init__(self, n_embd, n_head):
        if n_embd < 0 or n_head < 0:
            raise ValueError("Embedding dimension and number of heads must be greater than 0")
        
        super().__init__()
        # step_1: initialize the attention heads
        assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
        head_dim = n_embd // n_head
        self.heads = nn.ModuleList([AttentionHead(n_embd, head_dim) for _ in range(n_head)])
       
        self.output_linear = nn.Linear(n_embd, n_embd)

    def forward(self, hidden_state):
        # step_2: concatenate attention heads
        attn_outputs = torch.cat([head(hidden_state) for head in self.heads], dim=-1)
        # step_3: apply linear projection
        outputs = self.output_linear(attn_outputs)
        return outputs

In [5]:
multihead_attn = MultiHeadAttention(n_embd=768, n_head=12)
attn_outputs = multihead_attn(torch.rand(1, 10, 768))
attn_outputs.size()

torch.Size([1, 10, 768])

#### 3.3 Position-Wise Feed-Forward Layer

The Transformer, primarily built upon linear operations like ``dot products`` and ``linear projections``, relies on the ``Position-Wise Feed-Forward Layer`` to introduce non-linearity into the model. This non-linearity enables the model to capture complex data patterns and relationships. The layer typically consists of two linear transformations with a ``non-linear activation function (like ReLU or GELU)``. Each layer in the ``Encoder`` and ``Decoder`` includes one of these feed-forward networks, allowing the model to build increasingly abstract representations of the input data as it passes through successive layers. Note that since this layer processes each embedding independly, the computations can be fully parallelized.

In summary, this Layer comprises:
- ``First linear transformation`` to the input tensor. 
- A non-linear ``activation function`` to allow the model learn more complex patterns.
- ``Second linear transformation``, increasing the model's capacity to learn complex relationships in the data.
- ``Dropout``, a regularization technique used to prevent overfitting. It randomly zeroes some of the elements of the input tensor with a certain probability during training.

> Note that the ``ReLU`` function is a faster function that activates units only when the input is possitive, which can lead to sparse activations (that can be intended in some tasks); whereas ``GELU``, introduced after``ReLU``, offers smoother activation by modeling the input as a stochastic process, providing a probabilistic gate in the activation. In practice, ``GELU`` has been the preferred choice in the BERT and GPT models.

In [20]:
class PositionWiseFeedForward(nn.Module):
    """
    Implements the PositionWiseFeedForward layer.

    Parameters:
    n_embd (int): The size of the input feature dimension.
    ff_dim (int): The size of the hidden layer dimension.
    """
    def __init__(self, n_embd, ff_dim):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(n_embd, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, n_embd),
            nn.Dropout(0.1)
        )

    def forward(self, hidden_state):
        return self.ff(hidden_state)

In [21]:
feed_forward = PositionWiseFeedForward(n_embd=768, ff_dim=3072)
feed_forward_outputs = feed_forward(torch.rand(1, 10, 768))
feed_forward_outputs.size()

torch.Size([1, 10, 768])

#### 3.4 Positional Encoding

Since the Transformer model contains no recurrence and no convolution, the model is invariant to the position of the tokens. By adding ``positional encoding`` to the input sequence, the Transformer model can differentiate between tokens based on their position in the sequence, which is important for tasks such as language modeling and machine translation. In practice, ``positional encoding`` are added to the input embeddings at the bottoms of the ``encoder`` and ``decoder`` stacks. 

> As outlined in the original [Attention Is All you Need](https://arxiv.org/pdf/1706.03762.pdf) paper, ``Sinusoidal Positional Encoding`` and ``Learned Positional Encoding`` produces nearly identical results.



In [8]:
class SinusoidalPositionalEncoding(nn.Module):
    """
    Implements Sinusoidal Positional Encoding.

    Parameters:
    embed_size (int): The size of the input feature dimension.
    """
    def __init__(self, n_embd):
        super().__init__()
        self.n_embd = n_embd

    def forward(self, max_seq_len):
        pos = torch.arange(max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.n_embd, 2) * -(torch.log(torch.tensor(10000.0)) / self.n_embd))
        
        pe = torch.zeros(max_seq_len, self.n_embd)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)

        return pe

In [9]:
pos_encoder = SinusoidalPositionalEncoding(n_embd=768)
pos_encoding = pos_encoder(max_seq_len=100)

pos_encoding.shape  # Should be [max_seq_len, n_embd]

torch.Size([100, 768])

In [10]:
class LearnedPositionalEncoding(nn.Module):
    """
    Implements the LearnedPositionalEncoding layer.

    Parameters:
    max_seq_len (int): The maximum sequence length.
    n_embd (int): The size of the input feature dimension.
    """
    def __init__(self, max_seq_len, n_embd):
        super().__init__()
        self.position_embeddings = nn.Embedding(max_seq_len, n_embd)
        self.dropout = nn.Dropout(0.1)

    def forward(self, hidden_state):
        embeddings = hidden_state + self.position_embeddings(torch.arange(hidden_state.size(1), 
                                                             device=hidden_state.device))
        return self.dropout(embeddings)

In [11]:
encoding_layer = LearnedPositionalEncoding(max_seq_len=10, n_embd=768)
encoding_outputs = encoding_layer(torch.rand(1, 10, 768))
encoding_outputs.size()

torch.Size([1, 10, 768])

#### 3.5 Encoder

Each of the six layers of both the encoder and decoder is composed of two main sub-layers: a ``multihead self-attention`` sub-layer, which as explained hereinabove allows the model to focus on different parts of the input sequence, and a ``position-wise fully connected feed-forward`` sub-layer. In addition, the model employs a ``residual connection`` around each of the two sub-layers, followed by ``layer normalization``. In our case, we implement pre layer (instead of post layer) normalization with ``Dropout`` regularization to favour stability during training and prevent overfitting, respectively.

> ``layer normalization`` contributes to having zero mean and unitity variance. This helps to stabilize the learning process and reducing the number of training steps.

> ``residual connection`` or ``skip connection`` helps alleaviate the problem of vanishing gradients by passing a tensor to the next layer of the model without processing it and adding it to the processes tensor. In other words, the output of each sub-layer is
LayerNorm(x + Sublayer(x)), where Sublayer(x) is the function implemented by the sub-layer
itself.

In [12]:
class EncoderLayer(nn.Module):
    """
    Implements a single encoder layer.

    Parameters:
    n_embd (int): The size of the input feature dimension.
    n_head (int): The number of attention heads.
    ff_dim (int): The size of the hidden layer dimension.
    """
    def __init__(self, n_embd, n_head, ff_dim):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(n_embd)
        self.multihead_attn = MultiHeadAttention(n_embd, n_head)

        self.layer_norm_2 = nn.LayerNorm(n_embd)
        self.feed_forward = PositionWiseFeedForward(n_embd, ff_dim)
        
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, hidden_state):
        attn_outputs = self.multihead_attn(self.layer_norm_1(hidden_state))
        hidden_state = hidden_state + self.dropout(attn_outputs)

        output = hidden_state + self.dropout(self.feed_forward(self.layer_norm_2(hidden_state)))
        return output

In [13]:
encoder_layer = EncoderLayer(n_embd=768, n_head=12, ff_dim=3072)
encoder_layer(torch.rand(1, 10, 768)).size()

torch.Size([1, 10, 768])

In [18]:
class Encoder(nn.Module):
    """
    Implements the Encoder.

    Parameters:
    n_embd (int): The size of the input feature dimension.
    n_head (int): The number of attention heads.
    ff_dim (int): The size of the hidden layer dimension.
    """
    def __init__(self, n_embd, n_head, ff_dim, n_layer=6):
        super().__init__()
        self.positional_encoding = LearnedPositionalEncoding(100, n_embd)
        self.layers = nn.ModuleList([EncoderLayer(n_embd, n_head, ff_dim) for _ in range(n_layer)])

    def forward(self, hidden_state):
        hidden_state = self.positional_encoding(hidden_state)
        for layer in self.layers:
            hidden_state = layer(hidden_state)
        return hidden_state

In [19]:
encoder = Encoder(n_embd=768, n_head=12, ff_dim=3072, n_layer=6)
encoder(torch.rand(1, 10, 768)).size()

torch.Size([1, 10, 768])

#### 3.6 Decoder

[...]

#### 3.7 Transformer

[...]