In [13]:
import torch
from torch.nn import functional as F
import math

In [None]:
# Create tensors and tell PyTorch to track operations on them
x = torch.tensor([2.0], requires_grad=True)
w = torch.tensor([3.0], requires_grad=True)
b = torch.tensor([1.0], requires_grad=True)

# Do some operations
y = w * x + b  # y = 3*2 + 1 = 7
loss = y ** 2  # loss = 49

print(loss.grad_fn)  # <PowBackward0 ...>


<PowBackward0 object at 0x7fe72c08aa70>


In [12]:
loss.backward()  # Compute gradients
print(x.grad)  # dy/dx = 2*y*w = 2*7*3 = 42
print(w.grad)  # dy/dw = 2*y*x = 2*7*2 = 28
print(b.grad)  # dy/db = 2*y = 2*7 = 14 
print(loss)

tensor([42.])
tensor([28.])
tensor([14.])
tensor([49.], grad_fn=<PowBackward0>)


In [10]:
w = torch.tensor([3.0, 2.0], requires_grad=True)
loss = (w ** 2).sum() # Now loss = 9 + 4 = 13 (scalar!)

# Before backward():
print(w.grad)  # None (no gradient computed yet)

loss.backward()

# After backward():
print(w.grad) # tensor([6.0, 4.0]) (gradient computed)

None
tensor([6., 4.])


In [13]:
from mingpt.model import GPT
model_config = GPT.get_default_config()
model_config.model_type = 'gpt2'
model_config.vocab_size = 50257
model_config.block_size = 512
model = GPT(model_config)

number of parameters: 124.05M


In [14]:
for name, param in model.named_parameters():
    print(name, param.shape, param.requires_grad)

transformer.wte.weight torch.Size([50257, 768]) True
transformer.wpe.weight torch.Size([512, 768]) True
transformer.h.0.ln_1.weight torch.Size([768]) True
transformer.h.0.ln_1.bias torch.Size([768]) True
transformer.h.0.attn.c_attn.weight torch.Size([2304, 768]) True
transformer.h.0.attn.c_attn.bias torch.Size([2304]) True
transformer.h.0.attn.c_proj.weight torch.Size([768, 768]) True
transformer.h.0.attn.c_proj.bias torch.Size([768]) True
transformer.h.0.ln_2.weight torch.Size([768]) True
transformer.h.0.ln_2.bias torch.Size([768]) True
transformer.h.0.mlp.c_fc.weight torch.Size([3072, 768]) True
transformer.h.0.mlp.c_fc.bias torch.Size([3072]) True
transformer.h.0.mlp.c_proj.weight torch.Size([768, 3072]) True
transformer.h.0.mlp.c_proj.bias torch.Size([768]) True
transformer.h.1.ln_1.weight torch.Size([768]) True
transformer.h.1.ln_1.bias torch.Size([768]) True
transformer.h.1.attn.c_attn.weight torch.Size([2304, 768]) True
transformer.h.1.attn.c_attn.bias torch.Size([2304]) True
tr

In [6]:
import torch
import torch.nn.functional as F

# Example 1: Understanding .view() and .size()
print("=== Example 1: Tensor shapes ===")
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
print(f"Original shape: {x.shape}")  # [2, 3]
print(f"x.size(-1) = {x.size(-1)}")  # 3 (last dimension)
print(f"x.size(0) = {x.size(0)}")    # 2 (first dimension)

# .view() reshapes the tensor
print(f"\nx.view(-1) flattens to: {x.view(-1)}")  # [1, 2, 3, 4, 5, 6]
print(f"Shape: {x.view(-1).shape}")  # [6]

print(f"\nx.view(-1, 3) reshapes to:\n{x.view(-1, 3)}")  # Same as original
print(f"Shape: {x.view(-1, 3).shape}")  # [2, 3]

print(x)

# Example 2: What cross_entropy actually needs
print("\n=== Example 2: Cross Entropy Requirements ===")

# Simulating your model's output
batch_size = 2
seq_len = 3
vocab_size = 5

# logits shape: (batch, sequence, vocab) - predictions for each position
logits = torch.randn(batch_size, seq_len, vocab_size)
print(f"logits shape: {logits.shape}")  # [2, 3, 5]

# targets shape: (batch, sequence) - correct token at each position
targets = torch.tensor([[1, 2, 3],
                        [0, 4, 2]])
print(f"targets shape: {targets.shape}")  # [2, 3]

# Cross entropy wants: (N, C) for predictions and (N,) for targets
# where N = number of predictions, C = number of classes
logits_flat = logits.view(-1, logits.size(-1))
print(f"\nlogits.view(-1, logits.size(-1)) shape: {logits_flat.shape}")  # [6, 5]

targets_flat = targets.view(-1)
print(f"targets.view(-1) shape: {targets_flat.shape}")  # [6]

loss = F.cross_entropy(logits_flat, targets_flat)
print(f"\nLoss: {loss.item()}")

print(logits)
print(targets)
print(logits_flat)
print(targets_flat)

# Example 3: What does ignore_index do?
print("\n=== Example 3: ignore_index ===")

targets_with_padding = torch.tensor([[1, 2, -1],  # -1 = padding token
                                      [0, -1, 2]])  # -1 = padding token

loss_with_ignore = F.cross_entropy(logits_flat, targets_with_padding.view(-1), 
                                     ignore_index=-1)
print(f"Loss (ignoring -1 tokens): {loss_with_ignore.item()}")
print("The -1 positions don't contribute to the loss!")

=== Example 1: Tensor shapes ===
Original shape: torch.Size([2, 3])
x.size(-1) = 3
x.size(0) = 2

x.view(-1) flattens to: tensor([1, 2, 3, 4, 5, 6])
Shape: torch.Size([6])

x.view(-1, 3) reshapes to:
tensor([[1, 2, 3],
        [4, 5, 6]])
Shape: torch.Size([2, 3])
tensor([[1, 2, 3],
        [4, 5, 6]])

=== Example 2: Cross Entropy Requirements ===
logits shape: torch.Size([2, 3, 5])
targets shape: torch.Size([2, 3])

logits.view(-1, logits.size(-1)) shape: torch.Size([6, 5])
targets.view(-1) shape: torch.Size([6])

Loss: 1.7956656217575073
tensor([[[-0.0573, -0.2409,  0.0426, -0.3903, -1.1125],
         [ 1.2622, -0.1017, -0.2237,  0.1274,  1.8104],
         [ 1.9876, -0.2748, -0.8699,  1.4689,  0.1783]],

        [[-1.5124, -1.3039, -0.9486,  0.8336, -0.1846],
         [ 0.0790,  0.5465,  0.7917,  0.4532,  0.3316],
         [ 1.4380, -1.6008,  1.8799, -0.1534, -0.5787]]])
tensor([[1, 2, 3],
        [0, 4, 2]])
tensor([[-0.0573, -0.2409,  0.0426, -0.3903, -1.1125],
        [ 1.2622, -

In [8]:
import torch

# Start with a 3D tensor
logits = torch.tensor([
    [[1, 2, 3, 4, 5],    # batch 0, position 0, vocab scores
     [6, 7, 8, 9, 10],   # batch 0, position 1, vocab scores
     [11, 12, 13, 14, 15]], # batch 0, position 2, vocab scores
    
    [[16, 17, 18, 19, 20],  # batch 1, position 0, vocab scores
     [21, 22, 23, 24, 25],  # batch 1, position 1, vocab scores
     [26, 27, 28, 29, 30]]  # batch 1, position 2, vocab scores
])

print(f"Original shape: {logits.shape}")  # [2, 3, 5]
print(f"logits.size(-1) = {logits.size(-1)}")  # 5

# Now reshape
flat = logits.view(-1, logits.size(-1))
print(f"\nFlattened shape: {flat.shape}")  # [6, 5]
print(f"Flattened tensor:\n{flat}")


Original shape: torch.Size([2, 3, 5])
logits.size(-1) = 5

Flattened shape: torch.Size([6, 5])
Flattened tensor:
tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20],
        [21, 22, 23, 24, 25],
        [26, 27, 28, 29, 30]])


In [None]:
x = torch.randn(2, 3, 5)  # 30 elements total

# All of these give the same result:
print(x.view(-1, 5).shape)    # [6, 5]
print(x.view(6, -1).shape)    # [6, 5]
print(x.view(2, -1).shape)    # [2, 15]
print(x.view(-1).shape)       # [30] - fully flat

In [4]:
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
print(x.shape)
x[1]

torch.Size([2, 3])


tensor([4, 5, 6])

In [16]:
logits = torch.tensor([[2.3, -1.5, 0.8, 4.1, -0.2]])  # scores for 5 tokens
target = torch.tensor([3])  # correct token is index 3
print('logits shape:', logits.shape)
print('target shape:', target.shape)
loss = F.cross_entropy(logits, target)
print(f"Loss: {loss.item()}")

# What happened?
probs = F.softmax(logits, dim=-1)
print('probs: ',probs)
print('probs shape: ', probs.shape)
print(f"Probability of correct token (index 3): {probs[0, 3].item():.4f}")
print(f"-log of that probability: {-torch.log(probs[0, 3]).item():.4f}")


# Manual softmax calculation
exp_logits = torch.exp(logits)
print("\nexp(logits):", exp_logits)
print("exp(2.3) =", math.exp(2.3))
print("exp(4.1) =", math.exp(4.1))
sum_exp = exp_logits.sum()
print("Sum of all exp:", sum_exp)
probs = exp_logits / sum_exp
print("Probabilities:", probs)


# What if model was wrong?
bad_logits = torch.tensor([[4.1, -1.5, 0.8, 2.3, -0.2]])  # now best score at index 0
target = torch.tensor([3])  # but correct is still index 3

loss = F.cross_entropy(bad_logits, target)
print(f"\nLoss when model is wrong: {loss.item()}")  # Much higher!

logits shape: torch.Size([1, 5])
target shape: torch.Size([1])
Loss: 0.1983986645936966
probs:  tensor([[0.1356, 0.0030, 0.0302, 0.8200, 0.0111]])
probs shape:  torch.Size([1, 5])
Probability of correct token (index 3): 0.8200
-log of that probability: 0.1984

exp(logits): tensor([[ 9.9742,  0.2231,  2.2255, 60.3403,  0.8187]])
exp(2.3) = 9.974182454814718
exp(4.1) = 60.34028759736195
Sum of all exp: tensor(73.5819)
Probabilities: tensor([[0.1356, 0.0030, 0.0302, 0.8200, 0.0111]])

Loss when model is wrong: 1.9983986616134644


# Complete Summary: Understanding GPT and the Training Pipeline

## Part 1: The Data Pipeline

### From Text to Tokens

**Tokenization** converts text into sequences of integers using Byte Pair Encoding (BPE):
- Text: "Hello world" → Tokens: [15496, 995]
- Each token ID represents a subword unit from a vocabulary of 50,257 possible tokens
- Longer texts produce more tokens (variable length sequences)

### Creating Training Pairs

For next-token prediction, we create input/target pairs:
```python
tokens = [15496, 995, 13, 220, ...]  # Full tokenized sequence

# Create shifted sequences for next-token prediction
x = tokens[:-1]  # Input: all tokens except the last
y = tokens[1:]   # Target: all tokens except the first

# Example alignment:
# x[0] = token_0, y[0] = token_1  (predict token_1 given token_0)
# x[1] = token_1, y[1] = token_2  (predict token_2 given token_1)
# ... and so on
```

**Key insight:** We predict the next token at every position in parallel.

### Batching

The DataLoader stacks multiple sequences together:
- Individual sequence: `[seq_len]`
- Batch of 8 sequences: `[8, seq_len]`
- All sequences in a batch processed in parallel

### Block Size Constraint
```python
block_size = 512  # Maximum sequence length

if len(tokens) > block_size + 1:
    tokens = tokens[:block_size + 1]  # Truncate if too long
```

This is a hard computational limit - the model can only process up to 512 tokens at once.

---

## Part 2: Embeddings - Converting Tokens to Vectors

### Token Embeddings (wte - Word Token Embedding)

**What:** A lookup table that converts token IDs to dense vectors
```python
wte = nn.Embedding(vocab_size=50257, embedding_dim=768)
# Creates a matrix of shape [50257, 768]
# Each row is a learned 768-dimensional vector for one token
```

**How it works:**
```python
idx = torch.tensor([[48, 25, 198]])  # Token IDs, shape [1, 3]
tok_emb = wte(idx)                    # Look up embeddings, shape [1, 3, 768]

# Internally:
# tok_emb[0, 0, :] = wte.weight[48, :]   # Row 48 from the table
# tok_emb[0, 1, :] = wte.weight[25, :]   # Row 25 from the table
# tok_emb[0, 2, :] = wte.weight[198, :]  # Row 198 from the table
```

**Key property:** The same token always gets the same vector from this table, regardless of where it appears in the sequence.

### Positional Embeddings (wpe - Word Position Embedding)

**The problem:** Attention mechanisms are order-blind. Without position information:
- "The dog bit the cat" and "The cat bit the dog" would look identical
- The model needs to know WHERE each token appears

**Solution:** Add position-specific information
```python
wpe = nn.Embedding(block_size=512, embedding_dim=768)
# Creates a matrix of shape [512, 768]
# Each row is a learned 768-dimensional vector for one position

pos = torch.arange(0, seq_len)  # [0, 1, 2, ..., seq_len-1]
pos_emb = wpe(pos)              # Look up position embeddings
```

**Key insight:** Position embeddings are shared across all items in a batch (position 0 is position 0 for everyone).

### Combining Embeddings
```python
x = tok_emb + pos_emb  # Element-wise addition
```

Now each token's vector contains BOTH:
1. **What the token is** (from tok_emb)
2. **Where it appears** (from pos_emb)

Example: "dog" at position 0 gets a different combined vector than "dog" at position 4, even though the token is the same!

### Dropout for Regularization
```python
x = dropout(x)  # Randomly set some values to 0 during training
```

Prevents overfitting by forcing the model not to rely on any single feature.

---

## Part 3: The Forward Pass Through GPT

### Complete Flow
```python
def forward(self, idx, targets=None):
    # idx shape: [batch_size, seq_len] - contains token IDs (integers)
    
    # Step 1: Create embeddings
    tok_emb = self.transformer.wte(idx)  # [batch, seq_len, 768]
    pos_emb = self.transformer.wpe(pos)  # [1, seq_len, 768]
    x = self.transformer.drop(tok_emb + pos_emb)  # [batch, seq_len, 768]
    
    # Step 2: Process through transformer blocks (THE LEARNING HAPPENS HERE)
    for block in self.transformer.h:  # 12 blocks in GPT-2
        x = block(x)  # Each block: attention + feed-forward
    
    # Step 3: Final processing
    x = self.transformer.ln_f(x)  # Layer normalization for stability
    logits = self.lm_head(x)      # Project to vocabulary size
    # logits shape: [batch, seq_len, 50257]
    
    # Step 4: Compute loss if targets provided
    if targets is not None:
        loss = F.cross_entropy(
            logits.view(-1, 50257),  # Flatten to [batch*seq_len, 50257]
            targets.view(-1)          # Flatten to [batch*seq_len]
        )
    
    return logits, loss
```

### Understanding Logits

**Logits** are raw, unnormalized scores for each possible next token:
- Shape: `[batch, seq_len, vocab_size]`
- For each position, we have 50,257 scores (one per token in vocabulary)
- Higher score = model thinks that token is more likely to come next

**Not probabilities yet!** Cross entropy will apply softmax internally.

---

## Part 4: Loss Function - Cross Entropy

### Why Cross Entropy?

Language modeling is **multi-class classification** at each position:
- Question: "Which of 50,257 tokens comes next?"
- Answer: One specific token (discrete choice, not continuous)

### How Cross Entropy Works
```python
# For one position:
logits = [2.3, -1.5, 0.8, 4.1, -0.2]  # Raw scores
target = 3  # Correct token is at index 3

# Step 1: Convert logits to probabilities with softmax
probs = softmax(logits)  # [0.14, 0.003, 0.03, 0.82, 0.01]
# Higher logits → higher probabilities (exponentially)

# Step 2: Compute loss
loss = -log(probs[target])  # -log(0.82) = 0.20
```

**Interpretation:**
- Model assigned 82% probability to correct token → low loss (good!)
- If model assigned 10% → high loss (bad!)

### For Multiple Positions
```python
logits.view(-1, 50257)   # [512, 50257] - 512 predictions
targets.view(-1)          # [512] - 512 correct answers

loss = F.cross_entropy(logits, targets)  # Average loss over all 512 positions
```

---

## Part 5: Training Loop Mechanics

### PyTorch's Autograd - The Computational Graph

When you perform operations on tensors with `requires_grad=True`, PyTorch builds a graph:
```python
w = torch.tensor([3.0], requires_grad=True)
x = torch.tensor([2.0])
y = w * x      # y = 6.0
loss = y ** 2  # loss = 36.0

# PyTorch remembers:
# - loss came from y^2
# - y came from w * x
# - Therefore loss depends on w
```

Every tensor has a `.grad_fn` attribute that stores "how was I created?"

### The Training Cycle
```python
# Forward pass
logits, loss = model(x, y)  # Builds computational graph

# Clear old gradients
model.zero_grad()  # Sets all .grad attributes to None/zero

# Backward pass
loss.backward()    # Walks graph backwards, computes ∂loss/∂(every weight)
                   # Stores gradients in weight.grad for each parameter

# Update weights
optimizer.step()   # For each weight: w = w - learning_rate * w.grad
```

**Key insight:** Only parameters that were actually used in computing the loss get gradient updates.

### Gradient Clipping
```python
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
```

If gradients get too large (exploding gradients), scale them down for stability.

### The Optimizer - AdamW
```python
optimizer = torch.optim.AdamW(parameters, lr=learning_rate)
```

AdamW is a sophisticated version of gradient descent that:
- Adapts learning rate per parameter
- Uses momentum (considers past gradients)
- Applies weight decay for regularization

---

## Part 6: Key Concepts Recap

### Iteration vs Epoch

**One iteration:**
- Process one batch
- Compute loss
- Backpropagate
- Update weights once

**One epoch:**
- Process the entire dataset once
- Multiple iterations per epoch (num_iterations = dataset_size / batch_size)

### Shape Transformations

Understanding tensor shapes is critical:
```python
# Start: [batch, seq_len] - token IDs (integers)
idx = torch.Size([4, 512])

# After token embedding: [batch, seq_len, n_embd]
tok_emb = torch.Size([4, 512, 768])

# After blocks: [batch, seq_len, n_embd] - same shape, refined representations
x = torch.Size([4, 512, 768])

# After lm_head: [batch, seq_len, vocab_size]
logits = torch.Size([4, 512, 50257])

# For loss computation: flatten batch and sequence dimensions
logits_flat = torch.Size([2048, 50257])  # 4 * 512 = 2048 predictions
targets_flat = torch.Size([2048])         # 2048 correct answers
```

### The `view` Operation

Reshaping tensors without changing data:
```python
x = torch.randn(2, 3, 5)  # Shape [2, 3, 5], total 30 elements

x.view(-1)         # [30] - fully flatten
x.view(-1, 5)      # [6, 5] - PyTorch infers first dim: 30/5 = 6
x.view(2, -1)      # [2, 15] - PyTorch infers second dim: 30/2 = 15
x.view(6, 5)       # [6, 5] - explicit reshape
```

The `-1` means "figure this dimension out automatically."

---

## Part 7: What Happens in the Blocks (High-Level)

Before diving into code, understand the purpose:

### Transformer Block Structure
```python
def forward(self, x):
    # x shape: [batch, seq_len, 768]
    
    x = x + self.attn(self.ln_1(x))      # Attention: tokens communicate
    x = x + self.mlpf(self.ln_2(x))      # Feed-forward: process information
    
    return x  # Same shape: [batch, seq_len, 768]
```

### 1. Attention Mechanism

**Purpose:** Let tokens share information with each other

- "The cat sat on the mat"
- Token "sat" can attend to "cat" (who sat?) and "mat" (where?)
- Each token gathers relevant context from other tokens
- This is where GPT builds understanding of relationships

### 2. Feed-Forward Network

**Purpose:** Process the gathered information

- After attention, each token has collected context
- Feed-forward lets each token "think" about what it learned
- Happens independently for each token (no communication here)

### 3. Residual Connections (`x + ...`)

**Purpose:** Help gradients flow during training

- The `+` allows gradients to flow directly backward
- Prevents vanishing gradient problem in deep networks
- Also helps preserve information from earlier layers

### 4. Layer Normalization

**Purpose:** Stabilize training

- Normalizes values to mean=0, std=1
- Prevents values from exploding or vanishing
- Applied before attention and feed-forward

### Stack of 12 Blocks
```python
for block in self.transformer.h:  # Repeat 12 times
    x = block(x)
```

Each pass:
- Refines token representations
- Builds deeper understanding
- Incorporates more context

By block 12, each token has a rich representation incorporating information from the entire sequence.

---

## Summary of the Complete Pipeline

1. **Text → Tokens:** BPE tokenization converts text to integers
2. **Tokens → Embeddings:** Look up learned vectors for tokens and positions
3. **Preprocessing:** Combine embeddings, apply dropout
4. **Transformer Blocks:** 12 layers of attention + feed-forward (THE MAGIC)
5. **Output Projection:** Convert 768-dim vectors to 50,257-dim logit scores
6. **Loss Computation:** Cross entropy measures prediction quality
7. **Backpropagation:** Compute gradients for all parameters
8. **Weight Update:** Optimizer adjusts parameters to reduce loss
9. **Repeat:** Do this for thousands of iterations

**You are here:** About to understand what happens inside those transformer blocks where the intelligence emerges!

In [17]:
# Helper Code: Demonstrating Key Concepts

import torch
import torch.nn as nn
import torch.nn.functional as F

print("=" * 60)
print("DEMONSTRATION 1: Embeddings as Lookup Tables")
print("=" * 60)

# Create a tiny embedding table
vocab_size = 10
embedding_dim = 4
embedding = nn.Embedding(vocab_size, embedding_dim)

print(f"\nEmbedding table shape: {embedding.weight.shape}")
print("Embedding table (first 3 rows):")
print(embedding.weight[:3].detach())

# Look up embeddings for tokens
token_ids = torch.tensor([2, 5, 2])  # Note: token 2 appears twice
embedded = embedding(token_ids)

print(f"\nToken IDs: {token_ids}")
print(f"Embedded shape: {embedded.shape}")
print("\nEmbedding for token 2 (first occurrence):")
print(embedded[0])
print("Embedding for token 2 (second occurrence):")
print(embedded[2])
print("They're identical! Same token = same embedding")

print("\n" + "=" * 60)
print("DEMONSTRATION 2: Softmax and Cross Entropy")
print("=" * 60)

# Simulated logits for 5 possible tokens
logits = torch.tensor([[2.3, -1.5, 0.8, 4.1, -0.2]])
target = torch.tensor([3])  # Correct token is at index 3

# Manual softmax
probs = F.softmax(logits, dim=-1)
print(f"\nLogits: {logits}")
print(f"Probabilities (after softmax): {probs}")
print(f"Sum of probabilities: {probs.sum().item():.4f}")

# Cross entropy loss
loss = F.cross_entropy(logits, target)
manual_loss = -torch.log(probs[0, target])

print(f"\nTarget token index: {target.item()}")
print(f"Probability assigned to correct token: {probs[0, target].item():.4f}")
print(f"Cross entropy loss: {loss.item():.4f}")
print(f"Manual calculation (-log(prob)): {manual_loss.item():.4f}")

print("\n" + "=" * 60)
print("DEMONSTRATION 3: Tensor Reshaping with view()")
print("=" * 60)

# Create a sample tensor
x = torch.arange(24).reshape(2, 3, 4)
print(f"Original shape: {x.shape}")
print(f"Total elements: {x.numel()}")

print(f"\nx.view(-1) shape: {x.view(-1).shape}")
print(f"x.view(-1, 4) shape: {x.view(-1, 4).shape}")
print(f"x.view(2, -1) shape: {x.view(2, -1).shape}")
print(f"x.view(6, 4) shape: {x.view(6, 4).shape}")

print("\n" + "=" * 60)
print("DEMONSTRATION 4: Computational Graph and Gradients")
print("=" * 60)

# Simple computation graph
w = torch.tensor([3.0], requires_grad=True)
x = torch.tensor([2.0])
y = w * x
loss = y ** 2

print(f"w = {w.item()}")
print(f"x = {x.item()}")
print(f"y = w * x = {y.item()}")
print(f"loss = y^2 = {loss.item()}")

print(f"\nBefore backward:")
print(f"w.grad = {w.grad}")

# Compute gradients
loss.backward()

print(f"\nAfter backward:")
print(f"w.grad = {w.grad.item()}")
print(f"\nMath check: d(loss)/d(w) = d((w*x)^2)/d(w) = 2*(w*x)*x")
print(f"           = 2 * (3*2) * 2 = {2 * (3*2) * 2}")

print("\n" + "=" * 60)
print("DEMONSTRATION 5: Batching and Broadcasting")
print("=" * 60)

# Token embeddings for a batch
batch_size = 2
seq_len = 3
n_embd = 4

tok_emb = torch.randn(batch_size, seq_len, n_embd)
pos_emb = torch.randn(1, seq_len, n_embd)  # Note: batch dim is 1

print(f"Token embeddings shape: {tok_emb.shape}")
print(f"Position embeddings shape: {pos_emb.shape}")

# Broadcasting: pos_emb is added to both items in batch
combined = tok_emb + pos_emb

print(f"Combined shape: {combined.shape}")
print("\nPosition embedding is SHARED across batch items!")

print("\n" + "=" * 60)
print("All demonstrations complete!")
print("=" * 60)

DEMONSTRATION 1: Embeddings as Lookup Tables

Embedding table shape: torch.Size([10, 4])
Embedding table (first 3 rows):
tensor([[ 0.5441,  0.4206,  0.3946,  0.1132],
        [ 0.2333,  0.4732,  0.2855,  1.1324],
        [-0.6531,  0.2970, -0.0748, -0.8157]])

Token IDs: tensor([2, 5, 2])
Embedded shape: torch.Size([3, 4])

Embedding for token 2 (first occurrence):
tensor([-0.6531,  0.2970, -0.0748, -0.8157], grad_fn=<SelectBackward0>)
Embedding for token 2 (second occurrence):
tensor([-0.6531,  0.2970, -0.0748, -0.8157], grad_fn=<SelectBackward0>)
They're identical! Same token = same embedding

DEMONSTRATION 2: Softmax and Cross Entropy

Logits: tensor([[ 2.3000, -1.5000,  0.8000,  4.1000, -0.2000]])
Probabilities (after softmax): tensor([[0.1356, 0.0030, 0.0302, 0.8200, 0.0111]])
Sum of probabilities: 1.0000

Target token index: 3
Probability assigned to correct token: 0.8200
Cross entropy loss: 0.1984
Manual calculation (-log(prob)): 0.1984

DEMONSTRATION 3: Tensor Reshaping with vi

# Complete Summary: Deep Dive into Transformer Architecture and Modern Improvements

## Part 8: Inside the Transformer Block - Where Intelligence Emerges

### The Block Structure

Each Transformer block performs two main operations with residual connections:

**Residual Connections:**
- Pattern: `output = input + transformation(input)`
- Purpose: Allows gradients to flow directly backward through the `+` operation
- Effect: Prevents vanishing gradients in deep networks (12 blocks)
- Intuition: We don't replace the representation, we **refine** it incrementally

**Layer Normalization:**
- Applied before each sub-layer (pre-norm architecture)
- Normalizes activations to have mean=0, std=1
- Stabilizes training by preventing activations from exploding or vanishing
- Includes learnable scale (gamma) and shift (beta) parameters

---

## Part 9: The Attention Mechanism - How Tokens Communicate

### Core Purpose

Attention allows each token to:
1. Determine which other tokens are relevant to it
2. Gather information from those relevant tokens
3. Update its representation based on the gathered context

### The Three Projections: Query, Key, Value

From input embeddings `x` of shape `[batch, seq_len, n_embd]`, we create:

**Query (Q):** Represents "what information is this token seeking?"
**Key (K):** Represents "what information does this token advertise?"
**Value (V):** Represents "the actual information to transfer"

All three come from the same input, transformed by different learned weight matrices.

### Mathematical Formulation

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

**Breaking down the computation:**

1. **Similarity Scores:** $QK^T$ computes dot products between all query-key pairs
   - Shape: `[batch, seq_len, seq_len]`
   - Entry `[i, j]` = how much token `i` should attend to token `j`

2. **Scaling:** Divide by $\sqrt{d_k}$ (square root of key dimension)
   - Prevents dot products from getting too large
   - Keeps softmax gradients stable

3. **Attention Weights:** Apply softmax to convert scores to probabilities
   - Each row sums to 1
   - Represents distribution of attention across all tokens

4. **Weighted Aggregation:** Multiply attention weights by values
   - Each token's output = weighted average of all value vectors
   - Weights determined by attention scores

### Causal Masking

For autoregressive language modeling, we must prevent tokens from attending to future positions:

**Mechanism:** Set attention scores to $-\infty$ for all positions $j > i$

After softmax, $e^{-\infty} = 0$, effectively blocking future information.

**Implementation:** Lower triangular mask matrix

$$\text{mask} = \begin{bmatrix}
1 & 0 & 0 & 0 \\
1 & 1 & 0 & 0 \\
1 & 1 & 1 & 0 \\
1 & 1 & 1 & 1
\end{bmatrix}$$

Where 0 positions become $-\infty$ in attention scores.

---

## Part 10: Multi-Head Attention - Parallel Attention Patterns

### Why Multiple Heads?

Single attention lets the model learn one type of relationship. Multiple heads allow the model to simultaneously learn different types of patterns:
- Head 1 might learn subject-verb relationships
- Head 2 might learn adjective-noun relationships
- Head 3 might learn long-range dependencies
- etc.

### Implementation via Tensor Reshaping

Instead of running separate attention mechanisms, we split dimensions:

**Process:**
1. Split embedding dimension into `n_head` groups
   - Original: `[batch, seq_len, 768]`
   - Reshaped: `[batch, seq_len, n_head, head_dim]` where `head_dim = 768/12 = 64`

2. Transpose to make heads the "batch-like" dimension
   - Result: `[batch, n_head, seq_len, head_dim]`

3. Compute attention independently for each head
   - All heads processed in parallel via batched matrix operations

4. Concatenate head outputs back together
   - Transpose: `[batch, seq_len, n_head, head_dim]`
   - Flatten: `[batch, seq_len, 768]`

**Key insight:** This is data parallelism through tensor dimensions, not separate sequential operations. One matrix multiplication computes attention for all heads simultaneously.

---

## Part 11: The Feed-Forward Network (MLP)

### Structure

After attention, each token independently processes its updated representation:

**Standard architecture:**
```
x → Linear(768 → 3072) → GELU → Linear(3072 → 768) → Dropout
```

**Expansion ratio:** 4× the embedding dimension (768 → 3072)

### Purpose

While attention mixes information **between** tokens, the feed-forward network processes information **within** each token:
- Applies non-linear transformations
- Provides computational capacity for complex pattern recognition
- Operates independently on each position (no cross-token communication)

### GELU Activation Function

**Formula:**
$$\text{GELU}(x) = 0.5 \cdot x \cdot \left(1 + \tanh\left(\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3\right)\right)\right)$$

**Properties:**
- Smooth approximation of ReLU
- Allows small negative values through (unlike ReLU's hard cutoff at 0)
- Better gradient flow during backpropagation
- Prevents "dying neuron" problem

**Why better than ReLU?**
- ReLU: $f(x < 0) = 0$ → gradient is 0 → neuron can't recover
- GELU: $f(x < 0) \approx \text{small negative}$ → gradient exists → neuron can learn

---

## Part 12: Modern Architectural Improvements

### 1. RMSNorm - Efficient Normalization

**LayerNorm performs two operations:**
1. **Centering:** Subtract mean
2. **Scaling:** Divide by standard deviation

**RMSNorm simplifies to one operation:**
- Skip mean subtraction (centering)
- Only normalize by Root Mean Square (RMS)

**Formula:**

$$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{n}\sum_{i=1}^n x_i^2 + \epsilon}} \cdot \gamma$$

Where:
- $\gamma$ is a learnable scaling parameter
- $\epsilon$ is a small constant for numerical stability (typically $10^{-8}$)

**Benefits:**
- 7-64% faster than LayerNorm (fewer operations)
- Lower memory usage
- Empirically performs as well or better
- Used in modern LLMs (LLaMA, GPT-3)

**Trade-off:** Slightly less theoretical grounding, but strong empirical results

---

### 2. SwiGLU - Gated Feed-Forward Networks

**Standard MLP:**
```
x → Linear → Activation → Linear
```

**SwiGLU introduces gating:**
```
         ┌→ Linear1 → SiLU ──┐
x → split                     × → Linear3 → output
         └→ Linear2 ─────────┘
```

**Mathematical formulation:**

$$\text{SwiGLU}(x) = (xW_1 \odot \text{SiLU}(xW_2))W_3$$

Where:
- $W_1, W_2$: Parallel linear transformations (768 → 3072)
- $\text{SiLU}(x) = x \cdot \sigma(x)$: Sigmoid-weighted Linear Unit
- $\odot$: Element-wise multiplication (gating)
- $W_3$: Output projection (3072 → 768)

**Key mechanism:**
- One path produces gate values (via SiLU activation)
- Other path produces signal values
- Element-wise multiplication allows adaptive feature selection

**Benefits:**
- Network learns which features to pass through based on context
- Better gradient flow than standard activations
- Improved performance in models like PaLM and LLaMA

**Parameter increase:** Uses 3 linear layers instead of 2, increasing parameters by ~50% for the MLP

---

### 3. RoPE - Rotary Position Embeddings

**Traditional approach:**
- Learn absolute position embeddings
- Add them to token embeddings at the start
- Position info "baked in" to all subsequent operations

**RoPE's innovation:**
- No learned position embeddings
- Apply position information as **rotation** to Q and K vectors
- Rotation happens inside attention mechanism

### The Rotation Concept

**2D Rotation Matrix:**

$$R(\theta) = \begin{bmatrix}
\cos(\theta) & -\sin(\theta) \\
\sin(\theta) & \cos(\theta)
\end{bmatrix}$$

**Properties:**
- Preserves vector magnitude: $||R(\theta)v|| = ||v||$
- Changes direction by angle $\theta$
- Rotation is a multiplicative operation (matrix multiplication)

**Key insight for attention:**

If we rotate query at position $m$ by angle $m\theta$ and key at position $n$ by angle $n\theta$:

$$Q_m^T K_n = (R_{m\theta} q)^T (R_{n\theta} k) = q^T R_{(n-m)\theta} k$$

The dot product naturally depends on $(n-m)$ - the **relative position difference**!

### RoPE Implementation

**Frequency bands:** Different dimension pairs rotate at different speeds

$$\theta_i = \frac{m}{10000^{2i/d}}$$

Where:
- $m$ is the position index
- $i$ is the dimension pair index
- $d$ is the head dimension
- Base of 10000 creates frequency spectrum

**Rotation applied to dimension pairs:**

For each pair of dimensions $(x_1, x_2)$ at position $m$:

$$\begin{bmatrix}
x_1' \\
x_2'
\end{bmatrix} = \begin{bmatrix}
\cos(m\theta) & -\sin(m\theta) \\
\sin(m\theta) & \cos(m\theta)
\end{bmatrix} \begin{bmatrix}
x_1 \\
x_2
\end{bmatrix}$$

Simplifies to:
- $x_1' = x_1 \cos(m\theta) - x_2 \sin(m\theta)$
- $x_2' = x_1 \sin(m\theta) + x_2 \cos(m\theta)$

**Benefits:**
- Naturally encodes relative positions
- No learned parameters for position
- Better extrapolation to longer sequences than seen in training
- Used in modern LLMs (LLaMA, GPT-NeoX)

**Parameter reduction:** Removes learned positional embedding table (~0.4M parameters saved)

---

### 4. Learning Rate Warmup

**The problem:** At initialization, parameters are random and far from optimal. Large learning rates can cause:
- Exploding gradients
- Numerical instability
- Divergent training

**Solution:** Start with small learning rate, gradually increase to target value

**Linear warmup formula:**

$$\text{lr}_t = \text{lr}_{\text{base}} \cdot \min\left(1, \frac{t}{T_{\text{warmup}}}\right)$$

Where:
- $t$ is the current step
- $T_{\text{warmup}}$ is the warmup period (e.g., 1000 steps)
- Learning rate scales linearly from 0 to $\text{lr}_{\text{base}}$

**Benefits:**
- Prevents early training instability
- Allows use of higher base learning rates
- Smoother convergence

---

### 5. Cosine Learning Rate Decay

**The problem:** After warmup, maintaining constant learning rate can cause:
- Oscillation around optimal values
- Slower fine-grained convergence

**Solution:** Gradually decrease learning rate following a cosine curve

**Cosine annealing formula:**

$$\text{lr}_t = \text{lr}_{\text{min}} + \frac{1}{2}(\text{lr}_{\text{max}} - \text{lr}_{\text{min}})\left(1 + \cos\left(\frac{\pi t}{T_{\text{max}}}\right)\right)$$

Where:
- $t$ is steps since warmup ended
- $T_{\text{max}}$ is total training steps
- Learning rate smoothly decreases from $\text{lr}_{\text{max}}$ to $\text{lr}_{\text{min}}$

**Combined warmup + cosine schedule:**
1. Steps 0 to $T_{\text{warmup}}$: Linear increase
2. Steps $T_{\text{warmup}}$ to $T_{\text{max}}$: Cosine decay

**Benefits:**
- Smooth transitions (no sudden drops)
- Model takes smaller steps near convergence
- Better final performance
- Standard in modern transformer training

---

## Part 13: Complete Forward Pass with All Improvements

### Data Flow Through Modern Transformer

**Input Processing:**
1. Token IDs → Token embeddings (lookup table)
2. If using RoPE: Skip positional embeddings
3. If not using RoPE: Add learned positional embeddings
4. Apply dropout for regularization

**Through Each Block (×12):**
1. **Layer Normalization** (RMSNorm or LayerNorm)
   - Normalize before attention

2. **Multi-Head Attention**
   - Compute Q, K, V projections
   - If using RoPE: Apply rotary embeddings to Q and K
   - Compute attention scores with causal masking
   - Weighted aggregation of values
   - Output projection

3. **Residual Connection**
   - Add attention output to input

4. **Layer Normalization** (RMSNorm or LayerNorm)
   - Normalize before feed-forward

5. **Feed-Forward Network**
   - If using SwiGLU: Gated feed-forward with 3 linear layers
   - If not: Standard MLP with 2 linear layers + GELU
   - Apply dropout

6. **Residual Connection**
   - Add MLP output to input

**Output Processing:**
1. Final layer normalization
2. Linear projection to vocabulary size (768 → 50,257)
3. Output logits for each token position

**Training:**
1. Compute cross-entropy loss between logits and targets
2. Backpropagate gradients through entire network
3. Update weights using AdamW optimizer
4. Step learning rate scheduler (warmup + cosine decay)

---

## Part 14: Tensor Shape Tracking Example

### Concrete Example: Batch=2, SeqLen=512, Embedding=768, Heads=12

**Input:**
```
Token IDs: [2, 512] (integers)
```

**After Embeddings:**
```
Token embeddings: [2, 512, 768]
Position embeddings (if not RoPE): [1, 512, 768]
Combined: [2, 512, 768]
```

**Inside Attention:**
```
Q, K, V after projection: [2, 512, 768] each

After reshaping for multi-head:
Q: [2, 512, 12, 64] → transpose → [2, 12, 512, 64]
K: [2, 512, 12, 64] → transpose → [2, 12, 512, 64]
V: [2, 512, 12, 64] → transpose → [2, 12, 512, 64]

If using RoPE:
  cos, sin: [512, 32] (half of head dimension)
  Applied element-wise to rotated Q and K pairs

Attention scores: Q @ K^T
  [2, 12, 512, 64] @ [2, 12, 64, 512] = [2, 12, 512, 512]

After causal mask and softmax: [2, 12, 512, 512]

Attention output: scores @ V
  [2, 12, 512, 512] @ [2, 12, 512, 64] = [2, 12, 512, 64]

Concatenate heads:
  transpose → [2, 512, 12, 64]
  reshape → [2, 512, 768]
```

**Through Feed-Forward:**
```
If using SwiGLU:
  Branch 1: [2, 512, 768] → Linear → [2, 512, 3072] → SiLU
  Branch 2: [2, 512, 768] → Linear → [2, 512, 3072]
  Gate: element-wise multiply → [2, 512, 3072]
  Output projection: [2, 512, 3072] → Linear → [2, 512, 768]

If using standard MLP:
  [2, 512, 768] → Linear → [2, 512, 3072] → GELU → Linear → [2, 512, 768]
```

**Final Output:**
```
After 12 blocks: [2, 512, 768]
After final norm: [2, 512, 768]
After lm_head projection: [2, 512, 50257]
```

---

## Part 15: Parameter Count Analysis

### Baseline GPT-2 (124M parameters):

**Embeddings:**
- Token embeddings: 50,257 × 768 = 38.6M
- Position embeddings: 512 × 768 = 0.4M

**Per Transformer Block (×12):**
- Attention Q,K,V projection: 768 × (3 × 768) = 1.8M
- Attention output projection: 768 × 768 = 0.6M
- MLP expand: 768 × 3072 = 2.4M
- MLP contract: 3072 × 768 = 2.4M
- LayerNorm (×2): ~3K parameters (negligible)
- **Total per block: ~7.2M**
- **12 blocks: ~86M**

**Output:**
- Final LayerNorm: ~1.5K
- Language model head: shares weights with token embeddings

**Total: ~124M parameters**

### With All Modifications (152M parameters):

**Changes:**
- Remove position embeddings: -0.4M
- Replace LayerNorm with RMSNorm: No parameter change
- Replace MLP with SwiGLU: +~0.8M per block (third linear layer)
- **Net increase: ~28M parameters (mostly from SwiGLU)**

---

## Part 16: Training Dynamics Summary

### One Training Iteration:

1. **Sample batch** from dataset (e.g., batch_size=8, seq_len=512)

2. **Forward pass:**
   - Compute token + position embeddings
   - Pass through 12 transformer blocks
   - Project to vocabulary logits
   - Time complexity: $O(n \cdot d^2 + n^2 \cdot d)$ per block
     - $n$ = sequence length
     - $d$ = model dimension
     - Attention: $O(n^2 \cdot d)$ (quadratic in sequence length!)
     - Feed-forward: $O(n \cdot d^2)$ (linear in sequence length)

3. **Compute loss:**
   - Cross-entropy between logits and targets
   - Average over batch and sequence positions

4. **Backward pass:**
   - Compute gradients via automatic differentiation
   - Gradients flow through residual connections
   - Clip gradient norms to prevent explosions

5. **Update parameters:**
   - AdamW optimizer adjusts all 124-152M parameters
   - Learning rate determined by scheduler
   - Weight decay applied to specific parameter groups

6. **Step scheduler:**
   - During warmup: increase learning rate
   - After warmup: decrease via cosine schedule

**Typical training:** Repeat for millions of iterations over billions of tokens

---

## Key Takeaways

### What Makes Transformers Work:

1. **Attention mechanism:** Allows flexible, context-dependent information routing
2. **Residual connections:** Enable training of very deep networks (12+ layers)
3. **Layer normalization:** Stabilizes training dynamics
4. **Massive scale:** Billions of parameters trained on trillions of tokens

### Modern Improvements Philosophy:

1. **Efficiency:** RMSNorm reduces computation without hurting performance
2. **Expressiveness:** SwiGLU provides richer feature transformations
3. **Generalization:** RoPE improves extrapolation to longer sequences
4. **Stability:** Learning rate schedules prevent training collapse

### The Engineering Reality:

- Many architectural choices are **empirically validated**, not theoretically proven
- Small changes can have significant impacts (e.g., pre-norm vs post-norm)
- Scaling laws emerge: bigger models + more data = better performance
- Implementation details matter enormously for training stability

---

## Final Architecture Comparison

### Baseline GPT-2:
- Learned positional embeddings
- LayerNorm
- Standard MLP with GELU
- Constant learning rate (or simple decay)
- **124M parameters**

### Modern Improved Version:
- Rotary positional embeddings (RoPE)
- RMSNorm
- SwiGLU gated feed-forward
- Warmup + cosine decay schedule
- **152M parameters**
- ~7% faster per step (RMSNorm savings)
- Better performance on downstream tasks
- Better extrapolation to longer sequences

In [1]:
# Code Demonstrations: Understanding Transformer Components

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

print("=" * 70)
print("DEMONSTRATION 1: Attention Score Computation")
print("=" * 70)

# Simple 3-token example
B, T, C = 1, 3, 4  # 1 batch, 3 tokens, 4 dimensions

# Create simple Q and K matrices
q = torch.tensor([
    [[1.0, 0.0, 0.0, 0.0],   # Token 0 query
     [0.0, 1.0, 0.0, 0.0],   # Token 1 query
     [1.0, 1.0, 0.0, 0.0]]   # Token 2 query
])

k = torch.tensor([
    [[1.0, 0.0, 0.0, 0.0],   # Token 0 key
     [1.0, 0.0, 0.0, 0.0],   # Token 1 key
     [0.0, 1.0, 0.0, 0.0]]   # Token 2 key
])

# Compute attention scores
att = q @ k.transpose(-2, -1) / math.sqrt(C)
print("\nAttention scores (before masking):")
print(att[0])
print("\nInterpretation:")
print("  Row i, Col j = similarity between token i's query and token j's key")

# Apply causal mask
mask = torch.tril(torch.ones(T, T))
att_masked = att.masked_fill(mask == 0, float('-inf'))
print("\nAfter causal masking:")
print(att_masked[0])

# Apply softmax
att_weights = F.softmax(att_masked, dim=-1)
print("\nAttention weights (after softmax):")
print(att_weights[0])
print("\nNote: Each row sums to 1.0, representing probability distribution")

print("\n" + "=" * 70)
print("DEMONSTRATION 2: Multi-Head Attention Reshaping")
print("=" * 70)

# Starting tensor
B, T, C = 2, 4, 8  # 2 batch, 4 tokens, 8 dimensions
n_head = 2

x = torch.randn(B, T, C)
print(f"\nOriginal shape: {x.shape}")

# Reshape for multi-head
x_multihead = x.view(B, T, n_head, C // n_head)
print(f"After view: {x_multihead.shape} [batch, seq, heads, head_dim]")

# Transpose to put heads second
x_transposed = x_multihead.transpose(1, 2)
print(f"After transpose: {x_transposed.shape} [batch, heads, seq, head_dim]")

print("\nNow we can process each head in parallel via batched operations!")

print("\n" + "=" * 70)
print("DEMONSTRATION 3: RMSNorm vs LayerNorm")
print("=" * 70)

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        # Calculate RMS along last dimension
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        # Normalize
        x_norm = x / rms
        # Apply learned scaling
        return x_norm * self.weight

# Test both normalizations
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])

layernorm = nn.LayerNorm(4)
rmsnorm = RMSNorm(4)

ln_out = layernorm(x)
rms_out = rmsnorm(x)

print("\nInput:", x)
print("LayerNorm output:", ln_out)
print("RMSNorm output:", rms_out)

print("\nLayerNorm operations: subtract mean, divide by std")
print("RMSNorm operations: divide by RMS (skips mean subtraction)")

print("\n" + "=" * 70)
print("DEMONSTRATION 4: SwiGLU Gating Mechanism")
print("=" * 70)

# Simple gating example
x = torch.tensor([[1.0, 2.0, 3.0, 4.0]])

# Two parallel paths
w1 = nn.Linear(4, 4, bias=False)
w2 = nn.Linear(4, 4, bias=False)

# Initialize with specific weights for demonstration
with torch.no_grad():
    w1.weight.copy_(torch.eye(4) * 2)  # Doubles input
    w2.weight.copy_(torch.eye(4))       # Identity

hidden1 = w1(x)
hidden2 = w2(x)

print("\nInput:", x)
print("Path 1 (will be gated):", hidden1)
print("Path 2 (signal):", hidden2)

# Apply SiLU to path 1 (gate)
gate = hidden1 * torch.sigmoid(hidden1)
print("\nGate (after SiLU):", gate)

# Element-wise multiplication
output = gate * hidden2
print("Final output (gate × signal):", output)

print("\nNotice: Gate values control how much signal passes through")

print("\n" + "=" * 70)
print("DEMONSTRATION 5: 2D Rotation for RoPE")
print("=" * 70)

def rotate_2d(x, y, theta):
    """Rotate point (x,y) by angle theta"""
    cos_theta = math.cos(theta)
    sin_theta = math.sin(theta)
    
    x_new = x * cos_theta - y * sin_theta
    y_new = x * sin_theta + y * cos_theta
    
    return x_new, y_new

# Start with point on x-axis
x, y = 1.0, 0.0
print(f"\nOriginal point: ({x}, {y})")

# Rotate by different angles
for angle_deg in [45, 90, 180]:
    angle_rad = math.radians(angle_deg)
    x_rot, y_rot = rotate_2d(x, y, angle_rad)
    print(f"After {angle_deg}° rotation: ({x_rot:.3f}, {y_rot:.3f})")
    
    # Check magnitude preserved
    original_mag = math.sqrt(x**2 + y**2)
    rotated_mag = math.sqrt(x_rot**2 + y_rot**2)
    print(f"  Magnitude: {original_mag:.3f} → {rotated_mag:.3f} (preserved!)")

print("\nKey insight: Rotation changes direction but preserves magnitude")

print("\n" + "=" * 70)
print("DEMONSTRATION 6: RoPE Relative Position Encoding")
print("=" * 70)

# Simulate RoPE for 2 dimensions
dim = 2
base_freq = 10000

# Compute frequency
inv_freq = 1.0 / (base_freq ** (torch.arange(0, dim, 2).float() / dim))
print(f"\nInverse frequency: {inv_freq}")

# Create rotation angles for different positions
positions = torch.arange(5)  # Positions 0-4
freqs = torch.outer(positions, inv_freq)

print("\nRotation angles (position × frequency):")
print(freqs)

# Simulate dot product between rotated q and k at different positions
pos_i, pos_j = 2, 4  # Token at position 2 attends to position 4

angle_i = freqs[pos_i, 0]
angle_j = freqs[pos_j, 0]
relative_angle = angle_j - angle_i

print(f"\nToken at position {pos_i}: rotation angle = {angle_i:.4f}")
print(f"Token at position {pos_j}: rotation angle = {angle_j:.4f}")
print(f"Relative angle (encodes distance {pos_j - pos_i}): {relative_angle:.4f}")

print("\nThe dot product between rotated vectors naturally depends on")
print("the relative position difference!")

print("\n" + "=" * 70)
print("DEMONSTRATION 7: Learning Rate Schedules")
print("=" * 70)

def linear_warmup(step, warmup_steps, base_lr):
    if step < warmup_steps:
        return base_lr * (step + 1) / warmup_steps
    return base_lr

def cosine_decay(step, warmup_steps, total_steps, max_lr, min_lr):
    if step < warmup_steps:
        return max_lr * (step + 1) / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))

# Simulate training schedule
warmup_steps = 100
total_steps = 1000
base_lr = 0.001

print("\nLearning rate at different steps:")
for step in [0, 50, 100, 500, 900, 999]:
    lr_warmup = linear_warmup(step, warmup_steps, base_lr)
    lr_cosine = cosine_decay(step, warmup_steps, total_steps, base_lr, base_lr * 0.1)
    print(f"Step {step:4d}: warmup={lr_warmup:.6f}, cosine={lr_cosine:.6f}")

print("\nPattern:")
print("  Steps 0-100: LR increases (warmup)")
print("  Steps 100+: LR decreases smoothly (cosine decay)")

print("\n" + "=" * 70)
print("DEMONSTRATION 8: Complete Mini Forward Pass")
print("=" * 70)

# Tiny transformer block simulation
B, T, C = 1, 3, 8
n_head = 2
head_dim = C // n_head

print(f"\nInput shape: [batch={B}, seq_len={T}, dim={C}]")

# 1. Input
x = torch.randn(B, T, C)
print(f"After embeddings: {x.shape}")

# 2. Layer norm
ln = nn.LayerNorm(C)
x_norm = ln(x)
print(f"After LayerNorm: {x_norm.shape}")

# 3. Attention (simplified)
qkv_proj = nn.Linear(C, 3 * C)
qkv = qkv_proj(x_norm)
q, k, v = qkv.chunk(3, dim=-1)

# Reshape for multi-head
q = q.view(B, T, n_head, head_dim).transpose(1, 2)
k = k.view(B, T, n_head, head_dim).transpose(1, 2)
v = v.view(B, T, n_head, head_dim).transpose(1, 2)
print(f"Q, K, V (multi-head): {q.shape}")

# Attention scores
att = (q @ k.transpose(-2, -1)) / math.sqrt(head_dim)
print(f"Attention scores: {att.shape}")

# Causal mask
mask = torch.tril(torch.ones(T, T))
att = att.masked_fill(mask == 0, float('-inf'))
att = F.softmax(att, dim=-1)

# Aggregate values
out = att @ v
print(f"Attention output: {out.shape}")

# Reshape back
out = out.transpose(1, 2).contiguous().view(B, T, C)
print(f"After concatenating heads: {out.shape}")

# 4. Residual connection
x = x + out
print(f"After residual: {x.shape}")

# 5. MLP
x_norm2 = ln(x)
mlp = nn.Sequential(
    nn.Linear(C, 4 * C),
    nn.GELU(),
    nn.Linear(4 * C, C)
)
mlp_out = mlp(x_norm2)
print(f"After MLP: {mlp_out.shape}")

# 6. Final residual
x = x + mlp_out
print(f"Block output: {x.shape}")

print("\n" + "=" * 70)
print("All demonstrations complete!")
print("=" * 70)

DEMONSTRATION 1: Attention Score Computation

Attention scores (before masking):
tensor([[0.5000, 0.5000, 0.0000],
        [0.0000, 0.0000, 0.5000],
        [0.5000, 0.5000, 0.5000]])

Interpretation:
  Row i, Col j = similarity between token i's query and token j's key

After causal masking:
tensor([[0.5000,   -inf,   -inf],
        [0.0000, 0.0000,   -inf],
        [0.5000, 0.5000, 0.5000]])

Attention weights (after softmax):
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

Note: Each row sums to 1.0, representing probability distribution

DEMONSTRATION 2: Multi-Head Attention Reshaping

Original shape: torch.Size([2, 4, 8])
After view: torch.Size([2, 4, 2, 4]) [batch, seq, heads, head_dim]
After transpose: torch.Size([2, 2, 4, 4]) [batch, heads, seq, head_dim]

Now we can process each head in parallel via batched operations!

DEMONSTRATION 3: RMSNorm vs LayerNorm

Input: tensor([[1., 2., 3., 4.]])
LayerNorm output: tensor([[-1.