## Transformer101

Basic Transformer implementation in Pytorch.

In [3]:
import torch
import torch.nn as nn

### 1. Self-Attention

The intuition behind ``self-attention`` is that averaging token embeddings instead of using a fixed embedding for each token, enables the model to capture how words relate to each other in the input. In practice, said weighted relationships (attention weights) represent the syntactic and contextual structure of the sentence, leading to a more nuanced and rich understanding of the data.

The most common way to implement a self-attention layer is based on ``scaled dot-product attention``, which involves:
1. ``Linear projection`` of each token embedding into three vectors: ``query (q)``, ``key (k)``, ``value (v)``.
2. Compute ``scaled attention scores``: we determine the similary between ``q`` and ``k`` by applying the dot product. Since the results of this function are typically large numbers, they are then divided by a scaling factor inferred from the dimensionality of (k). This scaling is used to stabilize gradients during training.
3. Normalize the ``attention scores`` into ``attention weights`` by applying softmax (this ensures all the values sum to 1).
4. ``Update the token embeddings`` by multiplying the attention weights by the value vector.

In [14]:
class AttentionHead(nn.Module):
    """
    Represents a single attention head within a multi-head attention mechanism.
    
    Parameters:
    embed_dim (int): The size of the input feature dimension.
    head_dim (int): The size of the output feature dimension for this attention head.
    """
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        # ----------------------------------------------------------------------------------
        # step 1: linear projection of tokens embeddings into query, key, and value vectors
        # nn.Linear(in_features, out_features, bias=True) creates a linear transformation
        # wherein each input vector of size embed_dim will be transformed into a 
        # lower-dimensional vector of size head_dim.
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)

    def scaled_dot_product_attention(self, q, k, v):
        """
        Computes the scaled dot-product attention.

        Parameters:
        q, k, v (torch.Tensor): Query, Key, and Value tensors.

        Returns:
        torch.Tensor: Output after applying attention mechanism.
        """
        dim_k = torch.tensor(k.size(-1), dtype=torch.float32)
        # ----------------------------------------------------------------------------------
        # step 2: calculate the scaled attention scores
        # torch.bmm performs a batch matrix-matrix product of q and k.
        # we then apply the scaling factor 1/sqrt(k_dim) to said dot product.
        scaled_attention_scores = torch.bmm(q, k.transpose(1, 2)) / torch.sqrt(dim_k)
        # ----------------------------------------------------------------------------------
        # step 3: apply a softmax function to obtain the attention weights
        attention_weights = torch.softmax(scaled_attention_scores, axis=-1)
        # ----------------------------------------------------------------------------------
        # step 4: update tokens embeddings by applying attention weights to the value vector
        output = torch.bmm(attention_weights, v)
        return output

    def forward(self, hidden_state):
        """
        Defines the forward pass of the AttentionHead.

        Parameters:
        hidden_state (torch.Tensor): The input tensor.

        Returns:
        torch.Tensor: The output tensor after attention is applied.
        """
        attn_outputs = self.scaled_dot_product_attention(self.q(hidden_state),
                                                         self.k(hidden_state), 
                                                         self.v(hidden_state))
        return attn_outputs

### 2. Multi-Headed Attention

In a standard attention mechanism, the softmax of a single head tends to concentrate on a specific aspect of similarity, potentially overlooking other relevant features in the input. By integrating multiple attention heads, the model gains the ability to simultaneously attend to various aspects of the input data such as:
- semantic meaning of words
- grammatical relationships
- tone or sentiment
- intended modality
- idiomatic expressions
- [...]

In [15]:
class MultiHeadAttention(nn.Module):
    """
    Implements the Multi-Head Attention mechanism.

    Parameters:
    embed_dim (int): The size of the input feature dimension.
    num_heads (int): The number of attention heads.
    """
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        # ----------------------------------------------------------------------------------
        # step 1: initialize the attention heads
        # E.g. BERT has 12 attention heads whereas the embeddings dimension is 768 
        # resulting in 768 / 12 = 64 as the head dimension
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads"
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        # ----------------------------------------------------------------------------------
        # step 2: prepare linear transformation
        # combines the outputs of the attention heads into a single vector while preserving
        # the dimensionality of the embeddings
        self.output_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_state):
        """
        Defines the forward pass of the MultiHeadAttention.

        Parameters:
        hidden_state (torch.Tensor): The input tensor.

        Returns:
        torch.Tensor: The output tensor after multi-head attention is applied.
        """
        # ----------------------------------------------------------------------------------
        # step 3: concatenate attention heads
        attn_outputs = torch.cat([head(hidden_state) for head in self.heads], dim=-1)
        # ----------------------------------------------------------------------------------
        # step 4: linear projection of concatenated attention heads
        outputs = self.output_linear(attn_outputs)
        return outputs

In [16]:
multihead_attn = MultiHeadAttention(embed_dim=768, num_heads=12)
attn_outputs = multihead_attn(torch.rand(1, 10, 768))
attn_outputs.size()

torch.Size([1, 10, 768])

### 3. Position-Wise Feed-Forward Layer

The Transformer, primarily built upon linear operations like dot products and linear projections, relies on the Position-Wise Feed-Forward Layer to introduce non-linearity into the model. This non-linearity enables the model to capture complex data patterns and relationships. The layer typically consists of two linear transformations with a non-linear activation function (like ReLU or GELU). Each layer in the Encoder and Decoder includes one of these feed-forward networks, allowing the model to build increasingly abstract representations of the input data as it passes through successive layers.

Note that since this layer processes each embedding independly, the computations can be fully parallelized.

In [19]:
class PositionWiseFeedForward(nn.Module):
    """
    Implements the PositionWiseFeedForward layer.

    Parameters:
    embed_dim (int): The size of the input feature dimension.
    ff_dim (int): The size of the hidden layer dimension.
    """
    def __init__(self, embed_dim, ff_dim):
        super().__init__()
        self.linear_1 = nn.Linear(embed_dim, ff_dim)
        self.linear_2 = nn.Linear(ff_dim, embed_dim)
        self.activation = nn.GELU()

    def forward(self, hidden_state):
        """
        Defines the forward pass of the PositionWiseFeedForward layer.

        Parameters:
        x (torch.Tensor): The input tensor.

        Returns:
        torch.Tensor: The output tensor after applying the feed-forward network.
        """
        # This layer applies a linear transformation to the input tensor. 
        # It's a fully connected layer where each input is connected to every output by a learned weight.
        x = self.linear_1(hidden_state)
        # This step introduces non-linearity into the model, allowing it to learn more complex patterns.
        x = self.activation(x)
        # This layer applies another linear transformation to the output of the previous layer,
        # potentially increasing the model's capacity to learn complex relationships in the data.
        x = self.linear_2(x)
        # Dropout is a regularization technique used to prevent overfitting. It randomly zeroes 
        # some of the elements of the input tensor with a certain probability during training.
        return nn.Dropout(0.1)(x)

        # return self.dropout(self.linear_2(self.activation(self.linear_1(hidden_state))))

In [20]:
feed_forward = PositionWiseFeedForward(embed_dim=768, ff_dim=3072)
feed_forward_outputs = feed_forward(torch.rand(1, 10, 768))
feed_forward_outputs.size()

torch.Size([1, 10, 768])

### 4. Positional Embeddings

[...]

### 5. Decoder

[...]

### 6. Encoder

[...]