# Transformer implementation

The main four steps in transformer architecture are the following:
1. Embedding each token/feature/word into a vector, stack these vector so they form an **input**.
2. Attention mechanism - the core of the approach, it figures out the meaning of each token/feature/word based on the context.
3. MLP (Multi-Layered Perceptron), or Feed-Forward Layer - stores the facts and some relations between them.
4. Unembedding - transforms the output of the last layer back into token/feature/word space.

## Attention mechanism
### Notations

Symbol     | Meaning 
---------  | ------- 
$d$        | The model size, or embedding/positional encoding size
$d_k, d_v$ | The per-head key, query and value dimensions
$N$        | The sequence length of an input sequence (context window)
$H$        | The number of heads in multi-head attention layer
$\mathbf{X} \in \mathbb{R}^{N \times d}$ | The input sequence after mapping each element into an embedding vector
$\mathbf{W}_k \in \mathbb{R}^{d \times d_k}$ | The key weight matrix
$\mathbf{W}_q \in \mathbb{R}^{d \times d_k}$ | The query weight matrix
$\mathbf{W}_v \in \mathbb{R}^{d \times d_v}$ | The value weight matrix
$\mathbf{W}_o \in \mathbb{R}^{d_v \times d}$ | The output weight matrix.
$\mathbf{K} = \mathbf{X} \mathbf{W}_k,~\mathbf{K} \in \mathbb{R}^{N \times d_k}$ | The key embedding input, maps input to a key
$\mathbf{Q} = \mathbf{X} \mathbf{W}_q,~\mathbf{Q} \in \mathbb{R}^{N \times d_k}$ | The query embedding input, maps input to query
$\mathbf{V} = \mathbf{X} \mathbf{W}_v,~\mathbf{V} \in \mathbb{R}^{d_k \times d}$ | The value embedding input
$\mathbf{A} = softmax\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} \right), \mathbf{A} \in \mathbb{R}^{N \times N}$ | The self-attention matrix
$\mathbf{a}_{ij} \in R$ | The dot-product between $\mathbf{Q}_i$ and $\mathbf{K}_j$


Internally, a single attention-mechanism cycle uses the following data/brains to process the input:
- Key weight matrix $\mathbf{W}_k \in \mathbb{R}^{d \times d_k}$.
- Query weight matrix $\mathbf{W}_q \in \mathbb{R}^{d \times d_k}$.
- Values weight matrix $\mathbf{W}_v \in \mathbb{R}^{d \times d_v}$.

**Note**: most DL code uses tensors shaped $(B, N, d)$, where $B$ is batch size, $N$ is sequence length and $d$ is the dimensionality of an embedding vector. So if you drop the batch, then features are in the last dimension.

*Why is it a common approach?*
- First, it matches `nn.Linear` expecting features in the last dim.
- Second, batching and multi-head layout is cleaner: $(N, h, N, d_k) \to (B, h, N, N)$.

The processing pipeline looks as follows:
1. Take the stacked input $\mathbf{X} \in \mathbb{R}^{N \times d}$.
2. For each head calculate keys $\mathbf{K}_i$ and queries $\mathbf{Q}_i$ for an input token $\mathbf{x}$ as:
  - $\mathbf{K}_i = \mathbf{x} \mathbf{W}_k,~\mathbf{K}_i \in \mathbb{R}^{1 \times d_k}$
  - $\mathbf{Q}_i = \mathbf{x} \mathbf{W}_q,~\mathbf{Q}_i \in \mathbb{R}^{1 \times d_k}$
3. Calculate a dot-product between each $\mathbf{K}_i$ and $\mathbf{Q}_i$. Dot product there tells how close are the query and the key and measures the similarity betweek the query and its key. The attention meachanism learns how different terms/queries/keys are related.
4. All those dot-products are kept in a single matrix $\mathbf{Q}\mathbf{K}^T \in \mathbb{R}^{N \times N}$.
5. In case of self-attention, the upper-diagonal part of the matrix is nulled: this step is known as masking, it acts as a switch, that prevents current token from knowing "future" tokens. Then each row is normalized by dividing its values by $\sqrt{d_k}$. It works as a temperature in softmax - you counteract the growth in dot-product magnitude:
$$
\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}}
$$
6. Apply softmax to transform this to probabilities (acts as normalization as well), and then multiply by $\mathbf{W}_v$ to get the corrections to our input (but these correction live in the head-block space).
$$
\mathbf{z} = softmax\left(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}} \right) \mathbf{V}, \mathbf{z} \in \mathbb{R}^{N \times d_v}.
$$
7. Then we apply output matrix $\mathbf{W}_o \in \mathbb{R}^{d_v \times d}$ that projects to model dimensionality (in case of multiple heads its shape is $(h d_v) \times d$ since we concatenate heads - $(\mathbf{H}_1, \dots, \mathbf{H}_h) \to \mathbf{H}_{cat} \in \mathbb{R}^{n\times(h\cdot d_v)}$:
$$
\Delta \mathbf{I} = \mathbf{z} \mathbf{W}_o,~\Delta \mathbf{I} \in \mathbb{R}^{n \times d}
$$

These steps complete attention mechanism, $\Delta \mathbf{I}$ is then added to an input of the layer and normalized.

**Note:** in real implementations, $\mathbf{W}_q$, $\mathbf{W}_k$, $\mathbf{W}_v$ matrices are often stacked for efficiency and have the same embedding dimension, then the output from them can be obtained in a single cycle by projecting intput through `nn.Linear(d, 3 * d_k)`. $\mathbf{W}_o$ is then also implemented as `nn.Linear` layer.

## Multi-layer Perceptron, or Fully-connected Feed Forward Network
After the input was processed in the attention sub-layer, the next step is to pass it through MLP layer.
This step is much more straightforward and consists of three sequential operations:
1. Apply first linear transformation: $\mathbf{y}_1 = \mathbf{x} \mathbf{W}_1 + \mathbf{b}_1$
2. Use ReLU activation: $\mathbf{a} = \max(0, \mathbf{y}_1)$
3. Apply second linear transformation: $\mathbf{y}_2 = \mathbf{a} \mathbf{W}_2 + \mathbf{b_2}$.

Inner-layer dimensionality can be selected arbitrary. In the case of GPT-3 it is $4 \times 12288$, four times the number of dimensions in the embedding space. So the sizes per token are:
- $\mathbf{W}_1 \in \mathbb{R}^{d \times d_{ff}}$, $\mathbf{b}_1 \in \mathbb{R}^{d_{ff}}$
- $\mathbf{W}_2 \in \mathbb{R}^{d_{ff} \times d}$, $\mathbf{b}_2 \in \mathbb{R}^d$

## Multihead Attention Implementation

In [1]:
import torch
from torch import nn

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, q_dim, k_dim, v_dim):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.q_dim = q_dim
        self.k_dim = k_dim
        self.v_dim = v_dim

        self.Wq = nn.Parameter(torch.empty(self.num_heads, self.embed_dim, self.q_dim))
        self.Wk = nn.Parameter(torch.empty(self.num_heads, self.embed_dim, self.k_dim))
        self.Wv = nn.Parameter(torch.empty(self.num_heads, self.embed_dim, self.v_dim))
        self.Wo = nn.Parameter(torch.empty(self.num_heads * self.v_dim, self.embed_dim))
        nn.init.xavier_uniform_(self.Wq)
        nn.init.xavier_uniform_(self.Wk)
        nn.init.xavier_uniform_(self.Wv)
        nn.init.xavier_uniform_(self.Wo)

    def forward(self, X, mask=None, return_attention = False):
        Q = torch.einsum('bne,hev->bhnv', X, self.Wq)
        K = torch.einsum('bne,hev->bhnv', X, self.Wk)
        V = torch.einsum('bne,hev->bhnv', X, self.Wv)

        logits = Q @ K.permute(0, 1, 3, 2) / (self.k_dim ** 0.5)
        A = torch.softmax(logits, dim = -1)

        Z = A @ V
        Y = torch.einsum("bhnv, hve -> bne", 
                         Z,
                         self.Wo.view(self.num_heads, self.v_dim, self.embed_dim))

        if return_attention:
            return Y, A
        else:
            return Y
        
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.normalized_shape = normalized_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, X):
        mean = X.mean(dim=-1, keepdim=True)
        var = X.var(dim=-1, unbiased=False, keepdim=True)
        X_hat = (X - mean) / torch.sqrt(var + self.eps)
        return self.gamma * X_hat + self.beta
    
class EncoderBlock(nn.Module):
    def __init__(self, input_dim, num_heads, q_dim, k_dim, v_dim, dim_feedforward, dropout_rate):
        super().__init__()
        self.mha = MultiheadAttention(input_dim, num_heads, q_dim, k_dim, v_dim)
        self.ffn = nn.Sequential(
            nn.Linear(input_dim, dim_feedforward),
            nn.Dropout(dropout_rate),
            nn.ReLU(),
            nn.Linear(dim_feedforward, input_dim),
            nn.Dropout(dropout_rate)
        )
        self.norm1 = LayerNorm(input_dim)
        self.norm2 = LayerNorm(input_dim)

    def forward(self, X):
        # Multi-head attention
        attn_output = self.mha(X)
        X = self.norm1(X + attn_output)  # Add & Norm

        # Feed-forward network
        ffn_output = self.ffn(X)
        X = self.norm2(X + ffn_output)   # Add & Norm

        return X

class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, input_dim, num_heads, q_dim, k_dim, v_dim, dim_feedforward, dropout_rate):
        super().__init__()
        self.layers = nn.ModuleList([
            EncoderBlock(input_dim, num_heads, q_dim, k_dim, v_dim, dim_feedforward, dropout_rate)
            for _ in range(num_layers)
        ])

    def forward(self, X):
        for layer in self.layers:
            X = layer(X)
        return X
    
def sinusoidal_positions_encoder(L, d):
    import torch, math
    pos = torch.arange(L).float().unsqueeze(1)        # [L,1]
    i = torch.arange(d//2).float().unsqueeze(0)       # [1,d/2]
    denom = torch.pow(10000, (2*i)/d)                 # [1,d/2]
    angles = pos / denom                               # [L,d/2]
    pe = torch.zeros(L, d)
    pe[:, 0::2] = torch.sin(angles)
    pe[:, 1::2] = torch.cos(angles)
    return pe  # [L,d]


## Experiments
The experiments block was borrowed from [there](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html).

### Sequence to Sequence
A Sequence-To-Sequence task represent a class of tasks where the input is given, and the task is to get the output of an arbitrary length. It can be machine translation or summarizaion problems.

**Problem**: given a sequence of N numbers between 0 and $M$ the task is to reverse it.

As usual, the first thing is to create a dataset.

In [2]:
from functools import partial

from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split

class Seq2SeqDataset(Dataset):
    def __init__(self, num_samples, min_length, max_length, num_classes):
        self.num_samples = num_samples
        self.num_classes = num_classes
        
        self.data = []
        self.targets = []
        
        for _ in range(num_samples):
            # Random length for each sample
            seq_len = torch.randint(min_length, max_length + 1, (1,)).item()
            seq = torch.randint(1, num_classes, (seq_len,))
            target = torch.flip(seq, dims=[0])
            
            self.data.append(seq)
            self.targets.append(target)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]
    
def collate_fn(batch):
    sequences, targets = zip(*batch)
    
    # Pad sequences to the same length in this batch
    max_len = max(len(seq) for seq in sequences)
    
    padded_seqs = torch.zeros(len(sequences), max_len, dtype=torch.long)
    padded_targets = torch.zeros(len(targets), max_len, dtype=torch.long)
    
    for i, (seq, target) in enumerate(zip(sequences, targets)):
        padded_seqs[i, :len(seq)] = seq
        padded_targets[i, :len(target)] = target
    
    return padded_seqs, padded_targets

class Model(nn.Module):
    def __init__(self, num_classes, embed_dim, num_encoder_layers, num_heads, q_dim, k_dim, v_dim,
                  dim_feedforward, dropout_rate):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, embed_dim)

        self.max_seq_length = 5000
        pe = sinusoidal_positions_encoder(self.max_seq_length, embed_dim)
        self.register_buffer('positional_encoding', pe)


        self.transformer_encoder = TransformerEncoder(
            num_layers=num_encoder_layers,
            input_dim=embed_dim,
            num_heads=num_heads,
            q_dim=q_dim,
            k_dim=k_dim,
            v_dim=v_dim,
            dim_feedforward=dim_feedforward,
            dropout_rate=dropout_rate
        )
        self.output_net = nn.Sequential(
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        seq_len = x.size(1)
        
        if seq_len > self.max_seq_length:
            raise ValueError(f"Sequence length {seq_len} exceeds maximum {self.max_seq_length}")
        
        x = self.embedding(x)
        x = x + self.positional_encoding[:seq_len, :].unsqueeze(0)
        x = self.transformer_encoder(x)
        x = self.output_net(x)
        return x


In [None]:
from torch import optim
import torch.nn.functional as F

def train(model, data_loader, epochs, loss_function, optimizer, test_loader=None):
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch, labels in data_loader:
            optimizer.zero_grad()

            preds = model(batch)
            loss = loss_function(preds.view(-1, preds.size(-1)), labels.view(-1))

            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if epoch % 3 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss/len(data_loader):.4f}")

        if test_loader is not None and epoch % 3 == 0:
            model.eval()
            with torch.no_grad():
                total_test_loss = 0
                for batch, labels in test_loader:
                    preds = model(batch)
                    loss = loss_function(preds.view(-1, preds.size(-1)), labels.view(-1))
                    total_test_loss += loss.item()
                print(f"Test Loss: {total_test_loss/len(test_loader):.4f}")

torch.manual_seed(15)

dataset = partial(Seq2SeqDataset, min_length=3, max_length=15, num_classes=20)
train_loader = DataLoader(dataset(40000), batch_size=128, shuffle=True, drop_last=True, collate_fn=collate_fn)
test_loader = DataLoader(dataset(1000), batch_size=128, shuffle=False, drop_last=True, collate_fn=collate_fn)

model = Model(num_classes=20, embed_dim=16, num_encoder_layers=4, num_heads=4,
              q_dim=16, k_dim=16, v_dim=16, dim_feedforward=512,
              dropout_rate=0.0)

optimizer = optim.Adam(model.parameters(), lr=0.0005)

train(model, train_loader, 15, F.cross_entropy, optimizer, test_loader)

Epoch 0, Loss: 1.9103
Test Loss: 1.6215
Epoch 3, Loss: 1.4074
Test Loss: 1.3452


In [None]:
test_data = next(iter(test_loader))
test_preds = model(test_data[0])
test_pred_classes = test_preds.argmax(dim=-1)
print("Sample predictions:")
for i in range(5):
    print(f"Input: {test_data[0][i].tolist()}")
    print(f"Target: {test_data[1][i].tolist()}")
    print(f"Predicted: {test_pred_classes[i].tolist()}")
    print()

Sample predictions:
Input: [1, 11, 17, 10, 3, 6, 6, 9, 1, 7, 1, 3, 17, 16, 0]
Target: [16, 17, 3, 1, 7, 1, 9, 6, 6, 3, 10, 17, 11, 1, 0]
Predicted: [14, 17, 3, 1, 7, 1, 9, 6, 6, 3, 10, 17, 11, 1, 0]

Input: [14, 10, 6, 8, 13, 13, 15, 18, 8, 12, 12, 7, 14, 0, 0]
Target: [14, 7, 12, 12, 8, 18, 15, 13, 13, 8, 6, 10, 14, 0, 0]
Predicted: [14, 14, 12, 12, 8, 18, 15, 13, 13, 8, 6, 10, 14, 0, 0]

Input: [3, 7, 12, 9, 4, 11, 6, 14, 18, 9, 0, 0, 0, 0, 0]
Target: [9, 18, 14, 6, 11, 4, 9, 12, 7, 3, 0, 0, 0, 0, 0]
Predicted: [9, 8, 14, 6, 11, 4, 9, 12, 7, 3, 0, 0, 0, 0, 0]

Input: [7, 1, 17, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Target: [17, 1, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Predicted: [17, 1, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Input: [14, 14, 4, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Target: [19, 4, 14, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Predicted: [19, 4, 14, 14, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]



## Gotchas
Few things that I have encountered while implementing while implementing transformers from scratch:
- Initialization - without `xavier` initialization, the neural network was never able to converge. It is porbably due to the vanishing gradients or weak signals.
- When we have sequences of different length in our batch, default collate fails since stacking requires equal shapes. We need to tell PyTorch explicitly how to combine them. The right solution though, is to implement padding and masking for sequences.
- High learning rate and the absence of gradient clipping might lead to gradient explosion. During backpropagation, gradients can grow, so clipping helps there.