# Week 1: Introduction to Transformers

## Notebook 02: Multi-Head Attention and Transformer Blocks

This notebook explores Multi-Head Attention (MHA) and how it enables the model to attend to different representation subspaces.

### Learning Objectives
- Understand the motivation for multi-head attention
- Implement and test multi-head attention
- Build a complete transformer block
- Analyze the role of feed-forward networks and residual connections

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from src.llm_journey.models import MultiHeadAttention, TransformerBlock
from src.llm_journey.utils import set_seed, count_parameters

set_seed(42)

## 1. Multi-Head Attention

Instead of performing a single attention function, multi-head attention projects the queries, keys, and values $h$ times with different learned linear projections.

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O$$

where each head is:

$$\text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)$$

This allows the model to jointly attend to information from different representation subspaces.

In [None]:
# Create multi-head attention module
d_model = 64
num_heads = 8
batch_size = 2
seq_len = 10

mha = MultiHeadAttention(d_model, num_heads, dropout=0.1)
print(f"Multi-Head Attention parameters: {count_parameters(mha)}")

# Create input
x = torch.randn(batch_size, seq_len, d_model)
print(f"Input shape: {x.shape}")

# Forward pass
output = mha(x, x, x)
print(f"Output shape: {output.shape}")

## 2. Why Multiple Heads?

Multiple heads allow the model to:
- Attend to different positions simultaneously
- Capture different aspects of the relationships between tokens
- Learn different attention patterns for syntax, semantics, etc.

In [None]:
# Compare single-head vs multi-head
configs = [
    {"num_heads": 1, "d_model": 64},
    {"num_heads": 4, "d_model": 64},
    {"num_heads": 8, "d_model": 64},
]

for config in configs:
    mha = MultiHeadAttention(config["d_model"], config["num_heads"])
    params = count_parameters(mha)
    print(f"Heads: {config['num_heads']:2d} | Parameters: {params:,}")

## 3. Transformer Block

A transformer block combines:
1. Multi-head self-attention
2. Feed-forward network (two linear layers with activation)
3. Layer normalization
4. Residual connections

$$\text{TransformerBlock}(x) = \text{LayerNorm}(x + \text{FFN}(\text{LayerNorm}(x + \text{MHA}(x))))$$

In [None]:
# Create transformer block
d_model = 64
num_heads = 8
d_ff = 256  # Feed-forward dimension (typically 4x d_model)

block = TransformerBlock(d_model, num_heads, d_ff, dropout=0.1)
print(f"Transformer block parameters: {count_parameters(block):,}")

# Forward pass
x = torch.randn(batch_size, seq_len, d_model)
output = block(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

## 4. Residual Connections and Layer Normalization

These components are crucial for training deep networks:
- **Residual connections**: Enable gradient flow through deep networks
- **Layer normalization**: Stabilizes training and speeds up convergence

In [None]:
# Demonstrate residual connections
x = torch.randn(1, 5, d_model)
output = block(x)

# The output should be different from input but preserves information flow
print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
print(f"Output mean: {output.mean():.4f}, std: {output.std():.4f}")

## 5. Stacking Transformer Blocks

Modern LLMs stack many transformer blocks (12, 24, or even 96 layers). Let's see how this scales.

In [None]:
# Stack multiple transformer blocks
num_layers_list = [1, 6, 12, 24]

for num_layers in num_layers_list:
    blocks = nn.ModuleList([TransformerBlock(d_model, num_heads, d_ff) for _ in range(num_layers)])
    total_params = sum(count_parameters(block) for block in blocks)
    print(f"Layers: {num_layers:2d} | Total parameters: {total_params:,}")

## Exercises

1. Experiment with different numbers of attention heads and observe parameter counts
2. Modify the feed-forward dimension ratio (currently 4x) and analyze the impact
3. Implement a function to visualize attention weights for each head
4. Compare the output distributions with and without layer normalization
5. Stack multiple transformer blocks and pass data through the entire stack

## Next Steps

Continue to Notebook 03 to build a complete language model with positional encoding and learn about training procedures.