# Lesson 4: Multi-Head Attention and Transformer Block

Welcome to Lesson 4! We’re taking a big step forward by upgrading from a single attention head (Lesson 3) to **multi-head attention** and building a full **transformer block**, the backbone of models like Llama 2. Don’t worry if this sounds complex—we’ll break it down into simple pieces with lots of examples and visuals. Since you’re new to LLMs, we’ll focus on understanding *why* these ideas matter and *how* they work, using PyTorch and our Chinese text from `data.csv`.

## What You’ll Learn
- Why **multi-head attention** is better than single-head attention.
- How to code multi-head attention with multiple attention heads running in parallel.
- What a **transformer block** is and how it combines attention with other layers.
- How to visualize attention patterns across multiple heads using our Chinese data.

## Our Data
We’re still using `data.csv` with the `head` column, featuring casual Hong Kong-style Chinese text (e.g., from Lihkg). Our example sentence is: *法國紅酒慢煮阿根廷牛舌 配 煙肉洋蔥炒著仔* (French red wine slow-cooked Argentine beef tongue paired with bacon onion stir-fry). We’ll see how multiple heads focus on different parts of this sentence!

## Prerequisites
- PyTorch installed (from Lesson 0).
- `data.csv` with columns `user`, `title`, `head`.
- Basic Python knowledge and familiarity with Lesson 3’s single-head attention.

## Step 1: Import Libraries

Let’s load the tools we need. These are familiar from Lesson 3, with some extras for visualization.

In [1]:
import torch
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Set random seed for consistent results
torch.manual_seed(42)

# Check PyTorch version
print("PyTorch version:", torch.__version__)

PyTorch version: 2.6.0+cu118


## Step 2: Recap Single-Head Attention

In Lesson 3, we built a single attention head. It took our input (e.g., `法國`, `紅酒`, `牛舌`), turned each token into **query (Q)**, **key (K)**, and **value (V)** vectors, and computed:
- **Scores**: How much each token relates to others (e.g., `牛舌` focusing on `慢煮`).
- **Weights**: Probabilities from softmax (e.g., 0.45 for `慢煮`).
- **Output**: A new representation mixing relevant info.

But a single head can only focus on *one pattern* at a time. What if `牛舌` needs to focus on *both* its cooking method (`慢煮`) *and* its origin (`阿根廷`)? That’s where **multi-head attention** comes in—it lets the model look at multiple relationships simultaneously!

## Step 3: Understand Multi-Head Attention

### Why Multi-Head?
Imagine you’re reading our sentence: *法國紅酒慢煮阿根廷牛舌 配 煙肉洋蔥炒著仔*. To understand `牛舌` (beef tongue):
- One part of your brain might focus on **how it’s cooked** (`慢煮`, `紅酒`).
- Another part might focus on **where it’s from** (`阿根廷`).
- A third part might check **what it’s paired with** (`配`, `煙肉`).

A single attention head can only pick *one* of these patterns. Multi-head attention runs several heads in parallel (e.g., 4 heads), each learning a different focus, then combines their insights. This makes the model smarter and more flexible!

### How It Works
1. **Split Into Heads**: Instead of one Q, K, V set, we create multiple sets (e.g., 4 heads/ 4Q,4K,4V).
2. **Run Attention Separately**: Each head computes its own attention weights and output.
3. **Combine Results**: Concatenate the outputs from all heads and transform them back to the original size.

Let’s build it step-by-step.

## Step 4: Simulate Input Data

Like Lesson 3, we’ll simulate tokenized embeddings for 6 tokens from our sentence: `法國`, `紅酒`, `慢煮`, `阿根廷`, `牛舌`, `配`. In a real model, these come from a tokenizer and embedding layer.

In [2]:
# Define sizes
batch_size = 1    # One sentence
seq_length = 6    # 6 tokens: 法國, 紅酒, 慢煮, 阿根廷, 牛舌, 配
embed_dim = 64    # Embedding size per token

# Simulate embeddings
X = torch.randn(batch_size, seq_length, embed_dim)
print("Input shape:", X.shape)  # [1, 6, 64]
print("First token’s embedding (partial):", X[0, 0, :5])

Input shape: torch.Size([1, 6, 64])
First token’s embedding (partial): tensor([ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784])


## Step 5: Build Multi-Head Attention

Let’s create a `MultiHeadAttention` class. We’ll use 4 heads, each with a smaller dimension (e.g., 16 instead of 64), so the total size matches the input after combining.

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads  # Each head’s size

        # Ensure embed_dim is divisible by num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        # Linear layers for Q, K, V across all heads
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)  # Combine heads

        self.scale = self.head_dim ** -0.5  # Scaling factor

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.size()

        # Step 1: Compute Q, K, V
        Q = self.query(x)  # [1, 6, 64]
        K = self.key(x)    # [1, 6, 64]
        V = self.value(x)  # [1, 6, 64]

        # Step 2: Split into multiple heads
        Q = Q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # [1, 4, 6, 16]
        K = K.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # [1, 4, 6, 16]
        V = V.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)  # [1, 4, 6, 16]

        # Step 3: Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale  # [1, 4, 6, 6]

        # Step 4: Apply softmax for weights
        attn_weights = torch.softmax(scores, dim=-1)  # [1, 4, 6, 6]

        # Step 5: Combine with values
        out = torch.matmul(attn_weights, V)  # [1, 4, 6, 16]

        # Step 6: Reshape and combine heads
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)  # [1, 6, 64]
        out = self.out(out)  # Final linear layer: [1, 6, 64]

        return out, attn_weights

# Test it
num_heads = 4
mha = MultiHeadAttention(embed_dim, num_heads)
output, attn_weights = mha(X)
print("Output shape:", output.shape)  # [1, 6, 64]
print("Attention weights shape:", attn_weights.shape)  # [1, 4, 6, 6]

### Breaking Down the Code
- **Sizes**: `embed_dim = 64`, `num_heads = 4`, so each head’s `head_dim = 16` (64 ÷ 4).
- **Q, K, V**: Initially computed for all heads together (`[1, 6, 64]`), then split into 4 heads (`[1, 4, 6, 16]`).
- **Scores**: Each head computes its own `[6, 6]` score matrix, so we get `[1, 4, 6, 6]`.
- **Output**: After attention, heads are combined back to `[1, 6, 64]`.

Each head might focus differently—e.g., Head 1 on cooking (`慢煮`), Head 2 on origin (`阿根廷`).

## Step 6: Build a Transformer Block

### What’s a Transformer Block?
A transformer block is a complete unit in a transformer model. It has two main parts:
1. **Multi-Head Attention**: Captures relationships between tokens (what we just built).
2. **Feed-Forward Network (FFN)**: Processes each token individually to add deeper understanding.

It also includes:
- **Layer Normalization**: Stabilizes the numbers (like adjusting volume so everything’s clear).
- **Residual Connections**: Adds the input back to the output (helps the model learn better).

### Why These Parts?
- **Attention**: “Hey, `牛舌`, look at `慢煮`!”
- **FFN**: “Let’s think more about what `牛舌` means now.”
- **Norm & Residual**: Keeps everything balanced and prevents forgetting the original input.

Let’s code it!

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super().__init__()
        # Multi-head attention
        self.attention = MultiHeadAttention(embed_dim, num_heads)
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim)
        )
        # Layer normalization
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        # Dropout to prevent overfitting
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Multi-head attention with residual connection
        attn_output, attn_weights = self.attention(x)
        x = self.norm1(x + self.dropout(attn_output))  # Add & norm

        # Feed-forward with residual connection
        ff_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ff_output))  # Add & norm

        return x, attn_weights

# Test it
ff_hidden_dim = 256  # Hidden layer size in FFN
transformer_block = TransformerBlock(embed_dim, num_heads, ff_hidden_dim)
output, attn_weights = transformer_block(X)
print("Transformer output shape:", output.shape)  # [1, 6, 64]
print("Attention weights shape:", attn_weights.shape)  # [1, 4, 6, 6]

### Breaking Down the Transformer Block
- **Attention**: Computes relationships, outputs `[1, 6, 64]` and weights `[1, 4, 6, 6]`.
- **Residual**: `x + attn_output` keeps the original input’s info.
- **Norm1**: Normalizes after attention.
- **FFN**: Expands to 256 dims, applies ReLU, then shrinks back to 64.
- **Norm2**: Normalizes after FFN.
- **Dropout**: Randomly drops some values (10% here) to avoid over-reliance.

This block refines `牛舌` by first connecting it to other tokens, then thinking deeper about its meaning.

## Step 7: Visualize Multi-Head Attention Weights

Let’s plot the attention weights from all 4 heads to see their different focuses. We’ll use a heatmap for each head.

In [5]:
# Set font for Chinese characters
plt.rcParams['font.sans-serif'] = ['SimHei']  # Use SimHei for Chinese
plt.rcParams['axes.unicode_minus'] = False

# Tokens for labeling
tokens = ['法國', '紅酒', '慢煮', '阿根廷', '牛舌', '配']

# Convert weights to numpy
attn_weights_np = attn_weights[0].detach().numpy()  # [4, 6, 6]

# Plot heatmaps for all heads
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()

for i in range(num_heads):
    sns.heatmap(attn_weights_np[i], annot=True, cmap='Blues', fmt='.2f',
                xticklabels=tokens, yticklabels=tokens, ax=axes[i])
    axes[i].set_title(f'Head {i+1} 注意力權重 (Attention Weights)')
    axes[i].set_xlabel('被關注的詞 (Keys)')
    axes[i].set_ylabel('關注的詞 (Queries)')

plt.tight_layout()
plt.show()

print("注意: 如果中文未顯示，請確保系統有 SimHei 字體。")

### What Do We See?
- **Head 1**: Might focus on cooking (`牛舌` → `慢煮`, `紅酒`).
- **Head 2**: Might focus on origin (`牛舌` → `阿根廷`).
- **Head 3**: Might focus on pairing (`牛舌` → `配`).
- **Head 4**: Could catch other patterns (e.g., `法國` → `紅酒`).

Each head’s heatmap (6x6) shows how every token attends to every other token. High values (dark blue) mean strong focus. This diversity is why multi-head attention is powerful!

## Step 8: Exercises for Practice

Try these to solidify your understanding:
1. **Change Number of Heads**: Set `num_heads` to 2 or 8. Re-run and check the output shape and heatmaps. How does the number of heads affect patterns?
2. **Adjust FFN Size**: Change `ff_hidden_dim` to 128 or 512. Does the output shape change? Why or why not?
3. **Visualize One Token**: Print `attn_weights[0, :, 4, :]` (weights for `牛舌`) for all heads. Which head focuses most on `慢煮`?
4. **Add Tokens**: Increase `seq_length` to 8 (add `煙肉`, `洋蔥`). Re-run and observe the heatmaps.

Write your code below and experiment!

In [None]:
# Your exercise code here



## Summary

You’ve conquered a lot today! Here’s what we did:
- **Multi-Head Attention**: Built a system where multiple heads (e.g., 4) catch different relationships in our sentence—like cooking, origin, and pairing for `牛舌`.
- **Transformer Block**: Added a feed-forward network, normalization, and residuals to refine token meanings.
- **Visualization**: Plotted attention weights to see each head’s unique focus.

Next up in Lesson 5, we’ll stack multiple transformer blocks to build a mini transformer model and generate text. You’re getting closer to understanding LLMs like the pros!