## Transformer101

Basic Transformer implementation in Pytorch.

In [3]:
import torch
import torch.nn as nn

### 1. Self-Attention

The intuition behind ``self-attention`` is that averaging token embeddings instead of using a fixed embedding for each token, enables the model to capture how words relate to each other in the input. In practice, said weighted relationships (attention weights) represent the syntactic and contextual structure of the sentence, leading to a more nuanced and rich understanding of the data.

The most common way to implement a self-attention layer is based on ``scaled dot-product attention``, which involves:
1. ``Linear projection`` of each token embedding into three vectors: ``query (q)``, ``key (k)``, ``value (v)``.
2. Compute ``scaled attention scores``: we determine the similary between ``q`` and ``k`` by applying the dot product. Since the results of this function are typically large numbers, they are then multiplied by a scaling factor inferred from the dimensionality of (k).
3. Normalize the ``attention scores`` into ``attention weights`` by applying softmax (this ensures all the values sum to 1).
4. ``Update the token embeddings`` by multiplying the attention weights by the value vector.

In [14]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        # ----------------------------------------------------------------------------------
        # step 1: linear projection of tokens embeddings into query, key, and value vectors
        # nn.Linear(in_features, out_features, bias=True) creates a linear transformation.
        # self.x is an instance of nn.Linear, this allows the transformation to be applied 
        # by calling self.x(input).
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)

    def scaled_dot_product_attention(self, q, k, v):
        dim_k = torch.tensor(k.size(-1), dtype=torch.float32)
        # ----------------------------------------------------------------------------------
        # step 2: calculate the scaled attention scores
        # torch.bmm performs a batch matrix-matrix product of q and k.
        # we then apply the scaling factor 1/sqrt(k_dim) to said dot product.
        scaled_attention_scores = torch.bmm(q, k.transpose(1, 2)) / torch.sqrt(dim_k)
        # ----------------------------------------------------------------------------------
        # step 3: apply a softmax function to obtain the attention weights
        attention_weights = torch.softmax(scaled_attention_scores, axis=-1)
        # ----------------------------------------------------------------------------------
        # step 4: update tokens embeddings by applying attention weights to the value vector
        output = torch.bmm(attention_weights, v)
        return output

    def forward(self, x):
        attn_outputs = self.scaled_dot_product_attention(self.q(x), self.k(x), self.v(x))
        return attn_outputs

### 2. Multi-Headed Attention

In a standard attention mechanism, the softmax of a single head tends to concentrate on a specific aspect of similarity, potentially overlooking other relevant features in the input. By integrating multiple attention heads, the model gains the ability to simultaneously attend to various aspects of the input data such as:
- semantic meaning of words
- grammatical relationships
- tone or sentiment
- intended modality
- idiomatic expressions
- [...]

In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        # ----------------------------------------------------------------------------------
        # step 1: initialize the attention heads
        # E.g. BERT has 12 attention heads whereas the embeddings dimension is 768 
        # resulting in 786/12 = 64 as the head dimension
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList([AttentionHead(embed_dim, head_dim) for _ in range(num_heads)])
        # ----------------------------------------------------------------------------------
        # step 2: prepare linear transformation
        self.output_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_state):
        # ----------------------------------------------------------------------------------
        # step 3: concatenate attention heads
        attn_outputs = torch.cat([head(hidden_state) for head in self.heads], dim=-1)
        # ----------------------------------------------------------------------------------
        # step 4: linear projection of concatenated attention heads
        outputs = self.output_linear(attn_outputs)
        return outputs

In [16]:
multihead_attn = MultiHeadAttention(embed_dim=768, num_heads=12)
attn_outputs = multihead_attn(torch.rand(1, 10, 768))
attn_outputs.size()

torch.Size([1, 10, 768])