# Understanding GPT Architecture - Deep Dive

This notebook breaks down the GPT (Generative Pre-trained Transformer) architecture step by step.

## Key Features of This Implementation

This is a **modern, optimized** GPT implementation with several improvements over the original:

1. **Rotary Embeddings (RoPE)** - Better positional encoding than learned embeddings
2. **QK Normalization** - Stabilizes training
3. **Untied Weights** - Separate weights for token embedding and output layer
4. **ReLU²** - Squared ReLU activation in MLP (smoother than GELU)
5. **RMSNorm** - Simpler, faster normalization (no learnable params)
6. **No Bias** - Cleaner, fewer parameters
7. **Group-Query Attention (GQA)** - Efficient inference with KV cache

## What We'll Cover

1. **Configuration** - Model hyperparameters
2. **Normalization** - RMSNorm explained
3. **Rotary Embeddings** - How positional information is encoded
4. **Attention Mechanism** - Multi-head self-attention with GQA
5. **MLP (Feedforward)** - The "thinking" layer
6. **Transformer Block** - Attention + MLP combined
7. **Full GPT Model** - Putting it all together
8. **Forward Pass** - How data flows through the model

## Section 1: Configuration and Setup

In [None]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F

# GPT Configuration
@dataclass
class GPTConfig:
    sequence_len: int = 1024  # Maximum sequence length (context window)
    vocab_size: int = 50304   # Vocabulary size (number of unique tokens)
    n_layer: int = 12         # Number of transformer blocks
    n_head: int = 6           # Number of query attention heads
    n_kv_head: int = 6        # Number of key/value heads (for GQA)
    n_embd: int = 768         # Embedding dimension (model width)

# Create a small config for demonstration
config = GPTConfig(
    sequence_len=128,
    vocab_size=512,
    n_layer=4,
    n_head=4,
    n_kv_head=4,  # Same as n_head = standard multi-head attention
    n_embd=256
)

print("GPT Configuration:")
print(f"  Sequence length: {config.sequence_len} tokens")
print(f"  Vocabulary size: {config.vocab_size}")
print(f"  Number of layers: {config.n_layer}")
print(f"  Number of heads: {config.n_head}")
print(f"  Embedding dimension: {config.n_embd}")
print(f"  Head dimension: {config.n_embd // config.n_head}")
print()
print(f"Why these values?")
print(f"  - sequence_len: How much context the model can see")
print(f"  - vocab_size: How many unique tokens (from tokenizer)")
print(f"  - n_layer: Depth of the network (more = more capacity)")
print(f"  - n_head: Parallel attention computations (more = richer representations)")
print(f"  - n_embd: Width of the network (larger = more parameters)")

## Section 2: RMSNorm - Root Mean Square Normalization

**Why normalize?** Neural networks train better when activations are kept in a reasonable range. Normalization prevents exploding/vanishing gradients.

**Why RMSNorm over LayerNorm?**
- Simpler: No learnable scale/shift parameters
- Faster: Fewer computations
- Works just as well in practice

**Formula:** `RMSNorm(x) = x / RMS(x)` where `RMS(x) = sqrt(mean(x²))`

In [None]:
def norm(x):
    """
    Purely functional RMSNorm with no learnable parameters
    """
    return F.rms_norm(x, (x.size(-1),))

# Test it with a sample tensor
x = torch.randn(2, 4, 256)  # (batch, seq_len, embd_dim)
print(f"Input shape: {x.shape}")
print(f"Input mean: {x.mean():.4f}, std: {x.std():.4f}")
print()

x_normed = norm(x)
print(f"After RMSNorm:")
print(f"  Output shape: {x_normed.shape}")
print(f"  Output mean: {x_normed.mean():.4f}, std: {x_normed.std():.4f}")
print()

# Verify RMSNorm formula manually
rms = torch.sqrt((x ** 2).mean(dim=-1, keepdim=True))
x_manual = x / rms
print(f"Manual RMSNorm matches? {torch.allclose(x_normed, x_manual, atol=1e-6)}")
print()
print("Key insight: RMSNorm scales each vector to have RMS = 1")
print("This keeps activations in a stable range throughout the network")

## Section 3: Rotary Positional Embeddings (RoPE)

**The Problem:** Transformers have no inherent notion of position. "cat sat mat" and "mat sat cat" look the same!

**Old Solution:** Add positional embeddings to token embeddings (GPT-2 style)

**Better Solution:** Rotary Embeddings (RoPE)
- Apply rotation to query and key vectors based on their position
- Encodes **relative** position (distance between tokens) rather than absolute
- Better generalization to longer sequences
- No extra parameters to learn!

### How RoPE Works

1. **Frequency Computation:** Different dimensions get different rotation frequencies
2. **Rotation:** Rotate pairs of dimensions by an angle proportional to position
3. **Application:** Apply to both queries and keys in attention

**Key Insight:** The dot product between rotated Q and K encodes their relative distance!

In [None]:
# STEP 1: Precompute Rotary Embeddings

def precompute_rotary_embeddings(seq_len, head_dim, base=10000):
    """
    Precompute cos and sin values for rotary embeddings
    
    Args:
        seq_len: Maximum sequence length
        head_dim: Dimension of each attention head
        base: Base for frequency computation (10000 is standard)
    
    Returns:
        cos, sin: Precomputed rotation matrices
    """
    # Step 1: Compute inverse frequencies for each dimension pair
    # We process dimensions in pairs: (0,1), (2,3), (4,5), ...
    channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32)
    inv_freq = 1.0 / (base ** (channel_range / head_dim))
    
    print(f"Step 1: Compute frequencies")
    print(f"  Head dimension: {head_dim}")
    print(f"  Number of dimension pairs: {head_dim // 2}")
    print(f"  Inverse frequencies shape: {inv_freq.shape}")
    print(f"  inv_freq values (first 4): {inv_freq[:4].tolist()}")
    print()
    
    # Step 2: Create position indices
    t = torch.arange(seq_len, dtype=torch.float32)
    
    # Step 3: Compute rotation angles = position * frequency
    # outer product: each position gets paired with each frequency
    freqs = torch.outer(t, inv_freq)  # (seq_len, head_dim/2)
    
    print(f"Step 2: Compute rotation angles")
    print(f"  Position range: 0 to {seq_len-1}")
    print(f"  Frequencies shape: {freqs.shape}")
    print(f"  Frequencies at position 0: {freqs[0, :4].tolist()}")
    print(f"  Frequencies at position 10: {freqs[10, :4].tolist()}")
    print()
    
    # Step 4: Compute cos and sin
    cos, sin = freqs.cos(), freqs.sin()
    
    # Step 5: Reshape for broadcasting in attention
    # (1, seq_len, 1, head_dim/2) for broadcasting over (B, H, T, D)
    cos = cos[None, :, None, :]
    sin = sin[None, :, None, :]
    
    print(f"Step 3: Final cos/sin tensors")
    print(f"  cos shape: {cos.shape} (batch, seq_len, head, head_dim/2)")
    print(f"  sin shape: {sin.shape}")
    print()
    
    return cos, sin

# Demo with our config
head_dim = config.n_embd // config.n_head
cos, sin = precompute_rotary_embeddings(config.sequence_len, head_dim)

print(f"✓ Rotary embeddings precomputed for {config.sequence_len} positions")
print(f"  These are computed once and cached, not learned!")

In [None]:
# STEP 2: Apply Rotary Embeddings

def apply_rotary_emb(x, cos, sin):
    """
    Apply rotary embeddings to query or key tensor
    
    The rotation is applied to pairs of dimensions:
    - Dimensions (0,1) rotate together
    - Dimensions (2,3) rotate together
    - etc.
    
    Mathematically, for each pair (x1, x2):
        y1 = x1 * cos + x2 * sin
        y2 = -x1 * sin + x2 * cos
    
    This is a 2D rotation matrix!
    """
    assert x.ndim == 4  # (batch, n_head, seq_len, head_dim)
    d = x.shape[3] // 2
    
    # Split into two halves (pairs of dimensions)
    x1, x2 = x[..., :d], x[..., d:]
    
    # Rotate each pair
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    
    # Concatenate back
    return torch.cat([y1, y2], dim=3)

# Demonstrate on sample query vectors
B, H, T, D = 2, 4, 8, 64  # batch, heads, seq_len, head_dim
q = torch.randn(B, H, T, D)

print(f"Query tensor shape: {q.shape}")
print(f"  B={B} (batch), H={H} (heads), T={T} (seq_len), D={D} (head_dim)")
print()

# Apply rotary embeddings
q_rotated = apply_rotary_emb(q, cos[:, :T], sin[:, :T])

print(f"After rotation: {q_rotated.shape}")
print(f"  Same shape, but now encodes positional information!")
print()

# Key insight: The dot product between rotated Q and K at positions i and j
# depends on (i - j), giving us relative positional encoding!
print("Why this works:")
print("  - Queries and keys at position i get rotated by angle i*freq")
print("  - The dot product Q[i] · K[j] includes cos((i-j)*freq) terms")
print("  - This encodes the RELATIVE distance (i-j) between positions")
print("  - Model learns: 'pay attention to tokens N positions away'")

## Section 4: Causal Self-Attention Mechanism

**Attention is the core innovation of Transformers.** It allows the model to "look at" and weigh the importance of all previous tokens when processing each token.

### The Attention Formula

```
Attention(Q, K, V) = softmax(Q @ K^T / √d) @ V
```

Where:
- **Q (Query)**: "What am I looking for?"
- **K (Key)**: "What do I contain?"
- **V (Value)**: "What information do I carry?"

### Multi-Head Attention

Instead of one attention, we use multiple "heads" in parallel:
- Each head learns different patterns
- Head 1 might focus on syntax, Head 2 on semantics, etc.
- Outputs are concatenated and projected back

### Group-Query Attention (GQA)

**Optimization:** Share K and V across multiple Q heads
- Standard: `n_head` sets of Q, K, V
- GQA: `n_head` Q heads, but only `n_kv_head` K/V heads
- Saves memory and computation during inference
- Example: 8 Q heads might share 2 K/V heads (4x reduction!)

### Causal Masking

**Key constraint:** Token at position `i` can only attend to positions `≤ i`
- Prevents "looking into the future"
- Essential for autoregressive generation

In [None]:
class CausalSelfAttention(nn.Module):
    """
    Multi-head causal self-attention with Group-Query Attention (GQA) support
    """
    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.n_head = config.n_head          # Number of query heads
        self.n_kv_head = config.n_kv_head    # Number of key/value heads
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        
        # Ensure dimensions are compatible
        assert self.n_embd % self.n_head == 0, "n_embd must be divisible by n_head"
        assert self.n_kv_head <= self.n_head, "Can't have more KV heads than Q heads"
        assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head"
        
        # Projection matrices (no bias!)
        self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
        self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
    
    def forward(self, x, cos_sin, kv_cache=None):
        B, T, C = x.size()  # batch, sequence length, embedding dim
        
        # STEP 1: Project input to Q, K, V
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
        
        # STEP 2: Apply rotary embeddings to Q and K
        cos, sin = cos_sin
        q = apply_rotary_emb(q, cos, sin)
        k = apply_rotary_emb(k, cos, sin)
        
        # STEP 3: QK Normalization (stabilizes training)
        q, k = norm(q), norm(k)
        
        # STEP 4: Transpose for attention computation
        # (B, T, H, D) -> (B, H, T, D) - make head dimension the batch dimension
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # STEP 5: Compute attention
        # For simplicity, we'll just use scaled_dot_product_attention
        enable_gqa = self.n_head != self.n_kv_head
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
        
        # STEP 6: Reshape and project output
        y = y.transpose(1, 2).contiguous().view(B, T, -1)
        y = self.c_proj(y)
        
        return y

# Create and test attention module
attn = CausalSelfAttention(config, layer_idx=0)

print("✓ Causal Self-Attention Module Created")
print(f"  Query heads: {attn.n_head}")
print(f"  KV heads: {attn.n_kv_head}")
print(f"  Head dimension: {attn.head_dim}")
print(f"  GQA enabled: {attn.n_head != attn.n_kv_head}")
print()
print("Parameters:")
print(f"  c_q: {attn.c_q.weight.shape} ({attn.c_q.weight.numel():,} params)")
print(f"  c_k: {attn.c_k.weight.shape} ({attn.c_k.weight.numel():,} params)")
print(f"  c_v: {attn.c_v.weight.shape} ({attn.c_v.weight.numel():,} params)")
print(f"  c_proj: {attn.c_proj.weight.shape} ({attn.c_proj.weight.numel():,} params)")

# Test forward pass
x_test = torch.randn(2, 8, config.n_embd)
y_test = attn(x_test, (cos[:, :8], sin[:, :8]))
print()
print(f"✓ Forward pass successful!")
print(f"  Input: {x_test.shape}")
print(f"  Output: {y_test.shape}")

### Attention Step-by-Step Visualization

Let's manually compute attention for a tiny example to see what's happening:

In [None]:
# Manual attention computation for educational purposes
# Simplified: 1 head, 4 tokens, dimension 8

seq_len = 4
dim = 8

# Random Q, K, V (normally these come from linear projections)
Q = torch.randn(1, 1, seq_len, dim)  # (batch, heads, seq_len, dim)
K = torch.randn(1, 1, seq_len, dim)
V = torch.randn(1, 1, seq_len, dim)

print("="*60)
print("MANUAL ATTENTION COMPUTATION")
print("="*60)
print()

# Step 1: Compute attention scores (Q @ K^T)
scores = Q @ K.transpose(-2, -1)  # (1, 1, 4, 4)
print("Step 1: Compute Q @ K^T")
print(f"  Scores shape: {scores.shape}")
print(f"  Scores (before scaling):")
print(scores[0, 0].numpy())
print()

# Step 2: Scale by sqrt(dim)
scores = scores / math.sqrt(dim)
print(f"Step 2: Scale by 1/√{dim} = {1/math.sqrt(dim):.3f}")
print(f"  Scaled scores:")
print(scores[0, 0].numpy())
print()

# Step 3: Apply causal mask (prevent looking ahead)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
print("Step 3: Apply causal mask (set future positions to -inf)")
print("  Mask (True = masked):")
print(mask.int().numpy())
print(f"  Masked scores:")
print(scores[0, 0].numpy())
print()

# Step 4: Softmax to get attention weights
attn_weights = F.softmax(scores, dim=-1)
print("Step 4: Softmax (convert to probabilities)")
print("  Attention weights (each row sums to 1):")
print(attn_weights[0, 0].numpy())
print()
print("  Interpretation:")
print("    Row 0: Token 0 attends 100% to itself (can't see future)")
print("    Row 1: Token 1 attends to tokens 0 and 1")
print("    Row 2: Token 2 attends to tokens 0, 1, and 2")
print("    Row 3: Token 3 attends to all tokens 0-3")
print()

# Step 5: Weighted sum of values
output = attn_weights @ V
print("Step 5: Multiply attention weights by V (weighted sum)")
print(f"  Output shape: {output.shape}")
print()

print("="*60)
print("Key Takeaway:")
print("  Each token's output is a weighted combination of all")
print("  previous tokens' values, where weights come from Q·K")
print("="*60)

## Section 5: MLP (Multi-Layer Perceptron) - The Feedforward Network

After attention gathers information from other tokens, the **MLP processes this information**.

### Standard MLP Structure

```
x → Linear(expand 4x) → Activation → Linear(project back) → output
```

### Why 4x Expansion?

- Attention: `n_embd → n_embd` (no change in dimension)
- MLP: `n_embd → 4*n_embd → n_embd`
- The expansion gives the model "room to think"
- 4x is empirically found to work well (from original Transformer paper)

### ReLU² Activation

This implementation uses **ReLU²** instead of GELU:
- `ReLU²(x) = ReLU(x)² = max(0, x)²`
- Smoother than standard ReLU
- Faster to compute than GELU
- Works well in practice

### Role of MLP

- **Attention**: Gathers information ("what's relevant?")
- **MLP**: Processes information ("what does it mean?")

In [None]:
class MLP(nn.Module):
    """
    Simple MLP with 4x expansion and ReLU² activation
    """
    def __init__(self, config):
        super().__init__()
        # Expand by 4x
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        # Project back to original size
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
    
    def forward(self, x):
        # Expand
        x = self.c_fc(x)
        # Activate with ReLU²
        x = F.relu(x).square()
        # Project back
        x = self.c_proj(x)
        return x

# Create and test MLP
mlp = MLP(config)

print("✓ MLP Module Created")
print()
print("Architecture:")
print(f"  Input:  {config.n_embd} dimensions")
print(f"  Expand: {4 * config.n_embd} dimensions (4x)")
print(f"  Output: {config.n_embd} dimensions")
print()
print("Parameters:")
print(f"  c_fc:   {mlp.c_fc.weight.shape} ({mlp.c_fc.weight.numel():,} params)")
print(f"  c_proj: {mlp.c_proj.weight.shape} ({mlp.c_proj.weight.numel():,} params)")
print(f"  Total:  {mlp.c_fc.weight.numel() + mlp.c_proj.weight.numel():,} params")
print()

# Test forward pass
x_test = torch.randn(2, 8, config.n_embd)
y_test = mlp(x_test)

print(f"✓ Forward pass successful!")
print(f"  Input:  {x_test.shape}")
print(f"  Output: {y_test.shape}")
print()

# Visualize ReLU² activation
x_act = torch.linspace(-2, 2, 100)
relu_act = F.relu(x_act)
relu2_act = F.relu(x_act).square()

print("Comparing activations at x=0, 0.5, 1, 1.5, 2:")
for val in [0, 0.5, 1.0, 1.5, 2.0]:
    x_val = torch.tensor([val])
    relu = F.relu(x_val).item()
    relu2 = F.relu(x_val).square().item()
    print(f"  x={val:.1f}: ReLU={relu:.3f}, ReLU²={relu2:.3f}")

## Section 6: Transformer Block - Combining Attention and MLP

A **Transformer Block** is the fundamental building block, combining attention and MLP with **residual connections** and **layer normalization**.

### Structure (Pre-Norm Architecture)

```
x → norm → attention → (+) → norm → MLP → (+) → output
    ↓__________________|      ↓______________|
       residual connection    residual connection
```

### Key Design Choices

1. **Pre-Norm** (vs Post-Norm):
   - Normalize BEFORE attention/MLP, not after
   - More stable training, especially for deep networks
   - Modern standard (GPT-3, LLaMA, etc.)

2. **Residual Connections** (`x = x + f(x)`):
   - Allow gradients to flow directly through the network
   - Prevent vanishing gradients in deep networks
   - Model can learn identity function easily (just set f(x)=0)

3. **Why This Works**:
   - Each block makes a small "update" to the representation
   - Information flows through both the residual path (untouched) and the transformation path
   - Deep networks become easier to train

In [None]:
class Block(nn.Module):
    """
    Transformer block: attention + MLP with residual connections
    """
    def __init__(self, config, layer_idx):
        super().__init__()
        self.attn = CausalSelfAttention(config, layer_idx)
        self.mlp = MLP(config)
    
    def forward(self, x, cos_sin, kv_cache=None):
        # Attention block with residual connection
        # Pre-norm: normalize first, then apply attention, then add residual
        x = x + self.attn(norm(x), cos_sin, kv_cache)
        
        # MLP block with residual connection
        # Pre-norm: normalize first, then apply MLP, then add residual
        x = x + self.mlp(norm(x))
        
        return x

# Create and test a single transformer block
block = Block(config, layer_idx=0)

print("✓ Transformer Block Created")
print()
print("Structure:")
print("  1. x = x + attention(norm(x))  ← Attention with residual")
print("  2. x = x + mlp(norm(x))        ← MLP with residual")
print()

# Count parameters
n_params = sum(p.numel() for p in block.parameters())
n_attn = sum(p.numel() for p in block.attn.parameters())
n_mlp = sum(p.numel() for p in block.mlp.parameters())

print("Parameters:")
print(f"  Attention: {n_attn:,}")
print(f"  MLP:       {n_mlp:,}")
print(f"  Total:     {n_params:,}")
print(f"  MLP is {n_mlp/n_attn:.1f}x larger than attention!")
print()

# Test forward pass
x_test = torch.randn(2, 8, config.n_embd)
y_test = block(x_test, (cos[:, :8], sin[:, :8]))

print(f"✓ Forward pass successful!")
print(f"  Input:  {x_test.shape}")
print(f"  Output: {y_test.shape}")
print()

# Demonstrate the residual connection
x_small = torch.randn(1, 4, config.n_embd)
x_input = x_small.clone()

# If we zero out the block's parameters, we get identity function
with torch.no_grad():
    for p in block.parameters():
        p.zero_()

x_output = block(x_small, (cos[:, :4], sin[:, :4]))

print("Residual Connection Demonstration:")
print("  When all parameters are zero:")
print(f"  Input == Output? {torch.allclose(x_input, x_output)}")
print("  This shows the residual connection allows identity mapping!")

## Section 7: Complete GPT Model - Putting It All Together

The full GPT model stacks multiple transformer blocks and adds input/output layers.

### Architecture Overview

```
Token IDs (integers)
       ↓
Token Embedding (wte) → Vectors
       ↓
   RMSNorm (stabilize)
       ↓
 Block 1 (attn + MLP)
       ↓
 Block 2 (attn + MLP)
       ↓
      ...
       ↓
 Block N (attn + MLP)
       ↓
   RMSNorm (final)
       ↓
Language Model Head (lm_head) → Logits
       ↓
   Softmax → Probabilities
```

### Key Components

1. **Token Embedding (wte)**: Maps token IDs to vectors
2. **N Transformer Blocks**: Process and transform representations
3. **Language Model Head (lm_head)**: Maps final vectors to vocab logits
4. **Untied Weights**: `wte` and `lm_head` are separate (not shared)

In [None]:
class GPT(nn.Module):
    """
    Simplified GPT model for demonstration
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        
        # Stack of transformer blocks
        self.blocks = nn.ModuleList([
            Block(config, layer_idx) 
            for layer_idx in range(config.n_layer)
        ])
        
        # Language model head (output projection)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Precompute rotary embeddings
        head_dim = config.n_embd // config.n_head
        self.cos, self.sin = precompute_rotary_embeddings(config.sequence_len, head_dim)
    
    def forward(self, idx, targets=None):
        B, T = idx.size()  # batch size, sequence length
        
        # Get rotary embeddings for this sequence
        cos_sin = (self.cos[:, :T], self.sin[:, :T])
        
        # Step 1: Token embedding
        x = self.wte(idx)  # (B, T, n_embd)
        
        # Step 2: Normalize embeddings (modern practice)
        x = norm(x)
        
        # Step 3: Pass through all transformer blocks
        for block in self.blocks:
            x = block(x, cos_sin)
        
        # Step 4: Final normalization
        x = norm(x)
        
        # Step 5: Project to vocabulary logits
        logits = self.lm_head(x)  # (B, T, vocab_size)
        
        # If targets provided, compute loss
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), 
                targets.view(-1),
                ignore_index=-1
            )
            return loss
        else:
            return logits

# Create the complete model
model = GPT(config)

print("="*60)
print("COMPLETE GPT MODEL")
print("="*60)
print()
print(f"Configuration:")
print(f"  Vocabulary size: {config.vocab_size:,}")
print(f"  Embedding dimension: {config.n_embd}")
print(f"  Number of layers: {config.n_layer}")
print(f"  Number of heads: {config.n_head}")
print(f"  Sequence length: {config.sequence_len}")
print()

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
n_embed = model.wte.weight.numel()
n_blocks = sum(p.numel() for p in model.blocks.parameters())
n_head = model.lm_head.weight.numel()

print(f"Parameters:")
print(f"  Token embedding: {n_embed:,}")
print(f"  Transformer blocks: {n_blocks:,}")
print(f"  LM head: {n_head:,}")
print(f"  Total: {n_params:,}")
print()
print(f"  Most parameters ({n_blocks/n_params*100:.1f}%) are in the transformer blocks!")
print()

# Test forward pass
batch_size = 2
seq_len = 16
token_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))

print(f"Forward pass test:")
print(f"  Input shape: {token_ids.shape} (batch, seq_len)")

logits = model(token_ids)
print(f"  Output shape: {logits.shape} (batch, seq_len, vocab_size)")
print()
print(f"✓ Model successfully produces logits for each token position!")

## Section 8: Complete Forward Pass Trace

Let's trace how data flows through the entire model for a single example.

In [None]:
# Trace a forward pass step by step
print("="*70)
print("FORWARD PASS TRACE")
print("="*70)
print()

# Create a small input (batch=1, seq_len=4 for clarity)
input_ids = torch.tensor([[42, 100, 256, 89]])  # Example token IDs
B, T = input_ids.shape

print(f"Input: {input_ids.tolist()[0]}")
print(f"  Shape: {input_ids.shape} (batch_size=1, seq_len=4)")
print()

# Manually trace through the forward pass
print("Step 1: Token Embedding")
x = model.wte(input_ids)
print(f"  {input_ids.shape} → {x.shape}")
print(f"  Each token ID is mapped to a {config.n_embd}-dimensional vector")
print(f"  Example: token {input_ids[0, 0].item()} → vector of shape {x[0, 0].shape}")
print()

print("Step 2: Initial RMSNorm")
x = norm(x)
print(f"  {x.shape} (shape unchanged)")
print(f"  Normalizes embedding vectors for stable training")
print()

# Through each block
for layer_idx, block in enumerate(model.blocks):
    print(f"Step {3+layer_idx}: Transformer Block {layer_idx+1}")
    
    x_before = x.clone()
    cos_sin = (model.cos[:, :T], model.sin[:, :T])
    x = block(x, cos_sin)
    
    # Check how much the block changed the representation
    change_norm = (x - x_before).norm().item()
    
    print(f"  Input:  {x_before.shape}")
    print(f"  Output: {x.shape}")
    print(f"  Change magnitude: {change_norm:.4f}")
    print(f"  ├─ Attention: Each token attends to previous tokens")
    print(f"  └─ MLP: Process gathered information")
    print()

print(f"Step {3+config.n_layer}: Final RMSNorm")
x = norm(x)
print(f"  {x.shape} (shape unchanged)")
print()

print(f"Step {4+config.n_layer}: Language Model Head")
logits = model.lm_head(x)
print(f"  {x.shape} → {logits.shape}")
print(f"  Projects each position's vector to vocabulary logits")
print()

print("Step Final: Convert Logits to Probabilities (Softmax)")
probs = F.softmax(logits, dim=-1)
print(f"  {logits.shape} → {probs.shape}")
print(f"  Each position now has a probability distribution over {config.vocab_size} tokens")
print()

# Show predictions for the last position
last_pos_probs = probs[0, -1]  # Probabilities for next token after sequence
top5_probs, top5_indices = torch.topk(last_pos_probs, k=5)

print("Predictions for next token (after last position):")
print("  Top 5 most likely tokens:")
for i, (prob, idx) in enumerate(zip(top5_probs, top5_indices)):
    print(f"    {i+1}. Token {idx.item():3d}: {prob.item()*100:5.2f}%")
print()

print("="*70)
print("Summary: The model transformed token IDs → meaningful representations")
print("         → probability distributions for next token prediction!")
print("="*70)

## Summary: Why This Architecture Works

### Key Insights

1. **Attention is All You Need**
   - Self-attention lets each token gather information from all previous tokens
   - Multi-head attention learns different types of relationships
   - Causal masking ensures autoregressive generation

2. **Positional Encoding Matters**
   - RoPE encodes relative positions through rotation
   - No learnable parameters needed
   - Better generalization than absolute positions

3. **Residual Connections Enable Deep Networks**
   - Direct gradient flow prevents vanishing gradients
   - Each layer makes small refinements
   - Model can learn identity if needed

4. **Normalization Stabilizes Training**
   - RMSNorm keeps activations in reasonable range
   - Pre-norm (normalize before, not after) is more stable
   - QK norm in attention prevents training instabilities

5. **MLP Provides Compute**
   - Attention: "What information to gather?"
   - MLP: "How to process that information?"
   - 4x expansion gives model "thinking room"

### Modern Improvements Over Original GPT

| Feature | Original GPT-2 | This Implementation |
|---------|---------------|---------------------|
| Position Encoding | Learned embeddings | Rotary (RoPE) |
| Normalization | LayerNorm | RMSNorm |
| Attention Stability | None | QK Normalization |
| Activation | GELU | ReLU² |
| Weights | Tied wte/lm_head | Untied |
| Bias | Yes | No |
| KV Cache | Standard | GQA support |

### Parameter Scaling

For a typical config:
- **Embedding layer**: ~10% of parameters
- **Transformer blocks**: ~85% of parameters
  - Attention: ~30% of block params
  - MLP: ~70% of block params
- **LM head**: ~5% of parameters

**Key insight**: Most computation is in the MLP layers!

## Visual Architecture Diagram

```
┌─────────────────────────────────────────────────────────────┐
│                         GPT MODEL                            │
└─────────────────────────────────────────────────────────────┘

Input: Token IDs [42, 100, 256, 89]
         │
         ▼
┌──────────────────────┐
│  Token Embedding     │  vocab_size → n_embd
│  (lookup table)      │  [42] → [0.1, -0.5, ..., 0.3]
└──────────────────────┘
         │
         ▼
┌──────────────────────┐
│     RMSNorm          │  Normalize embeddings
└──────────────────────┘
         │
         ▼
┌──────────────────────────────────────────────┐
│            TRANSFORMER BLOCK 1                │
│  ┌─────────────────────────────────────┐     │
│  │ x → Norm → Attention → (+) ← x      │     │
│  │              ↓                       │     │
│  │         Apply RoPE                   │     │
│  │         QK Norm                      │     │
│  │         Softmax(QK^T/√d)V           │     │
│  └─────────────────────────────────────┘     │
│              │                                │
│  ┌───────────▼───────────────────────────┐   │
│  │ x → Norm → MLP → (+) ← x              │   │
│  │         Expand 4x                      │   │
│  │         ReLU²                          │   │
│  │         Project back                   │   │
│  └────────────────────────────────────────┘   │
└──────────────────────────────────────────────┘
         │
         ▼
┌──────────────────────────────────────────────┐
│         TRANSFORMER BLOCK 2...N               │
│         (same structure)                      │
└──────────────────────────────────────────────┘
         │
         ▼
┌──────────────────────┐
│     RMSNorm          │  Final normalization
└──────────────────────┘
         │
         ▼
┌──────────────────────┐
│   Language Model     │  n_embd → vocab_size
│      Head            │  vectors → logits
└──────────────────────┘
         │
         ▼
┌──────────────────────┐
│     Softmax          │  logits → probabilities
└──────────────────────┘
         │
         ▼
Output: Probabilities over vocabulary
        [0.01, 0.001, ..., 0.15, ...]
        "Token 256 has 15% probability"
```

---

## Next Steps

To dive deeper:

1. **Training**: See how the model learns from data (loss functions, backpropagation)
2. **Generation**: Understand sampling strategies (greedy, top-k, top-p, temperature)
3. **Optimization**: Learn about modern optimizers (AdamW, Muon)
4. **Scaling Laws**: Understand how performance scales with model size and data
5. **Fine-tuning**: Adapt pre-trained models to specific tasks

## References

- Original Transformer: "Attention is All You Need" (Vaswani et al., 2017)
- GPT-2: "Language Models are Unsupervised Multitask Learners" (Radford et al., 2019)
- RoPE: "RoFormer: Enhanced Transformer with Rotary Position Embedding" (Su et al., 2021)
- GQA: "GQA: Training Generalized Multi-Query Transformer" (Ainslie et al., 2023)

---

## Practice Exercise

Try modifying the config and observe changes:
- Increase `n_layer`: Deeper network, more capacity
- Increase `n_embd`: Wider network, more parameters
- Change `n_kv_head < n_head`: Enable GQA
- Increase `sequence_len`: Longer context window

See how parameter count and computation change!

In [None]:
# Quick Reference: Print model summary
print("="*70)
print("GPT ARCHITECTURE QUICK REFERENCE")
print("="*70)
print()

print("Components:")
print("  1. Token Embedding (wte): token_id → vector")
print("  2. RMSNorm: x / sqrt(mean(x²))")
print("  3. Rotary Embeddings (RoPE): Rotate Q,K by position")
print("  4. Multi-Head Attention: Softmax(QK^T/√d)V")
print("  5. QK Normalization: Normalize Q,K before attention")
print("  6. MLP: Linear → ReLU² → Linear")
print("  7. Residual: x = x + f(x)")
print("  8. Language Model Head: vector → logits")
print()

print("Data Flow:")
print("  tokens → embed → norm → [attn+MLP blocks] → norm → lm_head → logits")
print()

print("Key Formulas:")
print("  • Attention: softmax(Q @ K^T / √d) @ V")
print("  • RMSNorm: x / sqrt(mean(x²))")
print("  • RoPE: rotate(x, angle=position*frequency)")
print("  • ReLU²: max(0, x)²")
print("  • Residual: x_out = x_in + transformation(x_in)")
print()

print("Parameter Breakdown (for our config):")
total_params = sum(p.numel() for p in model.parameters())
print(f"  Total parameters: {total_params:,}")
print(f"  Memory (fp32): {total_params * 4 / 1024**2:.1f} MB")
print(f"  Memory (bf16): {total_params * 2 / 1024**2:.1f} MB")
print()

print("="*70)
print("You now understand the core GPT architecture!")
print("="*70)