# The Math of Attention

*Gently, for an old man.*

---

We're going to build up the attention mechanism piece by piece. Each section will have:
1. The intuition (what are we doing and why)
2. The equation (the actual math, in LaTeX)
3. The code (runnable PyTorch, explicit operations)

By the end, you'll have a working attention implementation you built yourself.

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# Reproducibility
torch.manual_seed(42)

# We'll work in float32 on CPU for clarity
# (No need for GPU at this scale)

print("Ready to learn.")

Ready to learn.


---

## 1. The Starting Point: Embedded Tokens

We begin with a sequence of tokens that have already been embedded. Each token is a vector of dimension $d$.

If we have $n$ tokens, our input is a matrix $X$ of shape $(n, d)$:

$$X \in \mathbb{R}^{n \times d}$$

Each row is one token's embedding vector.

In [2]:
# Let's make this concrete.
# Tiny example: 4 tokens, 8-dimensional embeddings

n_tokens = 4
d_model = 8

# Random embeddings (pretend these came from an embedding layer)
X = torch.randn(n_tokens, d_model)

print(f"Input X shape: {X.shape}")
print(f"\nX = ")
print(X.numpy().round(2))

Input X shape: torch.Size([4, 8])

X = 
[[ 1.93  1.49  0.9  -2.11  0.68 -1.23 -0.04 -1.6 ]
 [-0.75  1.65 -0.39 -1.4  -0.73 -0.56 -0.77  0.76]
 [ 1.64 -0.16 -0.5   0.44 -0.76  1.08  0.8   1.68]
 [ 1.28  1.3   0.61  1.33 -0.23  0.04 -0.25  0.86]]


---

## 2. The Three Projections: Q, K, V

Each token plays three roles:
- **Query (Q):** "What am I looking for?"
- **Key (K):** "What do I advertise?"
- **Value (V):** "What do I offer if selected?"

We create these by multiplying $X$ by three learned weight matrices:

$$Q = X W_Q$$
$$K = X W_K$$
$$V = X W_V$$

Where:
- $W_Q, W_K \in \mathbb{R}^{d \times d_k}$ (project to key/query dimension)
- $W_V \in \mathbb{R}^{d \times d_v}$ (project to value dimension)

Often $d_k = d_v = d$ for simplicity, but they don't have to be.

In [3]:
# For our tiny example, let's keep all dimensions the same
d_k = d_model  # query and key dimension
d_v = d_model  # value dimension

# Initialize random projection matrices
# (In a real model, these are learned parameters)
W_Q = torch.randn(d_model, d_k)
W_K = torch.randn(d_model, d_k)
W_V = torch.randn(d_model, d_v)

# Project!
Q = X @ W_Q  # (n, d) @ (d, d_k) = (n, d_k)
K = X @ W_K  # (n, d) @ (d, d_k) = (n, d_k)
V = X @ W_V  # (n, d) @ (d, d_v) = (n, d_v)

print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")

Q shape: torch.Size([4, 8])
K shape: torch.Size([4, 8])
V shape: torch.Size([4, 8])


---

## 3. Attention Scores: Queries Meet Keys

Now we compute how much each query "matches" each key. This is a dot product between every query-key pair.

$$\text{scores} = Q K^T$$

Result shape: $(n, n)$. Entry $(i, j)$ tells us how much token $i$'s query matches token $j$'s key.

**This is the N² part.** Every token compared to every other token.

In [4]:
# Compute attention scores
scores = Q @ K.T  # (n, d_k) @ (d_k, n) = (n, n)

print(f"Scores shape: {scores.shape}")
print(f"\nScores = ")
print(scores.numpy().round(2))

Scores shape: torch.Size([4, 4])

Scores = 
[[ 48.36  -1.43   7.06  16.17]
 [  1.88  14.59 -10.85 -11.88]
 [-20.9   -3.98  16.85   5.96]
 [  7.22   3.67  49.61  35.63]]


---

## 4. Scaling: Preventing Exploding Dot Products

Here's a subtle but important detail. When $d_k$ is large, dot products tend to be large (they're sums of $d_k$ terms). Large inputs to softmax push it into regions where gradients vanish.

**Solution:** Scale down by $\sqrt{d_k}$.

$$\text{scaled\_scores} = \frac{Q K^T}{\sqrt{d_k}}$$

This keeps the variance of the scores roughly constant regardless of dimension.

In [5]:
# Scale the scores
scale = d_k ** 0.5
scaled_scores = scores / scale

print(f"Scale factor: {scale:.2f}")
print(f"\nScaled scores = ")
print(scaled_scores.numpy().round(2))

Scale factor: 2.83

Scaled scores = 
[[17.1  -0.51  2.5   5.72]
 [ 0.67  5.16 -3.84 -4.2 ]
 [-7.39 -1.41  5.96  2.11]
 [ 2.55  1.3  17.54 12.6 ]]


---

## 5. Softmax: Turning Scores into Weights

We need each token's attention to be a probability distribution over all positions—weights that sum to 1.

Softmax does this, applied **row-wise** (each token's scores become a distribution):

$$\text{attention}_{ij} = \frac{\exp(\text{scaled\_scores}_{ij})}{\sum_k \exp(\text{scaled\_scores}_{ik})}$$

Or more compactly:

$$\text{attention} = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right)$$

Each row of the resulting matrix sums to 1.

In [6]:
# Apply softmax row-wise (dim=-1 means last dimension, i.e., across columns)
attention_weights = F.softmax(scaled_scores, dim=-1)

print(f"Attention weights shape: {attention_weights.shape}")
print(f"\nAttention weights = ")
print(attention_weights.numpy().round(3))

# Verify rows sum to 1
print(f"\nRow sums: {attention_weights.sum(dim=-1).numpy().round(6)}")

Attention weights shape: torch.Size([4, 4])

Attention weights = 
[[1.    0.    0.    0.   ]
 [0.011 0.989 0.    0.   ]
 [0.    0.001 0.979 0.021]
 [0.    0.    0.993 0.007]]

Row sums: [1. 1. 1. 1.]


---

## 6. The Output: Weighted Sum of Values

Finally, each token gathers information by taking a weighted sum of all the value vectors, using the attention weights.

$$\text{output} = \text{attention} \cdot V$$

Shape: $(n, n) \cdot (n, d_v) = (n, d_v)$

Token $i$'s output is: $\sum_j \text{attention}_{ij} \cdot V_j$

A blend of all values, weighted by how much token $i$ attended to each position.

In [7]:
# Compute the output
output = attention_weights @ V  # (n, n) @ (n, d_v) = (n, d_v)

print(f"Output shape: {output.shape}")
print(f"\nOutput = ")
print(output.numpy().round(2))

Output shape: torch.Size([4, 8])

Output = 
[[-3.69  0.8   9.47 -2.52 -6.27 -0.84 -3.96 -3.32]
 [-1.78  5.17  3.8   2.56 -3.    1.6   0.38  5.11]
 [-5.22  3.38 -5.24  0.9   3.28 -0.42  3.67 -0.99]
 [-5.21  3.4  -5.28  0.9   3.34 -0.39  3.69 -1.06]]


---

## 7. The Full Equation

Putting it all together, scaled dot-product attention is:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$$

That's it. That's the whole thing.

Let's wrap it in a function:

In [8]:
def scaled_dot_product_attention(Q, K, V):
    """
    Compute scaled dot-product attention.
    
    Args:
        Q: Queries, shape (n, d_k)
        K: Keys, shape (n, d_k)
        V: Values, shape (n, d_v)
    
    Returns:
        output: shape (n, d_v)
        attention_weights: shape (n, n)
    """
    d_k = Q.shape[-1]
    
    # Compute scaled scores
    scores = Q @ K.T / (d_k ** 0.5)
    
    # Softmax to get attention weights
    attention_weights = F.softmax(scores, dim=-1)
    
    # Weighted sum of values
    output = attention_weights @ V
    
    return output, attention_weights


# Verify it matches what we computed step-by-step
output_check, weights_check = scaled_dot_product_attention(Q, K, V)

print(f"Outputs match: {torch.allclose(output, output_check)}")
print(f"Weights match: {torch.allclose(attention_weights, weights_check)}")

Outputs match: True
Weights match: True


---

## What's Next?

We've built **single-head attention**. To get to a full transformer, we still need:

1. **Causal masking** — so tokens can't attend to future positions
2. **Multi-head attention** — running several attention heads in parallel
3. **Feed-forward layer** — the other half of a transformer block
4. **Layer normalization** — keeping activations well-behaved
5. **Positional encoding** — since attention is position-agnostic
6. **The full transformer block** — putting it all together

One cell at a time. One concept at a time.

---

*"See one, do one, teach one."*