# The Problem Attention Solves

**Module 4.2, Lesson 1** — Attention & the Transformer

In this notebook you'll:
- Compute the full attention operation step by step: `softmax(XX^T) X`
- Verify the attention weight matrix is symmetric (and understand why that's a problem)
- Visualize attention heatmaps to see which tokens attend to which
- Demonstrate that separate seeking/offering vectors break symmetry
- (Stretch) Compute attention on real sentences using pretrained GPT-2 embeddings

**For each exercise, PREDICT the output before running the cell.**

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Reproducible results
torch.manual_seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

print(f"PyTorch version: {torch.__version__}")

---
## Exercise 1: Raw Dot-Product Attention on a Tiny Example (Guided)

We'll work with 4 tokens, each with an 8-dimensional embedding. The goal: implement `Attention(X) = softmax(XX^T) X` step by step.

**Before running, predict:** If X has shape `[4, 8]`, what shape will `X @ X.T` be? After applying softmax row-wise, will each row sum to 1?

In [None]:
# Our tiny example: 4 tokens, 8-dim embeddings
torch.manual_seed(42)
n_tokens = 4
embed_dim = 8

# Simulated embeddings for: ["The", "cat", "sat", "here"]
tokens = ["The", "cat", "sat", "here"]
X = torch.randn(n_tokens, embed_dim)

print(f"X shape: {X.shape}")
print(f"\nEmbeddings:")
for i, token in enumerate(tokens):
    print(f"  {token}: {X[i].tolist()[:4]}...")

In [None]:
# Step 1: Compute the similarity score matrix (XX^T)
# Each entry [i, j] is the dot product of embedding i and embedding j

scores = X @ X.T  # shape: [4, 4]

print(f"Score matrix shape: {scores.shape}")
print(f"\nScore matrix (XX^T):")
print(scores.numpy().round(3))

# Verify it's symmetric: score[i,j] == score[j,i]
print(f"\nIs symmetric? {torch.allclose(scores, scores.T, atol=1e-5)}")
print("This is guaranteed: dot(a, b) = dot(b, a) always.")

In [None]:
# Step 2: Apply softmax to each ROW to get attention weights
# Each row becomes a probability distribution (sums to 1)

weights = F.softmax(scores, dim=-1)  # softmax along last dimension (columns)

print(f"Attention weight matrix:")
print(weights.numpy().round(4))

# Verify each row sums to 1
row_sums = weights.sum(dim=-1)
print(f"\nRow sums: {row_sums.tolist()}")
print("Each row sums to 1 — it's a probability distribution.")

In [None]:
# Step 3: Compute the output — weighted average of embeddings for each token

output = weights @ X  # shape: [4, 8]

print(f"Output shape: {output.shape}")
print(f"\nOriginal embedding for 'cat': {X[1].tolist()[:4]}...")
print(f"Attention output for 'cat':    {output[1].tolist()[:4]}...")
print(f"\nThe output is DIFFERENT from the input — it's a context-dependent blend.")

In [None]:
# All three steps in one line:
def raw_dot_product_attention(X: torch.Tensor) -> torch.Tensor:
    """
    Compute raw dot-product attention (no Q/K/V projections).
    
    Attention(X) = softmax(XX^T) X
    
    Args:
        X: Embedding matrix of shape [n_tokens, embed_dim]
    
    Returns:
        Context-dependent representations of shape [n_tokens, embed_dim]
    """
    scores = X @ X.T                       # [n, n] pairwise similarity
    weights = F.softmax(scores, dim=-1)     # [n, n] normalized weights per row
    output = weights @ X                    # [n, d] weighted average
    return output

# Verify it matches our step-by-step computation
output_oneliner = raw_dot_product_attention(X)
print(f"Matches step-by-step? {torch.allclose(output, output_oneliner, atol=1e-5)}")
print("\nThree matrix operations. That's all attention is.")

---
## Exercise 2: Visualize the Attention Heatmap (Guided)

The attention weight matrix is the core data structure. Let's see it.

**Before running, predict:** Will the diagonal of the attention matrix (each token attending to itself) have the highest weights? Why or why not?

In [None]:
def plot_attention(weights: torch.Tensor, tokens: list[str], title: str = "Attention Weights"):
    """Plot attention weight matrix as a heatmap."""
    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(weights.detach().numpy(), cmap='Purples', vmin=0, vmax=1)
    
    # Labels
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens, fontsize=11)
    ax.set_yticklabels(tokens, fontsize=11)
    ax.set_xlabel('Attending to', fontsize=12)
    ax.set_ylabel('Token', fontsize=12)
    
    # Add values in cells
    for i in range(len(tokens)):
        for j in range(len(tokens)):
            val = weights[i, j].item()
            color = 'white' if val > 0.5 else 'black'
            ax.text(j, i, f'{val:.3f}', ha='center', va='center', fontsize=10, color=color)
    
    plt.colorbar(im, ax=ax, label='Weight')
    plt.title(title, fontsize=13)
    plt.tight_layout()
    plt.show()

# Plot our attention weights
scores = X @ X.T
weights = F.softmax(scores, dim=-1)
plot_attention(weights, tokens)

In [None]:
# Now plot the RAW SCORES (before softmax) to see the symmetry clearly
fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# Left: raw scores
im0 = axes[0].imshow(scores.detach().numpy(), cmap='RdBu_r', 
                      vmin=-scores.abs().max().item(), vmax=scores.abs().max().item())
axes[0].set_xticks(range(len(tokens)))
axes[0].set_yticks(range(len(tokens)))
axes[0].set_xticklabels(tokens)
axes[0].set_yticklabels(tokens)
for i in range(len(tokens)):
    for j in range(len(tokens)):
        axes[0].text(j, i, f'{scores[i,j]:.2f}', ha='center', va='center', fontsize=10)
plt.colorbar(im0, ax=axes[0])
axes[0].set_title('Raw Dot-Product Scores (XX\u1d40)\n(Symmetric!)', fontsize=12)

# Right: after softmax
im1 = axes[1].imshow(weights.detach().numpy(), cmap='Purples', vmin=0, vmax=1)
axes[1].set_xticks(range(len(tokens)))
axes[1].set_yticks(range(len(tokens)))
axes[1].set_xticklabels(tokens)
axes[1].set_yticklabels(tokens)
for i in range(len(tokens)):
    for j in range(len(tokens)):
        axes[1].text(j, i, f'{weights[i,j]:.3f}', ha='center', va='center', fontsize=10,
                     color='white' if weights[i,j] > 0.5 else 'black')
plt.colorbar(im1, ax=axes[1])
axes[1].set_title('Attention Weights (after softmax)\n(Rows sum to 1)', fontsize=12)

plt.tight_layout()
plt.show()

print("Left matrix is SYMMETRIC: score[i,j] == score[j,i].")
print("Right matrix: rows sum to 1, but still driven by symmetric scores.")

**Key observation:** The raw score matrix is perfectly symmetric. This means the *raw relevance score* between any two tokens is always the same in both directions. As we’ll see, this is a fundamental limitation.

---
## Exercise 3: Confirming Symmetry (Guided)

Let's formally verify the symmetry property and understand exactly why it's a problem.

**Before running, predict:** Is `XX^T` guaranteed to be symmetric for ANY matrix X? (Hint: think about `(AB)^T = B^T A^T`.)

In [None]:
# Verify symmetry for multiple random embeddings
for trial in range(5):
    X_rand = torch.randn(6, 32)  # 6 tokens, 32 dims
    S = X_rand @ X_rand.T
    is_sym = torch.allclose(S, S.T, atol=1e-5)
    print(f"Trial {trial+1}: XX^T symmetric? {is_sym}")

print("\nAlways symmetric. This is a mathematical property of XX^T, not a coincidence.")
print("dot(a, b) = sum(a_i * b_i) = sum(b_i * a_i) = dot(b, a)")

In [None]:
# Why symmetry is a problem: "The cat chased the mouse"
# For 'cat', 'chased' is relevant because it tells us what the cat DID.
# For 'chased', 'cat' is relevant because it tells us WHO did the chasing.
# Different reasons, but the score is always the same.

# Let's make this concrete with embeddings that simulate meaning:
torch.manual_seed(0)
tokens_chase = ["The", "cat", "chased", "the", "mouse"]
X_chase = torch.randn(5, 16)

scores_chase = X_chase @ X_chase.T

# Compare score(cat -> chased) vs score(chased -> cat)
cat_idx, chased_idx = 1, 2
print(f"score(cat, chased) = {scores_chase[cat_idx, chased_idx]:.4f}")
print(f"score(chased, cat) = {scores_chase[chased_idx, cat_idx]:.4f}")
print(f"Difference: {abs(scores_chase[cat_idx, chased_idx] - scores_chase[chased_idx, cat_idx]):.8f}")
print(f"\nExactly the same. But 'cat' needs 'chased' for a DIFFERENT reason")
print(f"than 'chased' needs 'cat'. Symmetric scores can't express this.")

In [None]:
# What if tokens had TWO vectors — one for seeking, one for offering?
# Then score(A seeking B) = dot(seek_A, offer_B)
# And score(B seeking A) = dot(seek_B, offer_A)
# These are different because seek_A ≠ offer_A in general.

torch.manual_seed(42)
n = 5
d = 16

# Simulate two separate vector types
seeking_vectors = torch.randn(n, d)   # what each token looks for
offering_vectors = torch.randn(n, d)  # what each token advertises

# Score matrix: seeking[i] dot offering[j]
asymmetric_scores = seeking_vectors @ offering_vectors.T

print(f"Asymmetric scores symmetric? {torch.allclose(asymmetric_scores, asymmetric_scores.T, atol=1e-5)}")
print(f"\nscore(cat seeking chased) = {asymmetric_scores[1, 2]:.4f}")
print(f"score(chased seeking cat) = {asymmetric_scores[2, 1]:.4f}")
print(f"Difference: {abs(asymmetric_scores[1, 2] - asymmetric_scores[2, 1]):.4f}")
print(f"\nNow the scores are DIFFERENT in each direction!")
print(f"This is exactly what the next lesson (projections) will achieve.")

---
## Exercise 4 (Stretch): Attention with Pretrained GPT-2 Embeddings (Guided)

Use real GPT-2 token embeddings to compute raw dot-product attention and see whether "bank" gets different attention patterns in different contexts.

**Before running, predict:** The token "bank" has the same embedding regardless of context. After applying raw dot-product attention, will the *output* for "bank" be different in "steep and muddy" vs. "raised interest rates"?

In [None]:
# Install transformers if needed
try:
    from transformers import GPT2Tokenizer, GPT2Model
except ImportError:
    !pip install transformers -q
    from transformers import GPT2Tokenizer, GPT2Model

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
wte = model.wte.weight.detach()  # [50257, 768]

def get_embeddings(sentence: str) -> tuple[torch.Tensor, list[str]]:
    """Get token embeddings for a sentence from GPT-2's embedding table."""
    ids = tokenizer.encode(sentence)
    token_strs = [tokenizer.decode([tid]) for tid in ids]
    embeddings = wte[ids]  # [n_tokens, 768]
    return embeddings, token_strs

def compute_and_plot_attention(sentence: str):
    """Compute raw dot-product attention and visualize."""
    X, tok_strs = get_embeddings(sentence)
    
    scores = X @ X.T
    weights = F.softmax(scores, dim=-1)
    output = weights @ X
    
    plot_attention(weights, tok_strs, f'Attention: "{sentence}"')
    return X, weights, output, tok_strs

print("Ready to compute attention on real sentences.")

In [None]:
# Sentence 1: "bank" near terrain words
X1, W1, out1, toks1 = compute_and_plot_attention("The bank was steep and muddy")

In [None]:
# Sentence 2: "bank" near finance words
X2, W2, out2, toks2 = compute_and_plot_attention("The bank raised interest rates")

In [None]:
# Compare the attention output for "bank" in both sentences
# Find the index of " bank" in each
bank_idx_1 = toks1.index(" bank") if " bank" in toks1 else toks1.index("bank")
bank_idx_2 = toks2.index(" bank") if " bank" in toks2 else toks2.index("bank")

bank_input_1 = X1[bank_idx_1]
bank_input_2 = X2[bank_idx_2]

bank_output_1 = out1[bank_idx_1]
bank_output_2 = out2[bank_idx_2]

# Input embeddings should be IDENTICAL (same token)
input_sim = F.cosine_similarity(bank_input_1.unsqueeze(0), bank_input_2.unsqueeze(0)).item()
# Output embeddings should differ (different contexts)
output_sim = F.cosine_similarity(bank_output_1.unsqueeze(0), bank_output_2.unsqueeze(0)).item()

print(f"Input embedding similarity (should be 1.0): {input_sim:.4f}")
print(f"Output embedding similarity (should be < 1): {output_sim:.4f}")
print(f"\nDifference: {1.0 - output_sim:.4f}")
print(f"\nThe attention output for 'bank' is DIFFERENT in different contexts,")
print(f"even though the input embedding is identical. Context shapes meaning.")
print(f"\nBut with raw dot products, the effect is limited — 'bank' can't")
print(f"*seek* different information in different contexts. The seeking vector")
print(f"is always the same embedding regardless of context.")

In [None]:
# Compare attention patterns for "bank" in both contexts
fig, axes = plt.subplots(1, 2, figsize=(14, 3))

# Sentence 1 attention from "bank"
w1_bank = W1[bank_idx_1].detach().numpy()
axes[0].barh(range(len(toks1)), w1_bank, color='steelblue')
axes[0].set_yticks(range(len(toks1)))
axes[0].set_yticklabels(toks1)
axes[0].set_title('"bank" attention in: steep and muddy')
axes[0].set_xlabel('Attention weight')
axes[0].invert_yaxis()

# Sentence 2 attention from "bank"
w2_bank = W2[bank_idx_2].detach().numpy()
axes[1].barh(range(len(toks2)), w2_bank, color='coral')
axes[1].set_yticks(range(len(toks2)))
axes[1].set_yticklabels(toks2)
axes[1].set_title('"bank" attention in: raised interest rates')
axes[1].set_xlabel('Attention weight')
axes[1].invert_yaxis()

plt.tight_layout()
plt.show()

print("The patterns differ because the context embeddings differ.")
print("But 'bank' uses the SAME seeking vector in both cases.")
print("With separate seeking/offering vectors, 'bank' could seek")
print("terrain words in one context and finance words in another.")

---
## Key Takeaways

1. **Attention is three matrix operations:** `scores = XX^T`, `weights = softmax(scores)`, `output = weights @ X`. Each output token is a weighted average of all input tokens.
2. **The raw score matrix `XX^T` is always symmetric.** `dot(a, b) = dot(b, a)`, so the relevance score between any two tokens is identical in both directions. This is a mathematical fact, not a coincidence.
3. **Symmetry is a limitation.** "Cat" needs "chased" (to know what it did) and "chased" needs "cat" (to know who did it) — different reasons, but the score is forced to be the same.
4. **Separate seeking/offering vectors break symmetry.** When `score(A, B) = dot(seek_A, offer_B)`, the scores become asymmetric because `seek_A` differs from `offer_A`. This is exactly what Q/K projections achieve in the next lesson.
5. **Attention already creates context-dependent representations** — even raw dot-product attention changes "bank" based on surrounding words. But without asymmetric scores, its expressiveness is limited.