# Multi-Head Attention: Seeing the World from Multiple Perspectives

In the previous notebook, we built **Causal Attention** — a mechanism where each token can only attend to previous tokens. But there's a limitation: we only have **one set of Q, K, V projections**, which means we can only learn **one type of attention pattern**.

## The Problem with Single-Head Attention

Consider the sentence: `"The cat sat on the mat because it was tired"`

When processing the word "it", the model needs to understand:
- **Grammatical reference**: "it" refers to "cat" (not "mat")
- **Semantic relationship**: "tired" is a property of living things
- **Positional context**: "it" appears after describing an action

**One attention head can only focus on ONE of these aspects at a time!**

---

## The Solution: Multiple Heads

What if we ran **multiple attention mechanisms in parallel**, each with its own Q, K, V weights?

```
                        Input: "The cat sat on the mat because it was tired"
                                              │
                 ┌────────────────────────────┼────────────────────────────┐
                 │                            │                            │
                 ▼                            ▼                            ▼
           ┌──────────┐                ┌──────────┐                ┌──────────┐
           │  Head 1  │                │  Head 2  │                │  Head 3  │
           │          │                │          │                │          │
           │ Q₁,K₁,V₁ │                │ Q₂,K₂,V₂ │                │ Q₃,K₃,V₃ │
           │          │                │          │                │          │
           │ Focuses  │                │ Focuses  │                │ Focuses  │
           │    on    │                │    on    │                │    on    │
           │ grammar  │                │semantics │                │ position │
           └────┬─────┘                └────┬─────┘                └────┬─────┘
                │                           │                           │
                └───────────────────────────┼───────────────────────────┘
                                            │
                                            ▼
                                    ┌───────────────┐
                                    │  Concatenate  │
                                    │   & Project   │
                                    └───────────────┘
                                            │
                                            ▼
                                      Final Output
```

**Each head learns to attend to different things!**

---

## What We'll Learn

1. **Simple Multi-Head Implementation** — Stack multiple CausalAttention modules
2. **Why this is inefficient** — Separate matrix multiplications
3. **Efficient Multi-Head Implementation** — Single large projection + reshape
4. **The reshape trick** — How `view()` and `transpose()` enable parallelism
5. **Output projection** — Combining heads back together

---

## Key Intuition: Divide and Conquer

Instead of one large attention with `d_out = 256`:
```
Single Head: d_out = 256  →  One 256-dim attention pattern
```

We split into multiple smaller heads:
```
8 Heads: head_dim = 256/8 = 32  →  Eight 32-dim attention patterns
                                    (then concatenated back to 256)
```

**Same total parameters, but more diverse attention patterns!**

---

## Install Dependencies

In [1]:
!pip install torch tiktoken transformers



## Step 1: Our Familiar Input Embeddings

Same sentence as before: **"Your journey starts with one step"**

Each word is a 3-dimensional embedding vector. We'll use these to demonstrate multi-head attention.

```
Word        Embedding              
────        ─────────              
"Your"      [0.43, 0.15, 0.89]    
"journey"   [0.55, 0.87, 0.66]    
"starts"    [0.57, 0.85, 0.64]    
"with"      [0.22, 0.58, 0.33]    
"one"       [0.77, 0.25, 0.10]    
"step"      [0.05, 0.80, 0.55]    
```

In [2]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
print(x_2)
print(d_in)

tensor([0.5500, 0.8700, 0.6600])
3


## Step 2: Recap — CausalAttention from Previous Notebook

Here's the `CausalAttention` class we built. This will be our **building block** for multi-head attention.

**Quick reminder of what it does:**
1. Projects input → Queries, Keys, Values
2. Computes attention scores (Q × Kᵀ)
3. Applies causal mask (hide future tokens)
4. Softmax → attention weights
5. Weighted sum of Values → context vectors

```
Input [batch, seq, d_in]  →  CausalAttention  →  Output [batch, seq, d_out]
```

Each `CausalAttention` instance has its **own learnable W_query, W_key, W_value** matrices.

In [3]:
import torch.nn as nn
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
           'mask',
           torch.triu(torch.ones(context_length, context_length),
           diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)   
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

## Step 3: Simple Multi-Head Attention (The Wrapper Approach)

The most intuitive way to implement multi-head attention: **just create multiple CausalAttention modules and concatenate their outputs!**

### How It Works

```
                              Input x
                                 │
           ┌─────────────────────┼─────────────────────┐
           │                     │                     │
           ▼                     ▼                     ▼
    ┌─────────────┐       ┌─────────────┐       ┌─────────────┐
    │   Head 0    │       │   Head 1    │       │   Head 2    │
    │             │       │             │       │             │
    │ CausalAttn  │       │ CausalAttn  │       │ CausalAttn  │
    │ (d_in→d_out)│       │ (d_in→d_out)│       │ (d_in→d_out)│
    └──────┬──────┘       └──────┬──────┘       └──────┬──────┘
           │                     │                     │
           │    [batch,seq,2]    │    [batch,seq,2]    │    [batch,seq,2]
           │                     │                     │
           └─────────────────────┼─────────────────────┘
                                 │
                                 ▼
                        torch.cat(dim=-1)
                                 │
                                 ▼
                      [batch, seq, 2×num_heads]
                      = [batch, seq, 6] if 3 heads
```

### The Code Explained

```python
self.heads = nn.ModuleList([
    CausalAttention(...) for _ in range(num_heads)
])
```
- `nn.ModuleList`: A list that PyTorch recognizes as containing submodules
- Each head has its own separate Q, K, V weight matrices

```python
return torch.cat([head(x) for head in self.heads], dim=-1)
```
- Run input through each head independently
- Concatenate outputs along the last dimension (features)

In [4]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(
                 d_in, d_out, context_length, dropout, qkv_bias
             ) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

### Testing the Wrapper

Let's create a batch of inputs (2 identical sequences) to test our multi-head attention wrapper.

In [5]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


Now let's create a multi-head attention with **2 heads**, each outputting 2 dimensions:

```
Input shape:  [2, 6, 3]   →   2 batches, 6 tokens, 3-dim embeddings
                ↓
         2 attention heads
         (each: d_in=3 → d_out=2)
                ↓
Output shape: [2, 6, 4]   →   2 batches, 6 tokens, 2×2=4 dim (concatenated)
```

In [6]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


### Understanding the Output

```
Output shape: [2, 6, 4]
              ↑  ↑  ↑
           batch seq d_out×num_heads
```

- **First 2 columns** `[-0.45, 0.22, ...]`: Output from Head 0
- **Last 2 columns** `[0.48, 0.11, ...]`: Output from Head 1

Both batch items have identical outputs because they have identical inputs (we stacked the same `inputs` twice).

### The Problem: This Approach is Inefficient!

While conceptually clear, this wrapper has a major drawback:

```
Wrapper Approach:                    Efficient Approach:
─────────────────                    ───────────────────
Head 0: x @ W_q0, x @ W_k0, x @ W_v0    Single large matrix multiply:
Head 1: x @ W_q1, x @ W_k1, x @ W_v1    x @ W_q (all heads at once)
Head 2: x @ W_q2, x @ W_k2, x @ W_v2    x @ W_k (all heads at once)
...                                      x @ W_v (all heads at once)

= 3 × num_heads matrix multiplies      = 3 matrix multiplies total!
```

**GPUs love large matrix operations!** Doing one big multiply is much faster than many small ones.

---

## Step 4: Efficient Multi-Head Attention (The Real Implementation)

The key insight: instead of **separate weight matrices per head**, use **one large weight matrix** and then **reshape** the output to split it into heads.

### The Reshape Trick

```
Step 1: Single large projection
────────────────────────────────
Input:  [batch, seq, d_in]       e.g., [2, 6, 3]
W_q:    [d_in, d_out]            e.g., [3, 4]  (d_out = num_heads × head_dim = 2×2)
                ↓
Q:      [batch, seq, d_out]      e.g., [2, 6, 4]

Step 2: Reshape to separate heads
─────────────────────────────────
Q:      [batch, seq, d_out]           [2, 6, 4]
            ↓ view()
Q:      [batch, seq, num_heads, head_dim]   [2, 6, 2, 2]
            ↓ transpose(1,2)
Q:      [batch, num_heads, seq, head_dim]   [2, 2, 6, 2]

Now each head has its own slice of the embedding!
```

### Why `transpose(1, 2)`?

We want the `num_heads` dimension second so we can do batched matrix multiplication:

```
Before transpose: [batch, seq, heads, head_dim]
                         └─────┘
                      These need to interact for attention

After transpose:  [batch, heads, seq, head_dim]
                         │      └─────────────┘
                         │      These compute attention
                         │
                   This acts like an extra batch dimension!
```

PyTorch's `@` operator can handle this: it does batched matmul over the first dimensions.

### The Full Process

```
┌─────────────────────────────────────────────────────────────────────────┐
│                    EFFICIENT MULTI-HEAD ATTENTION                       │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│  Input: [batch, seq, d_in]                                              │
│            │                                                             │
│            ├──────────────────┬──────────────────┐                      │
│            ▼                  ▼                  ▼                      │
│     ┌───────────┐      ┌───────────┐      ┌───────────┐                │
│     │  W_query  │      │  W_key    │      │  W_value  │                │
│     │[d_in,d_out]│     │[d_in,d_out]│     │[d_in,d_out]│               │
│     └─────┬─────┘      └─────┬─────┘      └─────┬─────┘                │
│           │                  │                  │                       │
│           ▼                  ▼                  ▼                       │
│    [batch,seq,d_out]  [batch,seq,d_out]  [batch,seq,d_out]             │
│           │                  │                  │                       │
│           ▼                  ▼                  ▼                       │
│    ┌──────────────────────────────────────────────────┐                │
│    │              view + transpose                     │                │
│    │  [batch, seq, d_out] → [batch, heads, seq, head_dim] │            │
│    └──────────────────────────────────────────────────┘                │
│           │                  │                  │                       │
│           ▼                  ▼                  ▼                       │
│        Q: [b,h,s,d]      K: [b,h,s,d]      V: [b,h,s,d]                │
│           │                  │                  │                       │
│           └────────┬─────────┘                  │                       │
│                    ▼                            │                       │
│        Attention: Q @ K.transpose(-2,-1)        │                       │
│           = [b, h, s, s]                        │                       │
│                    │                            │                       │
│                    ▼                            │                       │
│           Mask + Softmax + Dropout              │                       │
│                    │                            │                       │
│                    └────────────┬───────────────┘                       │
│                                 ▼                                       │
│                    Weights @ V = [b, h, s, d]                          │
│                                 │                                       │
│                                 ▼                                       │
│                    transpose + contiguous + view                        │
│                    [b, h, s, d] → [b, s, h×d] = [b, s, d_out]          │
│                                 │                                       │
│                                 ▼                                       │
│                         Output Projection                               │
│                         [d_out, d_out]                                  │
│                                 │                                       │
│                                 ▼                                       │
│                    Output: [batch, seq, d_out]                          │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
```

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, 
                 context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)  
        queries = queries.view(                                             
            b, num_tokens, self.num_heads, self.head_dim                    
        )                                                                   

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )
        context_vec = self.out_proj(context_vec)
        return context_vec

### Code Walkthrough: Line by Line

```python
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        
        # IMPORTANT: d_out must be divisible by num_heads!
        # Each head gets d_out/num_heads dimensions
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Each head's dimension
        
        # Single large projections (instead of separate ones per head)
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        # NEW: Output projection to mix information across heads
        self.out_proj = nn.Linear(d_out, d_out)
        
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(...))  # Same causal mask

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        # Step 1: Project to Q, K, V (single large matrix multiply each)
        keys = self.W_key(x)       # [b, seq, d_out]
        queries = self.W_query(x)  # [b, seq, d_out]
        values = self.W_value(x)   # [b, seq, d_out]

        # Step 2: Reshape to [batch, seq, num_heads, head_dim]
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # Step 3: Transpose to [batch, num_heads, seq, head_dim]
        # This puts heads as a "batch" dimension for efficient attention
        keys = keys.transpose(1, 2)     # [b, h, s, d]
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Step 4: Compute attention scores
        # [b, h, s, d] @ [b, h, d, s] = [b, h, s, s]
        attn_scores = queries @ keys.transpose(2, 3)
        
        # Step 5: Apply causal mask
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # Step 6: Softmax (with scaling) + dropout
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Step 7: Weighted sum of values
        # [b, h, s, s] @ [b, h, s, d] = [b, h, s, d]
        context_vec = (attn_weights @ values)
        
        # Step 8: Transpose back to [batch, seq, heads, head_dim]
        context_vec = context_vec.transpose(1, 2)

        # Step 9: Reshape to [batch, seq, d_out] by merging heads
        # .contiguous() ensures memory is laid out correctly after transpose
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        
        # Step 10: Final output projection
        context_vec = self.out_proj(context_vec)
        return context_vec
```

### Key Differences from the Wrapper

| Wrapper Approach | Efficient Approach |
|-----------------|-------------------|
| Multiple small W matrices | One large W matrix |
| Loop through heads | Single batched operation |
| Simple concatenation | Reshape + transpose magic |
| No output projection | Adds `out_proj` layer |

### Why the Output Projection?

```python
self.out_proj = nn.Linear(d_out, d_out)
```

After concatenating head outputs, the `out_proj` layer:
1. **Mixes information** across heads (each head worked independently)
2. **Adds more learnable parameters** for richer representations
3. **Matches the original Transformer paper** specification

Now let's test the efficient implementation:

In [8]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


### Output Comparison

Notice the output shape difference:

| Implementation | Output Shape | Note |
|---------------|--------------|------|
| Wrapper (2 heads, d_out=2 each) | `[2, 6, 4]` | Concatenated: 2×2=4 |
| Efficient (2 heads, d_out=2 total) | `[2, 6, 2]` | Shared: 2÷2=1 per head |

The efficient version keeps the same `d_out` as input to the layer, which is more common in practice (maintains consistent embedding dimensions throughout the model).

---

## Summary: What We Learned

### Multi-Head Attention in One Picture

```
┌─────────────────────────────────────────────────────────────────────────┐
│                        MULTI-HEAD ATTENTION                             │
├─────────────────────────────────────────────────────────────────────────┤
│                                                                          │
│    "The cat sat on the mat because it was tired"                        │
│                           │                                              │
│           ┌───────────────┼───────────────┐                             │
│           ▼               ▼               ▼                             │
│     ┌──────────┐    ┌──────────┐    ┌──────────┐                       │
│     │  Head 1  │    │  Head 2  │    │  Head 3  │    ...                │
│     │          │    │          │    │          │                       │
│     │ "it"→cat │    │ position │    │  syntax  │                       │
│     │(reference)│   │ patterns │    │ patterns │                       │
│     └─────┬────┘    └─────┬────┘    └─────┬────┘                       │
│           │               │               │                             │
│           └───────────────┼───────────────┘                             │
│                           ▼                                              │
│                   Concatenate & Project                                  │
│                           │                                              │
│                           ▼                                              │
│              Combined multi-aspect representation                        │
│                                                                          │
└─────────────────────────────────────────────────────────────────────────┘
```

### Key Takeaways

| Concept | What We Learned |
|---------|-----------------|
| **Why Multiple Heads** | Different heads learn different attention patterns (syntax, semantics, position, etc.) |
| **Wrapper Approach** | Simple but inefficient — uses separate W matrices per head |
| **Efficient Approach** | Single large projection + reshape = GPU-friendly |
| **The Reshape Trick** | `view()` + `transpose()` splits embeddings into heads |
| **Output Projection** | Mixes information across heads after concatenation |
| **head_dim** | `d_out ÷ num_heads` — each head works with a fraction of dimensions |

### Real-World Numbers (GPT-2 Small)

```
d_model = 768          (embedding dimension)
num_heads = 12         (attention heads)
head_dim = 768/12 = 64 (dimensions per head)

Each head: 64-dim queries, keys, values
All heads: 12 × 64 = 768 (back to original dimension)
```

---

## Exercise 3.3: Initializing GPT-2 Size Attention Modules

### The Task

Using the `MultiHeadAttention` class we built, initialize a multi-head attention module that matches the **smallest GPT-2 model** specifications:

| Parameter | GPT-2 Small Value | Description |
|-----------|------------------|-------------|
| `num_heads` | **12** | Number of attention heads |
| `d_in` | **768** | Input embedding dimension |
| `d_out` | **768** | Output embedding dimension |
| `context_length` | **1024** | Maximum sequence length |

### Why These Numbers?

```
GPT-2 Small Architecture:
─────────────────────────
Total parameters: ~124 million
Embedding dim:    768
Attention heads:  12
head_dim:         768 / 12 = 64  ← Each head works with 64 dimensions

The 768-dimensional embedding is split across 12 heads:
┌──────────────────────────────────────────────────────────────────┐
│                     768-dim embedding                             │
├─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┤
│ 64  │ 64  │ 64  │ 64  │ 64  │ 64  │ 64  │ 64  │ 64  │ 64  │ 64  │ 64  │
│head1│head2│head3│head4│head5│head6│head7│head8│head9│h10 │h11 │h12 │
└─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
```

---

### Step-by-Step Solution

#### Step 1: Define the GPT-2 Small Configuration

First, let's define all the hyperparameters matching GPT-2 Small:

In [9]:
# GPT-2 Small configuration
GPT2_CONFIG = {
    "d_in": 768,            # Input embedding dimension
    "d_out": 768,           # Output embedding dimension (same as input)
    "context_length": 1024, # Maximum context window (tokens)
    "num_heads": 12,        # Number of attention heads
    "dropout": 0.1,         # Dropout rate (typical for GPT-2)
    "qkv_bias": True        # GPT-2 uses bias in Q, K, V projections
}

# Let's verify the head dimension calculation
head_dim = GPT2_CONFIG["d_out"] // GPT2_CONFIG["num_heads"]
print(f"GPT-2 Small Configuration:")
print(f"  Embedding dimension: {GPT2_CONFIG['d_in']}")
print(f"  Number of heads:     {GPT2_CONFIG['num_heads']}")
print(f"  Head dimension:      {head_dim}")
print(f"  Context length:      {GPT2_CONFIG['context_length']} tokens")

GPT-2 Small Configuration:
  Embedding dimension: 768
  Number of heads:     12
  Head dimension:      64
  Context length:      1024 tokens


#### Step 2: Initialize the MultiHeadAttention Module

Now let's create our GPT-2 sized attention module:

In [10]:
# Initialize GPT-2 sized Multi-Head Attention
torch.manual_seed(123)

gpt2_mha = MultiHeadAttention(
    d_in=GPT2_CONFIG["d_in"],
    d_out=GPT2_CONFIG["d_out"],
    context_length=GPT2_CONFIG["context_length"],
    dropout=GPT2_CONFIG["dropout"],
    num_heads=GPT2_CONFIG["num_heads"],
    qkv_bias=GPT2_CONFIG["qkv_bias"]
)

print("GPT-2 Multi-Head Attention Module Created!")
print(gpt2_mha)

GPT-2 Multi-Head Attention Module Created!
MultiHeadAttention(
  (W_query): Linear(in_features=768, out_features=768, bias=True)
  (W_key): Linear(in_features=768, out_features=768, bias=True)
  (W_value): Linear(in_features=768, out_features=768, bias=True)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)


#### Step 3: Count the Parameters

Let's see how many learnable parameters are in just this one attention layer:

In [11]:
# Count parameters in each component
def count_parameters(module):
    return sum(p.numel() for p in module.parameters())

# Break down by component
print("Parameter Count Breakdown:")
print("=" * 50)

# W_query, W_key, W_value each: d_in × d_out + d_out (bias)
qkv_params = count_parameters(gpt2_mha.W_query)
print(f"W_query: {qkv_params:,} parameters")
print(f"  - Weight: {GPT2_CONFIG['d_in']} × {GPT2_CONFIG['d_out']} = {GPT2_CONFIG['d_in'] * GPT2_CONFIG['d_out']:,}")
print(f"  - Bias:   {GPT2_CONFIG['d_out']:,}")

print(f"\nW_key:   {count_parameters(gpt2_mha.W_key):,} parameters")
print(f"W_value: {count_parameters(gpt2_mha.W_value):,} parameters")

# Output projection: d_out × d_out + d_out (bias)
out_proj_params = count_parameters(gpt2_mha.out_proj)
print(f"\nout_proj: {out_proj_params:,} parameters")
print(f"  - Weight: {GPT2_CONFIG['d_out']} × {GPT2_CONFIG['d_out']} = {GPT2_CONFIG['d_out'] * GPT2_CONFIG['d_out']:,}")
print(f"  - Bias:   {GPT2_CONFIG['d_out']:,}")

# Total
total_params = count_parameters(gpt2_mha)
print(f"\n{'=' * 50}")
print(f"TOTAL: {total_params:,} parameters")
print(f"\nThat's {total_params / 1e6:.2f} million parameters in ONE attention layer!")

Parameter Count Breakdown:
W_query: 590,592 parameters
  - Weight: 768 × 768 = 589,824
  - Bias:   768

W_key:   590,592 parameters
W_value: 590,592 parameters

out_proj: 590,592 parameters
  - Weight: 768 × 768 = 589,824
  - Bias:   768

TOTAL: 2,362,368 parameters

That's 2.36 million parameters in ONE attention layer!


#### Step 4: Test with a Sample Input

Let's create a sample batch that matches GPT-2's expected input format and run it through our attention module:

In [12]:
# Create a sample input batch
# Shape: [batch_size, sequence_length, embedding_dim]
batch_size = 2
seq_length = 128  # Using a shorter sequence for demo (full GPT-2 supports 1024)
embedding_dim = GPT2_CONFIG["d_in"]

# Random embeddings (in practice, these come from the embedding layer)
sample_input = torch.randn(batch_size, seq_length, embedding_dim)

print(f"Input shape: {sample_input.shape}")
print(f"  - Batch size: {batch_size}")
print(f"  - Sequence length: {seq_length}")
print(f"  - Embedding dimension: {embedding_dim}")

Input shape: torch.Size([2, 128, 768])
  - Batch size: 2
  - Sequence length: 128
  - Embedding dimension: 768


#### Step 5: Run the Forward Pass

Now let's pass our sample through the attention layer:

In [None]:
# Set to evaluation mode (disables dropout)
gpt2_mha.eval()

# Run forward pass
with torch.no_grad():
    output = gpt2_mha(sample_input)

print(f"Output shape: {output.shape}")
print(f"  - Batch size: {output.shape[0]}")
print(f"  - Sequence length: {output.shape[1]}")
print(f"  - Embedding dimension: {output.shape[2]}")

print(f"\n✓ Input and output have the same shape!")
print(f"  Input:  {list(sample_input.shape)}")
print(f"  Output: {list(output.shape)}")

### Exercise Summary

We successfully initialized a GPT-2 sized multi-head attention module!

```
┌─────────────────────────────────────────────────────────────────────┐
│                    GPT-2 SMALL ATTENTION LAYER                      │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  Configuration:                                                      │
│  ─────────────                                                       │
│  • Embedding dimension: 768                                          │
│  • Number of heads: 12                                               │
│  • Head dimension: 64                                                │
│  • Context length: 1024 tokens                                       │
│                                                                      │
│  Parameters:                                                         │
│  ───────────                                                         │
│  • W_query: 768 × 768 + 768 = 590,592                               │
│  • W_key:   768 × 768 + 768 = 590,592                               │
│  • W_value: 768 × 768 + 768 = 590,592                               │
│  • out_proj: 768 × 768 + 768 = 590,592                              │
│  ─────────────────────────────────────                               │
│  • TOTAL: ~2.36 million parameters                                   │
│                                                                      │
│  Note: GPT-2 Small has 12 transformer blocks, each with              │
│  an attention layer like this one. That's 12 × 2.36M ≈ 28M          │
│  parameters just for attention (out of 124M total)!                  │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘
```

**Key Takeaways:**
1. Real transformer models use much larger dimensions (768 vs our demo's 3)
2. The head dimension (64) is what matters for attention computation quality
3. A single attention layer has millions of learnable parameters
4. Input and output shapes match — this is crucial for residual connections!