## Transformer101

Basic Transformer implementation in Pytorch.

In [None]:
import torch
import torch.nn as nn

### 1. Self-Attention

The intuition behind ``self-attention`` is that averaging token embeddings instead of using a fixed embedding for each token, enables the model to capture how words relate to each other in the input. In practice, said weighted relationships (attention weights) represent the syntactic and contextual structure of the sentence, leading to a more nuanced and rich understanding of the data.

The most common way to implement a self-attention layer is based on ``scaled dot-product attention``, which involves:
1. ``Linear projection`` of each token embedding into three vectors: ``query (q)``, ``key (k)``, ``value (v)``.
2. Compute ``scaled attention scores``: we determine the similary between ``q`` and ``k`` by applying the dot product. Since the results of this function are typically large numbers, they are then multiplied by a scaling factor inferred from the dimensionality of (k).
3. Normalize the ``attention scores`` into ``attention weights`` by applying softmax (this ensures all the values sum to 1).
4. ``Update the token embeddings`` by multiplying the attention weights by the value vector.

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        # step 1: linear projection of tokens embeddings into query, key, and value vectors
        #
        # nn.Linear(in_features, out_features, bias=True) creates a linear transformation.
        # self.x is an instance of nn.Linear, which allows the transformation to be applied later
        # by calling self.x(input).
        self.q = nn.Linear(embed_dim, head_dim)
        self.k = nn.Linear(embed_dim, head_dim)
        self.v = nn.Linear(embed_dim, head_dim)

    def scaled_dot_product_attention_1(q, k, v):
        # step 2: calculate the scaled attention scores
        # torch.bmm performs a batch matrix-matrix product of matrices stored in q and k.
        # we then apply the scaling factor 1/sqrt(k_dim) to the dot product of q and k.
        scaled_attention_logits = torch.bmm(q, k.transpose(1, 2)) / torch.sqrt(k.size(-1))
        # step 3: apply a softmax function to obtain the attention weights
        attention_weights = torch.softmax(scaled_attention_logits, axis=-1)
        # step 4: update the token embeddings by applying the attention weights to the value vectors
        return torch.bmm(attention_weights, v)

    def forward(self, x):
        attn_outputs = self.scaled_dot_product_attention(self.q(x), self.k(x), self.v(x))
        return attn_outputs