# Causal Attention: Teaching Models to Respect Time

In the previous notebook, we learned about **Self-Attention with trainable weights** — how words can learn to pay attention to each other using Query, Key, and Value transformations.

But there's a **critical problem** for language models: in standard self-attention, every word can see every other word, including words that come **after** it!

## Why is This a Problem?

Imagine you're training a model to predict the next word:

```
"Your journey starts with one ____"
```

If the word "starts" can see the word "step" (which comes after it), the model is **cheating**! It already knows what comes next.

**During training:** The model would learn to just copy future words instead of actually learning language patterns.

**During generation:** There ARE no future words yet — the model needs to generate them one by one!

---

## What is Causal Attention?

**Causal** means "respecting cause and effect" — things in the past can affect the future, but not vice versa.

```
Standard Self-Attention:          Causal (Masked) Attention:
                                  
"Your"    sees: Your, journey,    "Your"    sees: Your ✓
                starts, with,                      
                one, step                          
                                  
"journey" sees: Your, journey,    "journey" sees: Your ✓, journey ✓
                starts, with,     
                one, step         
                                  
"starts"  sees: Your, journey,    "starts"  sees: Your ✓, journey ✓, starts ✓
                starts, with,     
                one, step         
                                  
... and so on                     ... each word only sees itself and previous words
```

---

## The Key Idea: Masking

We'll use a **mask** to hide future tokens. The mask is a triangular matrix:

```
         Your  journey  starts  with  one  step
Your     [1      0        0      0     0    0  ]  ← "Your" only sees itself
journey  [1      1        0      0     0    0  ]  ← "journey" sees Your + itself
starts   [1      1        1      0     0    0  ]  ← "starts" sees first 3 words
with     [1      1        1      1     0    0  ]  ← ... and so on
one      [1      1        1      1     1    0  ]
step     [1      1        1      1     1    1  ]  ← "step" sees everything
```

**1 = can see, 0 = cannot see (masked)**

---

## What We'll Learn in This Notebook

1. **Why masking is necessary** for autoregressive language models
2. **Simple masking approach** — multiply by a triangular mask
3. **The problem with simple masking** — broken probability distributions
4. **Efficient masking** — using `-inf` before softmax
5. **Dropout regularization** — preventing overfitting in attention
6. **Batched processing** — handling multiple sequences at once
7. **Complete CausalAttention class** — production-ready implementation

Let's build this step by step!

---

## Install Dependencies

In [1]:
!pip install torch tiktoken transformers

Collecting tiktoken
  Downloading tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (6.7 kB)
Downloading tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: tiktoken
Successfully installed tiktoken-0.12.0


## Step 1: Setting Up Our Input Embeddings

We'll use the same sentence from our previous notebooks:

**"Your journey starts with one step"**

Each word is represented as a 3-dimensional embedding vector. Think of these as coordinates in a "meaning space" where similar words are close together.

```
Word      Position    Embedding Vector         What it might encode
────      ────────    ────────────────         ────────────────────
"Your"       x¹       [0.43, 0.15, 0.89]      possessive, personal
"journey"    x²       [0.55, 0.87, 0.66]      noun, abstract concept
"starts"     x³       [0.57, 0.85, 0.64]      verb, beginning action
"with"       x⁴       [0.22, 0.58, 0.33]      preposition, connector
"one"        x⁵       [0.77, 0.25, 0.10]      number, singular
"step"       x⁶       [0.05, 0.80, 0.55]      noun, concrete action
```

**Note:** In a real model, these embeddings would be learned, not hand-crafted. But using fixed values helps us trace the computations exactly.

In [3]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
print(x_2)
print(d_in)


tensor([0.5500, 0.8700, 0.6600])
3


## Step 2: Recap — Self-Attention from Previous Notebook

Before we add masking, let's quickly bring in the `SelfAttention_v2` class we built before. This gives us the foundation to build upon.

**Quick Reminder of How Self-Attention Works:**

```
┌─────────────────────────────────────────────────────────────────────┐
│                     SELF-ATTENTION PIPELINE                         │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│   Input Embeddings                                                   │
│         │                                                            │
│         ├──────────────────┬──────────────────┐                     │
│         ▼                  ▼                  ▼                     │
│    ┌─────────┐        ┌─────────┐        ┌─────────┐               │
│    │× W_query│        │ × W_key │        │× W_value│               │
│    └────┬────┘        └────┬────┘        └────┬────┘               │
│         │                  │                  │                     │
│         ▼                  ▼                  ▼                     │
│      Queries             Keys              Values                   │
│         │                  │                  │                     │
│         └────────┬─────────┘                  │                     │
│                  ▼                            │                     │
│           Q × Kᵀ = Attention Scores           │                     │
│                  │                            │                     │
│                  ▼                            │                     │
│      softmax(scores/√d_k) = Weights           │                     │
│                  │                            │                     │
│                  └────────────┬───────────────┘                     │
│                               ▼                                     │
│                    Weights × Values = Context                       │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘
```

The key insight: **Right now, every word attends to every other word.** We need to fix this for language modeling!

In [4]:
import torch.nn as nn
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [5]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out) 

## Step 3: The Problem — All Words See All Words

Let's compute the attention weights to see the issue. We'll manually run through the self-attention steps:

1. **Transform inputs → Queries, Keys** using learned weight matrices
2. **Compute attention scores** = Queries × Keysᵀ (dot products measuring similarity)
3. **Apply softmax** to get attention weights (probabilities that sum to 1)

Watch the attention weights matrix — notice how **every position has non-zero weights for every other position**. This means each word is "looking at" all other words, including future ones!

In [6]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### Understanding the Attention Weights Matrix

Let's visualize what this matrix means:

```
              ──────────── Keys (what each word offers) ────────────
              Your   journey  starts   with    one    step
Queries   Your    [0.19    0.16     0.17    0.16    0.17   0.15]  ← attends to ALL
(what     journey [0.20    0.17     0.17    0.15    0.17   0.15]  ← attends to ALL
each      starts  [0.20    0.17     0.17    0.15    0.17   0.15]  ← attends to ALL
word      with    [0.19    0.17     0.17    0.16    0.17   0.16]  ← attends to ALL
asks)     one     [0.18    0.17     0.17    0.16    0.17   0.16]  ← attends to ALL
          step    [0.19    0.17     0.17    0.15    0.17   0.15]  ← attends to ALL
```

**The Problem:** Look at row 1 ("Your") — it has non-zero attention to "journey", "starts", "with", "one", and "step". But when generating "Your", we shouldn't know about ANY future words!

**Each row should only have non-zero values up to and including that position.**

---

## Step 4: First Attempt — Simple Masking with Multiplication

Our first approach: create a **lower triangular matrix** of ones and zeros, then **multiply** it with the attention weights to zero out the upper triangle (future positions).

`torch.tril()` = "triangular lower" — gives us a matrix with 1s on and below the diagonal, 0s above.

In [7]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


### Visualizing the Mask

```
         Your  journey  starts  with  one  step
Your     [1      0        0      0     0    0  ]   ← Can only see itself
journey  [1      1        0      0     0    0  ]   ← Can see Your + itself
starts   [1      1        1      0     0    0  ]   ← Can see first 3 words
with     [1      1        1      1     0    0  ]   ← Can see first 4 words
one      [1      1        1      1     1    0  ]   ← Can see first 5 words
step     [1      1        1      1     1    1  ]   ← Can see everything
```

Now let's multiply: `attention_weights × mask`. Wherever the mask is 0, the attention weight becomes 0.

In [8]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


### The Problem: Rows Don't Sum to 1 Anymore!

Look at the first row: `[0.19, 0.0, 0.0, 0.0, 0.0, 0.0]`

**This sums to only 0.19, not 1.0!**

Attention weights are supposed to be a **probability distribution** — they must sum to 1 so we compute a proper weighted average of values.

```
Before masking:  0.19 + 0.16 + 0.17 + 0.16 + 0.17 + 0.15 = 1.00 ✓
After masking:   0.19 + 0.00 + 0.00 + 0.00 + 0.00 + 0.00 = 0.19 ✗
```

**Simple fix:** Divide each row by its sum to renormalize.

In [9]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


### Success! Now Rows Sum to 1

```
Row 1 (Your):    [1.00, 0.00, 0.00, 0.00, 0.00, 0.00]  sum = 1.00 ✓
Row 2 (journey): [0.55, 0.45, 0.00, 0.00, 0.00, 0.00]  sum = 1.00 ✓
Row 3 (starts):  [0.38, 0.31, 0.31, 0.00, 0.00, 0.00]  sum = 1.00 ✓
...
```

**This works, but there's a more elegant approach!**

---

## Step 5: The Better Way — Masking with -∞ Before Softmax

Instead of:
1. Apply softmax → get weights
2. Multiply by mask → zero out future
3. Renormalize → divide by row sum

We can do:
1. Add -∞ to attention scores where we want to mask
2. Apply softmax → **automatically get zeros!**

**Why does this work?**

```
softmax(x) = e^x / Σe^x

When x = -∞:
- e^(-∞) = 0
- So those positions contribute 0 to the numerator AND denominator
- Result: they naturally become 0, and remaining values sum to 1!
```

**Much cleaner!** One operation instead of three, and mathematically equivalent.

### Creating the Upper Triangular Mask

`torch.triu()` = "triangular upper" — gives us 1s above the diagonal (positions to mask).

We then use `masked_fill()` to replace those positions with -∞.

In [10]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


### Understanding the Masked Scores

Look at the output:
```
Row 1: [0.29,  -inf,  -inf,  -inf,  -inf,  -inf]  "Your" can only see itself
Row 2: [0.47,  0.17,  -inf,  -inf,  -inf,  -inf]  "journey" sees Your + itself
Row 3: [0.46,  0.17,  0.17,  -inf,  -inf,  -inf]  "starts" sees first 3
...
```

The -∞ values will become 0 after softmax. Now let's apply softmax to get our final attention weights:

In [11]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### Perfect! The Causal Attention Weights

```
              Your   journey  starts   with    one    step
Your         [1.00    0.00     0.00    0.00   0.00   0.00]  ← only self
journey      [0.55    0.45     0.00    0.00   0.00   0.00]  ← past + self
starts       [0.38    0.31     0.31    0.00   0.00   0.00]  ← past + self
with         [0.28    0.25     0.25    0.23   0.00   0.00]  ← past + self
one          [0.22    0.20     0.20    0.19   0.20   0.00]  ← past + self
step         [0.19    0.17     0.17    0.15   0.17   0.15]  ← all (last word)
```

**Notice:**
- Row 1 is `[1, 0, 0, 0, 0, 0]` — "Your" puts 100% attention on itself
- Each subsequent row has more non-zero entries
- Row 6 is identical to standard self-attention (last word sees everything)
- **All rows sum to 1.0** — proper probability distributions!

---

## Step 6: Adding Dropout for Regularization

**What is Dropout?**

Dropout is a technique to prevent **overfitting** (when a model memorizes training data instead of learning patterns). During training, we randomly "drop" (set to zero) some values with probability `p`.

**Why use dropout on attention weights?**

Without dropout, a model might learn to rely too heavily on specific attention patterns. Dropout forces the model to be robust — it can't depend on any single connection always being there.

**How it works:**

```
Original:  [0.2, 0.3, 0.1, 0.4]
                  ↓
           Randomly zero some (p=0.5)
                  ↓
Dropped:   [0.4, 0.0, 0.2, 0.0]   ← zeros inserted
                  ↓
           Scale up remaining (×2)
                  ↓
Final:     [0.4, 0.0, 0.2, 0.0]   ← keeps expected value same
```

**The 2× scaling:** When we drop 50% of values, we double the remaining ones. This way, the average output stays the same whether dropout is on (training) or off (inference).

Let's see dropout in action:

In [12]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


### Understanding the Dropout Output

We applied 50% dropout to a matrix of 1s:

```
Original: all 1s
After:    2s and 0s randomly placed
```

**Why 2s instead of 1s?** 

With 50% dropout:
- ~Half the values become 0
- The remaining half are scaled by `1/(1-0.5) = 2`
- This keeps the expected value the same!

```
Expected value before: 1.0 × 1.0 = 1.0
Expected value after:  (0.5 × 0) + (0.5 × 2) = 1.0  ✓
```

Now let's apply dropout to our causal attention weights:

In [13]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


### Dropout Applied to Attention Weights

Notice:
- Some attention weights are now 0 (dropped)
- Others are doubled (scaled up)
- Row 2 is all zeros! The model can't use "journey"'s attention for this forward pass
- This forces the model to not rely too heavily on any single attention pattern

**Important:** Dropout is only active during training! During inference/generation, we disable it so the model uses all learned patterns.

---

## Step 7: Handling Batched Inputs

In practice, we don't process one sentence at a time. We process **batches** of sentences together for efficiency (GPUs are optimized for parallel operations).

**Batched tensor shape:** `[batch_size, sequence_length, embedding_dim]`

```
Single input:     [6, 3]         ← 6 tokens, 3-dim embeddings
Batched input:    [2, 6, 3]      ← 2 sequences, each with 6 tokens, 3-dim embeddings
                   ↑  ↑  ↑
                batch seq embed
```

Let's create a batch by stacking our input twice:

In [14]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


We now have a batch of 2 sequences, each with 6 tokens of 3-dimensional embeddings.

---

## Step 8: The Complete CausalAttention Class

Now let's put everything together into a production-ready PyTorch module!

### Key Components:

1. **`__init__`**: Set up the module
   - Create Q, K, V linear layers
   - Create dropout layer
   - Register the causal mask as a **buffer** (not a parameter — it's not learned)

2. **`forward`**: Compute causal attention
   - Transform inputs → Q, K, V
   - Compute attention scores
   - Apply causal mask (set future positions to -∞)
   - Apply softmax
   - Apply dropout
   - Compute weighted sum of values

### Why `register_buffer` for the mask?

```python
self.register_buffer('mask', torch.triu(...))
```

- The mask is a **constant** — it doesn't change during training
- `register_buffer` makes it part of the module's state (saved/loaded with model)
- It moves to GPU automatically when you call `model.to('cuda')`
- But it's NOT a learnable parameter — no gradients computed for it

### Handling Variable Sequence Lengths

```python
self.mask.bool()[:num_tokens, :num_tokens]
```

The mask is created for the maximum context length, but we slice it to match the actual sequence length. This allows the same module to handle sequences of different lengths!

In [15]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
           'mask',
           torch.triu(torch.ones(context_length, context_length),
           diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)   
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

### Understanding the Code Line by Line

```python
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        
        # The three learnable projections
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(dropout)
        
        # Pre-computed causal mask (upper triangular = positions to mask)
        # register_buffer: part of state but not a learnable parameter
        self.register_buffer(
           'mask',
           torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape  # Unpack batch dimensions
        
        # Step 1: Project to Q, K, V
        keys = self.W_key(x)       # [batch, seq, d_out]
        queries = self.W_query(x)  # [batch, seq, d_out]
        values = self.W_value(x)   # [batch, seq, d_out]

        # Step 2: Compute attention scores
        # Note: transpose(1,2) swaps seq and d_out dims for batched matmul
        attn_scores = queries @ keys.transpose(1, 2)  # [batch, seq, seq]
        
        # Step 3: Apply causal mask (set future to -inf)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        
        # Step 4: Softmax + scaling
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        
        # Step 5: Apply dropout (only during training)
        attn_weights = self.dropout(attn_weights)

        # Step 6: Weighted sum of values
        context_vec = attn_weights @ values  # [batch, seq, d_out]
        return context_vec
```

Now let's test it with our batched input:

In [16]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


### Output Shape Explained

```
Input:  [2, 6, 3]   →   Output: [2, 6, 2]
         ↑  ↑  ↑                 ↑  ↑  ↑
      batch seq d_in          batch seq d_out
```

- **Batch size (2):** Preserved — we still have 2 sequences
- **Sequence length (6):** Preserved — still 6 positions
- **Embedding dimension:** Changed from `d_in=3` to `d_out=2` (the projection dimension)

Each output vector is a **context-aware representation** that only incorporates information from previous positions (thanks to causal masking)!

---

## Summary: What We Learned

### The Journey from Self-Attention to Causal Attention

```
┌─────────────────────────────────────────────────────────────────────┐
│                    CAUSAL ATTENTION PIPELINE                        │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│   Input: [batch, seq_len, d_in]                                     │
│              │                                                       │
│              ├─────────────────┬─────────────────┐                  │
│              ▼                 ▼                 ▼                  │
│         ┌────────┐        ┌────────┐        ┌────────┐             │
│         │W_query │        │ W_key  │        │W_value │             │
│         └───┬────┘        └───┬────┘        └───┬────┘             │
│             │                 │                 │                   │
│             ▼                 ▼                 ▼                   │
│          Queries            Keys             Values                 │
│             │                 │                 │                   │
│             └────────┬────────┘                 │                   │
│                      ▼                          │                   │
│               Q × Kᵀ = Scores                   │                   │
│                      │                          │                   │
│                      ▼                          │                   │
│           ┌──────────────────┐                  │                   │
│           │  CAUSAL MASK     │  ← NEW!          │                   │
│           │  (set future     │                  │                   │
│           │   to -∞)         │                  │                   │
│           └────────┬─────────┘                  │                   │
│                    ▼                            │                   │
│             Softmax / √d_k                      │                   │
│                    │                            │                   │
│                    ▼                            │                   │
│           ┌──────────────────┐                  │                   │
│           │    DROPOUT       │  ← NEW!          │                   │
│           └────────┬─────────┘                  │                   │
│                    │                            │                   │
│                    └──────────────┬─────────────┘                   │
│                                   ▼                                 │
│                           Weights × Values                          │
│                                   │                                 │
│                                   ▼                                 │
│                   Output: [batch, seq_len, d_out]                   │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘
```

### Key Takeaways

| Concept | What We Learned |
|---------|-----------------|
| **Causal Masking** | Prevents tokens from attending to future positions (essential for autoregressive generation) |
| **-∞ Trick** | Setting scores to -∞ before softmax elegantly produces zeros while maintaining proper probabilities |
| **Dropout** | Randomly drops attention connections during training to prevent overfitting |
| **Batching** | Processing multiple sequences together with shape `[batch, seq, embed]` |
| **register_buffer** | Storing non-learnable constants that travel with the model |

### What's Next?

In the next notebook, we'll learn about **Multi-Head Attention** — running multiple attention "heads" in parallel, each learning to focus on different aspects of the input. This is what makes transformers so powerful!

```
Single Head:     One attention pattern
Multi-Head:      Multiple attention patterns combined
                 → Head 1 might focus on syntax
                 → Head 2 might focus on semantics
                 → Head 3 might focus on position
                 → etc.
```