In [None]:
# Module 7a: Self-Attention Mechanism
# This notebook walks through the self-attention mechanism step by step,
# building from simple dot-product attention to full multi-head attention.

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Our input sequence: "Your journey starts with one step"
# Each token is represented as a 3-dimensional embedding vector

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

print(f"Input shape: {inputs.shape}")  # [seq_len, d_model] = [6, 3]
print(f"Sequence length: {inputs.shape[0]}, Embedding dim: {inputs.shape[1]}")

In [4]:
query = inputs[1]  # 2nd input token is the query
print('Query - ', query)

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)

print(attn_scores_2)

Query -  tensor([0.5500, 0.8700, 0.6600])
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [6]:
torch.dot(torch.tensor([0.43, 0.15, 0.89]), torch.tensor([0.55, 0.87, 0.66]))

tensor(0.9544)

### normalizing

In [7]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)  # torch.exp(x) / torch.exp(x).sum(dim=0)

print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())


Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [None]:
# Step 3: Context Vector
# The context vector is a weighted sum of ALL input vectors,
# using the attention weights we just computed.
# This gives us a new representation of token 2 ("journey")
# that incorporates information from all other tokens.

context_vec_2 = torch.zeros(inputs.shape[1])
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i

print("Context vector for token 2 ('journey'):")
print(context_vec_2)
print(f"\nOriginal token 2: {inputs[1]}")
print(f"Context vector:   {context_vec_2}")
print("\nThe context vector is a blend of all tokens, weighted by attention!")

In [None]:
# Let's compute ALL context vectors at once (for all tokens, not just token 2)
# Using matrix multiplication: attention_scores = inputs @ inputs.T

# Step 1: All pairwise attention scores
attn_scores = inputs @ inputs.T  # [6, 6]
print("All attention scores (6x6 matrix):")
print(attn_scores)

# Step 2: Normalize each row with softmax
attn_weights = torch.softmax(attn_scores, dim=-1)  # each row sums to 1
print("\nAttention weights (after softmax):")
print(attn_weights)
print(f"\nRow sums: {attn_weights.sum(dim=-1)}")  # all 1.0

# Step 3: Context vectors for all tokens
context_vecs = attn_weights @ inputs  # [6, 6] @ [6, 3] = [6, 3]
print(f"\nAll context vectors shape: {context_vecs.shape}")
print(context_vecs)
print(f"\nVerify: context_vecs[1] matches our earlier result: {torch.allclose(context_vecs[1], context_vec_2)}")

## Visualizing Attention Weights

Let's see which tokens attend to which. This is the key insight of self-attention: each token can "look at" every other token with different intensities.

In [None]:
tokens = ["Your", "journey", "starts", "with", "one", "step"]

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(attn_weights.numpy(), cmap='Blues')
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, fontsize=12)
ax.set_yticklabels(tokens, fontsize=12)
ax.set_xlabel("Key (attended to)", fontsize=12)
ax.set_ylabel("Query (attending)", fontsize=12)
ax.set_title("Simple Self-Attention Weights", fontsize=14)

# Add text annotations
for i in range(len(tokens)):
    for j in range(len(tokens)):
        ax.text(j, i, f"{attn_weights[i, j]:.2f}", ha="center", va="center", fontsize=10)

plt.colorbar(im)
plt.tight_layout()
plt.show()

---

## Query, Key, Value (Q, K, V) Projections

The simple attention above uses the raw input vectors for everything. In practice, transformers use **learnable weight matrices** to project inputs into three different spaces:

- **Query (Q)**: "What am I looking for?"
- **Key (K)**: "What do I contain?"
- **Value (V)**: "What information do I provide?"

This separation allows the model to learn different representations for matching (Q, K) vs. information transfer (V).

```
Q = X @ W_q    (what to search for)
K = X @ W_k    (what to match against)  
V = X @ W_v    (what to return)

Attention(Q, K, V) = softmax(Q @ K^T) @ V
```

In [None]:
torch.manual_seed(42)

d_model = inputs.shape[1]  # 3 (input embedding dimension)
d_k = 2  # dimension of Q, K, V projections (can differ from d_model)

# Learnable weight matrices
W_q = torch.nn.Parameter(torch.randn(d_model, d_k))
W_k = torch.nn.Parameter(torch.randn(d_model, d_k))
W_v = torch.nn.Parameter(torch.randn(d_model, d_k))

print(f"Input dim: {d_model}, Projection dim: {d_k}")
print(f"W_q shape: {W_q.shape}")
print(f"W_k shape: {W_k.shape}")
print(f"W_v shape: {W_v.shape}")

# Project inputs into Q, K, V spaces
Q = inputs @ W_q  # [6, 3] @ [3, 2] = [6, 2]
K = inputs @ W_k  # [6, 3] @ [3, 2] = [6, 2]
V = inputs @ W_v  # [6, 3] @ [3, 2] = [6, 2]

print(f"\nQ shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")

print(f"\nQuery for token 2 ('journey'): {Q[1]}")
print(f"Key for token 2 ('journey'):   {K[1]}")
print(f"Value for token 2 ('journey'): {V[1]}")

In [None]:
# Compute attention with Q, K, V
# Step 1: Q @ K^T gives attention scores
attn_scores_qk = Q @ K.T  # [6, 2] @ [2, 6] = [6, 6]
print("Attention scores (Q @ K^T):")
print(attn_scores_qk)

# Step 2: Softmax to get weights
attn_weights_qk = torch.softmax(attn_scores_qk, dim=-1)
print("\nAttention weights:")
print(attn_weights_qk)

# Step 3: Weighted sum of VALUES (not inputs!)
context_vecs_qk = attn_weights_qk @ V  # [6, 6] @ [6, 2] = [6, 2]
print(f"\nContext vectors shape: {context_vecs_qk.shape}")
print("Context vectors:")
print(context_vecs_qk)

---

## Scaled Dot-Product Attention

There's a problem with the raw dot product: when the dimension `d_k` is large, the dot products can become very large, pushing softmax into regions with extremely small gradients.

**Solution:** Divide by `sqrt(d_k)` to keep the variance stable.

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

This is the attention formula from "Attention Is All You Need" (Vaswani et al., 2017).

In [None]:
import math

# Scaled attention scores
scale = math.sqrt(d_k)
print(f"Scale factor: sqrt({d_k}) = {scale:.4f}")

scaled_attn_scores = (Q @ K.T) / scale
print(f"\nUnscaled scores (sample): {attn_scores_qk[1].detach()}")
print(f"Scaled scores (sample):   {scaled_attn_scores[1].detach()}")

# The scaled scores have smaller magnitude, leading to softer attention weights
scaled_attn_weights = torch.softmax(scaled_attn_scores, dim=-1)
print(f"\nUnscaled weights: {attn_weights_qk[1].detach()}")
print(f"Scaled weights:   {scaled_attn_weights[1].detach()}")
print("\nNotice: scaled weights are more uniform (less peaky) = better gradients!")

# Final context vectors with scaling
scaled_context = scaled_attn_weights @ V
print(f"\nScaled context vectors shape: {scaled_context.shape}")

## Exercise 1: Implement Scaled Dot-Product Attention as a Function

**Your Task:** Combine everything above into a clean function that performs scaled dot-product attention.

**Steps:**
1. Compute Q, K, V from inputs using weight matrices
2. Compute attention scores: Q @ K^T
3. Scale by sqrt(d_k)
4. Apply softmax
5. Compute context vectors: weights @ V

**Hints:**
- Input shape: `[seq_len, d_model]`
- Output shape: `[seq_len, d_v]` where d_v is the value projection dimension

In [None]:
def scaled_dot_product_attention(X, W_q, W_k, W_v):
    """
    Compute scaled dot-product attention.
    
    Args:
        X: input tensor [seq_len, d_model]
        W_q: query weight matrix [d_model, d_k]
        W_k: key weight matrix [d_model, d_k]
        W_v: value weight matrix [d_model, d_v]
    
    Returns:
        context_vectors: [seq_len, d_v]
        attention_weights: [seq_len, seq_len]
    """
    # TODO: Implement scaled dot-product attention
    
    # 1. Project inputs to Q, K, V
    Q = None  # Your code
    K = None  # Your code
    V = None  # Your code
    
    # 2. Compute attention scores
    d_k = None  # Your code: get dimension of keys
    attn_scores = None  # Your code: Q @ K^T / sqrt(d_k)
    
    # 3. Apply softmax
    attn_weights = None  # Your code
    
    # 4. Compute context vectors
    context = None  # Your code
    
    return context, attn_weights

# Test your implementation
context, weights = scaled_dot_product_attention(inputs, W_q, W_k, W_v)
print(f"Context shape: {context.shape}")
print(f"Weights shape: {weights.shape}")

### Solution for Exercise 1

In [None]:
def scaled_dot_product_attention(X, W_q, W_k, W_v):
    """
    Compute scaled dot-product attention - SOLUTION
    """
    # 1. Project inputs to Q, K, V
    Q = X @ W_q
    K = X @ W_k
    V = X @ W_v
    
    # 2. Compute scaled attention scores
    d_k = K.shape[-1]
    attn_scores = (Q @ K.T) / math.sqrt(d_k)
    
    # 3. Apply softmax
    attn_weights = torch.softmax(attn_scores, dim=-1)
    
    # 4. Compute context vectors
    context = attn_weights @ V
    
    return context, attn_weights

# Test
context, weights = scaled_dot_product_attention(inputs, W_q, W_k, W_v)
print(f"Context shape: {context.shape}")  # [6, 2]
print(f"Weights shape: {weights.shape}")  # [6, 6]
print(f"\nContext vectors:\n{context.detach()}")
print(f"\nAttention weights:\n{weights.detach()}")

---

## Causal Masking (for Autoregressive Models)

In models like GPT, each token should only attend to **previous tokens** (and itself), not future tokens. This is because during generation, future tokens don't exist yet.

We achieve this by setting attention scores for future positions to `-inf` before softmax, which makes their attention weights become 0.

```
Causal mask for sequence length 6:

  Your journey starts with one step
Your    ✓     ✗      ✗     ✗    ✗    ✗
journey ✓     ✓      ✗     ✗    ✗    ✗  
starts  ✓     ✓      ✓     ✗    ✗    ✗
with    ✓     ✓      ✓     ✓    ✗    ✗
one     ✓     ✓      ✓     ✓    ✓    ✗
step    ✓     ✓      ✓     ✓    ✓    ✓
```

In [None]:
seq_len = inputs.shape[0]

# Create causal mask: upper triangle = True (positions to mask)
causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
print("Causal mask (True = masked/blocked):")
print(causal_mask.int())

# Compute attention scores as before
Q = inputs @ W_q
K = inputs @ W_k
V = inputs @ W_v
attn_scores_causal = (Q @ K.T) / math.sqrt(d_k)

print(f"\nScores before masking:\n{attn_scores_causal.detach()}")

# Apply mask: set future positions to -inf
attn_scores_causal = attn_scores_causal.masked_fill(causal_mask, float('-inf'))
print(f"\nScores after masking:\n{attn_scores_causal.detach()}")

# Softmax: -inf becomes 0
causal_weights = torch.softmax(attn_scores_causal, dim=-1)
print(f"\nCausal attention weights:")
print(causal_weights.detach())

# Context vectors with causal attention
causal_context = causal_weights @ V
print(f"\nCausal context vectors:\n{causal_context.detach()}")

In [None]:
# Visualize: bidirectional vs causal attention weights
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, w, title in zip(axes, 
                          [weights.detach(), causal_weights.detach()],
                          ["Bidirectional (BERT-style)", "Causal (GPT-style)"]):
    im = ax.imshow(w.numpy(), cmap='Blues', vmin=0, vmax=0.5)
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, fontsize=10)
    ax.set_yticklabels(tokens, fontsize=10)
    ax.set_xlabel("Key")
    ax.set_ylabel("Query")
    ax.set_title(title, fontsize=13)
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            ax.text(j, i, f"{w[i,j]:.2f}", ha="center", va="center", fontsize=9)

plt.suptitle("Bidirectional vs Causal Attention", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

---

## Multi-Head Attention

Instead of performing one attention computation, we perform **multiple attention computations in parallel** (called "heads"), each with its own Q, K, V projections.

**Why multiple heads?**
- Different heads can learn to attend to different types of relationships
- One head might focus on syntactic relationships, another on semantic ones
- It's like having multiple "perspectives" on the same input

**How it works:**
1. Split the model dimension into `n_heads` smaller dimensions
2. Each head performs independent scaled dot-product attention
3. Concatenate all head outputs
4. Project back to the original dimension with a final linear layer

```
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) @ W_o

where head_i = Attention(X @ W_q_i, X @ W_k_i, X @ W_v_i)
```

In [None]:
torch.manual_seed(42)

# For multi-head attention, let's use a larger embedding dimension
d_model_mh = 6    # model dimension (must be divisible by n_heads)
n_heads = 2       # number of attention heads
d_head = d_model_mh // n_heads  # dimension per head = 3

print(f"d_model: {d_model_mh}, n_heads: {n_heads}, d_head: {d_head}")

# Create new inputs with d_model=6 (by repeating our 3-dim inputs)
inputs_mh = torch.cat([inputs, inputs], dim=-1)  # [6, 6]
print(f"Input shape: {inputs_mh.shape}")

# Weight matrices for ALL heads combined
# Instead of separate W_q for each head, we use one big matrix and reshape
W_q_mh = nn.Linear(d_model_mh, d_model_mh, bias=False)
W_k_mh = nn.Linear(d_model_mh, d_model_mh, bias=False)
W_v_mh = nn.Linear(d_model_mh, d_model_mh, bias=False)
W_o = nn.Linear(d_model_mh, d_model_mh, bias=False)  # output projection

# Project all inputs at once
Q_mh = W_q_mh(inputs_mh)  # [6, 6]
K_mh = W_k_mh(inputs_mh)  # [6, 6]
V_mh = W_v_mh(inputs_mh)  # [6, 6]

print(f"\nQ shape (before split): {Q_mh.shape}")

# Reshape to separate heads: [seq_len, d_model] -> [seq_len, n_heads, d_head]
Q_heads = Q_mh.view(seq_len, n_heads, d_head)
K_heads = K_mh.view(seq_len, n_heads, d_head)
V_heads = V_mh.view(seq_len, n_heads, d_head)

# Transpose to [n_heads, seq_len, d_head] for batch matrix multiply
Q_heads = Q_heads.permute(1, 0, 2)  # [2, 6, 3]
K_heads = K_heads.permute(1, 0, 2)  # [2, 6, 3]
V_heads = V_heads.permute(1, 0, 2)  # [2, 6, 3]

print(f"Q per head shape: {Q_heads.shape}")
print(f"Head 0 Q:\n{Q_heads[0].detach()}")
print(f"Head 1 Q:\n{Q_heads[1].detach()}")

In [None]:
# Compute attention for ALL heads simultaneously using batch matrix multiply
# attn_scores: [n_heads, seq_len, d_head] @ [n_heads, d_head, seq_len] = [n_heads, seq_len, seq_len]
attn_scores_mh = torch.bmm(Q_heads, K_heads.transpose(1, 2)) / math.sqrt(d_head)
attn_weights_mh = torch.softmax(attn_scores_mh, dim=-1)

print(f"Attention scores shape: {attn_scores_mh.shape}")
print(f"Attention weights shape: {attn_weights_mh.shape}")

# Context vectors per head: [n_heads, seq_len, seq_len] @ [n_heads, seq_len, d_head] = [n_heads, seq_len, d_head]
head_outputs = torch.bmm(attn_weights_mh, V_heads)
print(f"\nPer-head output shape: {head_outputs.shape}")

# Concatenate heads: [n_heads, seq_len, d_head] -> [seq_len, n_heads * d_head] = [seq_len, d_model]
head_outputs = head_outputs.permute(1, 0, 2)  # [seq_len, n_heads, d_head]
concat_output = head_outputs.reshape(seq_len, d_model_mh)  # [seq_len, d_model]
print(f"Concatenated shape: {concat_output.shape}")

# Final output projection
mh_output = W_o(concat_output)
print(f"Final output shape: {mh_output.shape}")
print(f"\nMulti-head attention output:\n{mh_output.detach()}")

In [None]:
# Visualize attention patterns for each head
fig, axes = plt.subplots(1, n_heads, figsize=(12, 5))

for h in range(n_heads):
    ax = axes[h]
    w = attn_weights_mh[h].detach().numpy()
    im = ax.imshow(w, cmap='Blues', vmin=0, vmax=0.5)
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, fontsize=10)
    ax.set_yticklabels(tokens, fontsize=10)
    ax.set_title(f"Head {h+1}", fontsize=13)
    ax.set_xlabel("Key")
    ax.set_ylabel("Query")
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            ax.text(j, i, f"{w[i,j]:.2f}", ha="center", va="center", fontsize=9)

plt.suptitle("Multi-Head Attention: Each Head Learns Different Patterns", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()
print("Notice how each head attends to different positions!")

## Exercise 2: Implement Multi-Head Attention as an nn.Module

**Your Task:** Package multi-head attention into a reusable PyTorch module.

**Steps:**
1. Initialize weight matrices W_q, W_k, W_v, W_o as `nn.Linear` layers
2. In `forward()`: project → split heads → compute attention → concat → output project
3. Support optional causal masking

**Hints:**
- Use `view()` and `permute()` to reshape between `[seq_len, d_model]` and `[n_heads, seq_len, d_head]`
- Use `torch.bmm()` for batched matrix multiplication
- Use `masked_fill()` for causal masking

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        # TODO: Create linear layers for Q, K, V projections and output
        self.W_q = None  # Your code: nn.Linear(d_model, d_model, bias=False)
        self.W_k = None  # Your code
        self.W_v = None  # Your code
        self.W_o = None  # Your code
    
    def forward(self, x, causal=False):
        """
        Args:
            x: input tensor [seq_len, d_model]
            causal: whether to apply causal masking
        Returns:
            output: [seq_len, d_model]
            attn_weights: [n_heads, seq_len, seq_len]
        """
        seq_len = x.shape[0]
        
        # TODO: 1. Project to Q, K, V
        Q = None  # Your code
        K = None  # Your code
        V = None  # Your code
        
        # TODO: 2. Reshape to [n_heads, seq_len, d_head]
        Q = None  # Your code: view then permute
        K = None  # Your code
        V = None  # Your code
        
        # TODO: 3. Compute scaled attention scores
        attn_scores = None  # Your code
        
        # TODO: 4. Apply causal mask if needed
        if causal:
            pass  # Your code: create mask and apply with masked_fill
        
        # TODO: 5. Softmax
        attn_weights = None  # Your code
        
        # TODO: 6. Apply attention to values
        head_outputs = None  # Your code
        
        # TODO: 7. Concat heads and output projection
        output = None  # Your code
        
        return output, attn_weights

# Test your implementation
mha = MultiHeadAttention(d_model=6, n_heads=2)
out, attn = mha(inputs_mh, causal=False)
print(f"Output shape: {out.shape}")
print(f"Attention shape: {attn.shape}")

### Solution for Exercise 2

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-Head Attention - SOLUTION"""
    
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
    
    def forward(self, x, causal=False):
        seq_len = x.shape[0]
        
        # 1. Project to Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 2. Reshape: [seq_len, d_model] -> [n_heads, seq_len, d_head]
        Q = Q.view(seq_len, self.n_heads, self.d_head).permute(1, 0, 2)
        K = K.view(seq_len, self.n_heads, self.d_head).permute(1, 0, 2)
        V = V.view(seq_len, self.n_heads, self.d_head).permute(1, 0, 2)
        
        # 3. Scaled dot-product attention scores
        attn_scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(self.d_head)
        
        # 4. Causal mask
        if causal:
            mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
            attn_scores = attn_scores.masked_fill(mask.unsqueeze(0), float('-inf'))
        
        # 5. Softmax
        attn_weights = torch.softmax(attn_scores, dim=-1)
        
        # 6. Weighted sum of values
        head_outputs = torch.bmm(attn_weights, V)  # [n_heads, seq_len, d_head]
        
        # 7. Concat and project: [n_heads, seq_len, d_head] -> [seq_len, d_model]
        head_outputs = head_outputs.permute(1, 0, 2)  # [seq_len, n_heads, d_head]
        concat = head_outputs.reshape(seq_len, self.d_model)
        output = self.W_o(concat)
        
        return output, attn_weights

# Test
torch.manual_seed(42)
mha = MultiHeadAttention(d_model=6, n_heads=2)

# Bidirectional
out_bi, attn_bi = mha(inputs_mh, causal=False)
print(f"Bidirectional output shape: {out_bi.shape}")

# Causal
out_causal, attn_causal = mha(inputs_mh, causal=True)
print(f"Causal output shape: {out_causal.shape}")

# Visualize both
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
for col, (attn_w, title) in enumerate([(attn_bi, "Bidirectional"), (attn_causal, "Causal")]):
    for h in range(2):
        ax = axes[h][col]
        w = attn_w[h].detach().numpy()
        ax.imshow(w, cmap='Blues', vmin=0, vmax=0.5)
        ax.set_xticks(range(len(tokens)))
        ax.set_yticks(range(len(tokens)))
        ax.set_xticklabels(tokens, fontsize=9)
        ax.set_yticklabels(tokens, fontsize=9)
        ax.set_title(f"{title} - Head {h+1}", fontsize=11)
        for i in range(len(tokens)):
            for j in range(len(tokens)):
                ax.text(j, i, f"{w[i,j]:.2f}", ha="center", va="center", fontsize=8)

plt.suptitle("Multi-Head Attention Module", fontsize=14)
plt.tight_layout()
plt.show()

---

## Comparison with PyTorch's Built-in Multi-Head Attention

PyTorch provides `nn.MultiheadAttention` which does everything we just built. Let's verify our understanding matches the official implementation.

In [None]:
# PyTorch's built-in MultiheadAttention
# Note: PyTorch expects input as [seq_len, batch_size, d_model]
pytorch_mha = nn.MultiheadAttention(embed_dim=6, num_heads=2, batch_first=False, bias=False)

# Add batch dimension: [seq_len, d_model] -> [seq_len, 1, d_model]
x_batched = inputs_mh.unsqueeze(1)

# For self-attention: query = key = value = input
with torch.no_grad():
    pt_output, pt_weights = pytorch_mha(x_batched, x_batched, x_batched)

print(f"PyTorch MHA output shape: {pt_output.shape}")  # [6, 1, 6]
print(f"PyTorch attention weights shape: {pt_weights.shape}")  # [1, 6, 6]
print(f"\nOur implementation and PyTorch's use the same architecture!")
print(f"The difference is only in weight initialization (random).")

---

## Summary

In this notebook, we built self-attention from the ground up:

1. **Simple Attention**: dot product between input vectors → softmax → weighted sum
2. **Context Vectors**: each token gets a new representation that blends information from all tokens
3. **Q/K/V Projections**: learnable weight matrices separate "what to search for" (Q), "what to match against" (K), and "what to return" (V)
4. **Scaled Dot-Product Attention**: divide by sqrt(d_k) for stable gradients
5. **Causal Masking**: prevent attention to future tokens (for autoregressive models like GPT)
6. **Multi-Head Attention**: multiple parallel attention computations capture different relationships

### Key Formula

$$\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W^O$$

$$\text{where head}_i = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i$$

### Next Steps
- **Module 7b**: Build a complete Transformer block using this attention mechanism
- Add positional encoding, layer normalization, feed-forward networks, and residual connections

### References
- Paper: Vaswani et al. "Attention Is All You Need" (2017)
- Blog: Jay Alammar "The Illustrated Transformer"
- Book: Sebastian Raschka "Build a Large Language Model (From Scratch)" (Ch. 3)