In [31]:
!pip install torch tiktoken transformers



# Self-Attention with Trainable Weights

In the previous notebook, we computed attention using raw embeddings directly. The problem? **Those embeddings were hand-crafted** — in reality, the model needs to **learn** what to pay attention to.

This notebook introduces **trainable weight matrices** (Q, K, V) that transform inputs before computing attention:

```
Query = input @ W_query    "What am I looking for?"
Key   = input @ W_key      "What do I contain?"  
Value = input @ W_value    "What do I provide?"
```

---

## Understanding Q, K, V in Simple Terms

### The Library Analogy

Imagine you're in a library looking for information about "how to bake a cake."

| Concept | Library Analogy | What It Does |
|---------|-----------------|--------------|
| **Query (Q)** | Your question: "I need cake recipes" | What you're **searching for** |
| **Key (K)** | Book titles/labels on the shelf | What each book **advertises** it contains |
| **Value (V)** | The actual content inside books | The **real information** you'll get |

#### How it works step-by-step:

1. **You form a Query** → "I want cake recipes"
2. **You compare your Query against all Keys** → You scan book titles
3. **High match = high attention** → "The Art of Baking" matches well!
4. **You retrieve Values based on matches** → You read from the matching books

---

### Why Separate Q, K, V? Why Not Just Use the Input Directly?

This is the key insight! Let's think about the words in our sentence:

**Sentence:** `"Your journey starts with one step"`

Consider the word **"journey"** (x²):
- When "journey" is **asking a question** (Query): "What words relate to me? What starts? With what?" → needs to find connections
- When "journey" is **being searched** (Key): "I am a noun, an abstract concept, the main subject" → needs to advertise what it is
- When "journey" **provides information** (Value): "Here's my semantic meaning to contribute" → the actual content to pass forward

**These are THREE DIFFERENT ROLES for the same word!**

---

### Visual Example with Our Sentence

```
Input word: "journey" (x²) = [0.55, 0.87, 0.66]

                    ┌──────────────────┐
                    │    "journey"     │
                    │  [0.55, 0.87,    │
                    │       0.66]      │
                    └────────┬─────────┘
                             │
             ┌───────────────┼───────────────┐
             │               │               │
             ▼               ▼               ▼
        ┌─────────┐    ┌─────────┐    ┌─────────┐
        │× W_query│    │× W_key  │    │× W_value│
        └────┬────┘    └────┬────┘    └────┬────┘
             │               │               │
             ▼               ▼               ▼
        ┌─────────┐    ┌─────────┐    ┌─────────┐
        │Query    │    │ Key     │    │Value    │
        │"What am │    │"What I  │    │"What I  │
        │I looking│    │contain" │    │provide" │
        │  for?"  │    │         │    │         │
        │         │    │         │    │         │
        │[0.43,   │    │[0.44,   │    │[0.40,   │
        │ 1.46]   │    │ 1.14]   │    │ 1.00]   │
        └─────────┘    └─────────┘    └─────────┘

Then "journey"'s Query compares against ALL Keys:

        Query₂ @ Keys.T = Attention Scores
        
        "journey" asks: "How relevant is each word to me?"
        
        Your:     1.27  ←─┐
        journey:  1.85  ←─┤
        starts:   1.81  ←─┼── These scores show "journey" 
        with:     1.08  ←─┤   attends most to itself and "starts"
        one:      0.56  ←─┤
        step:     1.54  ←─┘
```

---

### The Math (Simplified)

```python
# Each word gets transformed THREE different ways:

Q = input @ W_query   # Shape: [6, 2] - all 6 words get queries
K = input @ W_key     # Shape: [6, 2] - all 6 words get keys
V = input @ W_value   # Shape: [6, 2] - all 6 words get values

# Then attention happens:
attention_scores = Q @ K.T      # "How much does each query match each key?"
attention_weights = softmax(attention_scores)  # Normalize to probabilities
output = attention_weights @ V  # Weighted sum of values
```

---

### Why This is Powerful (The Learning Part)

| Without Q,K,V | With Q,K,V |
|---------------|------------|
| Fixed attention based on raw similarity | **Learnable** attention patterns |
| "journey" always looks the same when asking or answering | "journey" can learn different roles |
| Can't adapt to task | Model learns what's important through training |

**The W matrices are LEARNED during training!** The model discovers:
- What features to look for (W_query)
- What features to advertise (W_key)
- What features to pass forward (W_value)

---

### Super Simple Summary

| | Query | Key | Value |
|---|-------|-----|-------|
| **Question** | "What do I need?" | "What do I have?" | "Here's my content" |
| **Search engine** | Your search terms | Webpage titles | Webpage content |
| **Our example** | "journey" asks: "who relates to me?" | Each word's label: "I'm a verb/noun/etc" | Each word's actual meaning |

The genius is: **the same word plays all three roles**, but with **different learned transformations** for each role!

---

**Why this matters:**
- The model can **learn** different representations for querying vs. being queried
- Attention patterns become **trainable** via backpropagation
- This is how attention works in real transformers like GPT!

## Install Dependencies

## Same Input Embeddings

We use the same hand-crafted embeddings from the previous notebook. Remember: in a real model, these would come from a learned embedding layer.

In [32]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

## Define Dimensions

- `x_2`: The "journey" token embedding (our query token)
- `d_in = 3`: Input dimension (size of each embedding)
- `d_out = 2`: Output dimension (size of Q, K, V vectors)

**Note:** `d_out` can differ from `d_in`! This allows the model to project embeddings into a different dimensional space for attention computation.

In [36]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
print(x_2)
print(d_in)

tensor([0.5500, 0.8700, 0.6600])
3


## Create the Weight Matrices (W_query, W_key, W_value)

These are the **trainable parameters** that transform inputs into queries, keys, and values.

Each matrix has shape `[d_in, d_out]` = `[3, 2]`:
- Takes a 3-dimensional input embedding
- Produces a 2-dimensional Q, K, or V vector

**Key insight:** These matrices start with **random values** (unlike our hand-crafted embeddings). During training, backpropagation will adjust these weights so the model learns **what to pay attention to**.

`requires_grad=False` is set here just for demonstration — in real training, we'd want gradients!

In [39]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
print(W_query)
print(W_key)
print(W_value)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])
Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


## Compute Query, Key, Value for One Token

Transform "journey" (`x_2`) through each weight matrix:

```
x_2 [1×3] @ W_query [3×2] = query_2 [1×2]
x_2 [1×3] @ W_key   [3×2] = key_2   [1×2]  
x_2 [1×3] @ W_value [3×2] = value_2 [1×2]
```

**Intuition:**
- `query_2`: "What is 'journey' looking for?"
- `key_2`: "What does 'journey' contain that others might want?"
- `value_2`: "What information does 'journey' provide when attended to?"

In [40]:
query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor([0.3951, 1.0037])


## Compute Keys and Values for ALL Tokens

We need keys and values for every token (so "journey" can compare against all of them).

```
inputs [6×3] @ W_key   [3×2] = keys   [6×2]
inputs [6×3] @ W_value [3×2] = values [6×2]
```

Each token now has its own 2-dimensional key and value vector.

In [41]:
keys = inputs @ W_key 
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


## Compute Attention Scores

Now we compute how much "journey" should attend to each token using the **dot product between query and keys**.

First, a single attention score (journey attending to itself):
```
attn_score_22 = query_2 · key_2
```

In [42]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


Now compute attention scores for "journey" against ALL tokens at once:

```
query_2 [1×2] @ keys.T [2×6] = attn_scores_2 [1×6]
```

**Compare to previous notebook:** We're doing the same dot product operation, but now using **transformed** Q and K vectors instead of raw embeddings!

In [43]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


## Scaled Dot-Product Attention

**Problem:** When `d_k` (key dimension) is large, dot products can become very large, pushing softmax into regions with tiny gradients.

**Solution:** Scale by `√d_k` before applying softmax:

```
attn_weights = softmax(attn_scores / √d_k)
```

Here `d_k = 2`, so we divide by `√2 ≈ 1.414`. This keeps the variance of the dot products stable regardless of dimension.

In [44]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


## Compute the Context Vector

Same as before: weighted sum of values using the attention weights.

```
context_vec_2 = attn_weights_2 @ values
```

The output is now 2-dimensional (matching `d_out`) instead of 3-dimensional (the original `d_in`).

In [45]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


## Wrap It All in a PyTorch Module (v1)

Let's package everything into a reusable `nn.Module` class.

**SelfAttention_v1** uses `nn.Parameter` with raw tensors:
- Weights are initialized with `torch.rand()` (uniform 0-1)
- The `@` operator does matrix multiplication
- `grad_fn=<MmBackward0>` shows PyTorch is tracking gradients for backprop!

In [46]:
import torch.nn as nn
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [47]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


## Improved Version with nn.Linear (v2)

**SelfAttention_v2** uses `nn.Linear` layers instead of raw `nn.Parameter`:

**Advantages of `nn.Linear`:**
- Better weight initialization (Kaiming/Xavier by default)
- Optional bias term (`qkv_bias` parameter)
- Cleaner syntax: `self.W_key(x)` instead of `x @ self.W_key`
- More consistent with PyTorch conventions

**Note:** The outputs are different because:
1. Different random seed (789 vs 123)
2. `nn.Linear` uses a different weight initialization scheme

In [48]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [49]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
