# Training Transformers: From Random to Intelligent

Training a transformer is like teaching someone to complete sentences. They start by guessing randomly, but through practice and feedback, they learn language patterns.

## The Learning Process
1. **Make a guess** - Predict the next word
2. **Get feedback** - Check if guess was right  
3. **Adjust approach** - Update internal understanding
4. **Try again** - Repeat millions of times

## What You'll Learn
- **Training loop fundamentals** - The 4-step process
- **Loss and optimization** - How models improve
- **Watching learning happen** - See intelligence emerge
- **When to stop** - Recognizing completion

import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Dict
import math

plt.style.use('default')
sns.set_palette("husl")
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## What is Training?

Demonstrate the core concept through analogy and simple examples.

## Training Loop Fundamentals

Show the 4-step training process that creates AI intelligence.

In [ ]:
# Learning analogy: practice makes perfect
attempts = ["Wrong note!", "Better!", "Almost!", "Perfect!"]
mistakes = [8.5, 6.2, 3.1, 1.0]

print("🎹 Piano Learning Example:")
for i, (attempt, mistake) in enumerate(zip(attempts, mistakes)):
    print(f"Practice {i+1}: {attempt:<12} (Mistake level: {mistake})")

print(f"Mistake reduction: {mistakes[0]} → {mistakes[-1]} (Learning!)")

print("\n🤖 Transformer Learning:")
predictions = ["dog", "table", "floor", "mat"]
confidence = [20, 45, 60, 95]

print("Context: 'The cat sat on the ___'")
print("Correct answer: 'mat'")

for i, (pred, conf) in enumerate(zip(predictions, confidence)):
    status = "✅" if pred == "mat" else "❌"
    print(f"Attempt {i+1}: '{pred}' ({conf}% sure) {status}")

print("\n💡 Key insight: Each mistake teaches the model!")
print("Model learns: 'After \"on the\", try \"mat\" not \"dog\"'")

In [None]:
## Simple Training Implementation

Build a minimal transformer and demonstrate the 4-step training loop.

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=32, n_heads=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.norm = nn.LayerNorm(d_model)
        self.output = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        x = self.embedding(x)
        attn_out, _ = self.attention(x, x, x)
        x = self.norm(x + attn_out)
        return self.output(x)

# Create simple training data
text = "The cat sat on the mat"
chars = sorted(set(text))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
vocab_size = len(chars)

# Convert text to indices for training
indices = [char_to_idx[ch] for ch in text]
inputs = torch.tensor(indices[:-1]).unsqueeze(0)   # Input sequence
targets = torch.tensor(indices[1:]).unsqueeze(0)   # Target sequence (shifted by 1)

print(f"Training text: '{text}'")
print(f"Vocabulary: {chars}")
print(f"Vocab size: {vocab_size}")
print(f"Input sequence: {inputs[0].tolist()}")
print(f"Target sequence: {targets[0].tolist()}")

# Initialize model and optimizer
model = TinyTransformer(vocab_size)
optimizer = optim.Adam(model.parameters(), lr=0.01)

def training_step(model, inputs, targets, optimizer):
    """Demonstrate the 4-step training process"""
    
    # Step 1: Forward pass - make predictions
    logits = model(inputs)
    print("1️⃣ Forward pass: Made predictions ✓")
    
    # Step 2: Calculate loss - measure mistakes  
    loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
    print(f"2️⃣ Loss calculation: {loss.item():.3f} (lower = better) ✓")
    
    # Step 3: Backward pass - calculate adjustments
    optimizer.zero_grad()
    loss.backward()
    print("3️⃣ Backward pass: Calculated gradients ✓")
    
    # Step 4: Update weights - apply adjustments
    optimizer.step()
    print("4️⃣ Weight update: Applied changes ✓\n")
    
    return loss.item()

print("🚀 Training steps demonstration:")
losses = []
for step in range(3):
    print(f"Training Step {step + 1}:")
    loss = training_step(model, inputs, targets, optimizer)
    losses.append(loss)

print("📊 Learning progress:")
for i, loss in enumerate(losses):
    trend = "📉 Improving!" if i > 0 and loss < losses[i-1] else "📈"
    print(f"Step {i+1}: Loss = {loss:.3f} {trend}")

print("\n✅ Training loop complete - model got smarter!")

In [None]:
## Extended Training Session

Train for more steps and watch the transformer learn to generate text.

def generate_text(model, char_to_idx, idx_to_char, start_text="The", max_length=10):
    """Generate text to see what the model has learned"""
    model.eval()
    with torch.no_grad():
        current = [char_to_idx.get(ch, 0) for ch in start_text]
        
        for _ in range(max_length - len(start_text)):
            if len(current) == 0:
                break
                
            x = torch.tensor(current).unsqueeze(0)
            logits = model(x)
            next_char_logits = logits[0, -1, :]
            
            # Sample next character
            probs = F.softmax(next_char_logits, dim=-1)
            next_idx = torch.multinomial(probs, 1).item()
            current.append(next_idx)
        
        result = ''.join(idx_to_char.get(i, '?') for i in current)
        return result

model.train()
idx_to_char = {i: ch for ch, i in char_to_idx.items()}

# Extended training with monitoring
losses = []
samples = []

print("🔬 Watching transformer learn step by step:")
print("Step | Loss  | Generated Sample")
print("-" * 40)

for step in range(50):
    # Training step
    optimizer.zero_grad()
    logits = model(inputs)
    loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    # Generate sample every 10 steps to see progress
    if step % 10 == 0:
        sample = generate_text(model, char_to_idx, idx_to_char, "The", 15)
        samples.append((step, sample))
        print(f"{step:4d} | {loss.item():.3f} | '{sample}'")

# Visualize learning curve
plt.figure(figsize=(10, 6))
plt.plot(losses, 'b-', linewidth=2, alpha=0.8)
plt.title('🧠 Transformer Learning Progress')
plt.xlabel('Training Step')
plt.ylabel('Loss (Lower = Better)')
plt.grid(True, alpha=0.3)

# Add annotations to show learning phases
plt.annotate('🤔 Random guessing', xy=(0, losses[0]), 
            xytext=(5, losses[0] + 0.2), fontsize=12,
            arrowprops=dict(arrowstyle='->', color='red'))
            
plt.annotate('🎯 Learning patterns!', xy=(len(losses)-1, losses[-1]), 
            xytext=(len(losses)-10, losses[-1] + 0.2), fontsize=12,
            arrowprops=dict(arrowstyle='->', color='green'))

plt.show()

print(f"\n🎭 Generated samples over time:")
for step, sample in samples:
    interpretation = "Gibberish" if step == 0 else "Better" if step < 30 else "Much improved!"
    print(f"Step {step:2d}: '{sample}' ({interpretation})")

print(f"\n✨ Loss improved from {losses[0]:.3f} to {losses[-1]:.3f}")
print("The transformer learned to complete sentences!")

In [None]:
## When to Stop Training

Learn to recognize the signals that training is complete.

# Simulate training completion signals
steps = np.arange(100)
loss_curve = 3.0 * np.exp(-steps/20) + 1.0 + 0.1 * np.sin(steps/5) * np.exp(-steps/30)

# Sample progression showing quality improvement
sample_progression = [
    (0, "Random gibberish"),      
    (20, "The cat dog tree"),       
    (40, "The cat sat on"),         
    (60, "The cat sat on the mat"), 
    (80, "The cat sat on the mat."),
    (99, "The cat sat on the mat.")
]

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Loss curve showing convergence
ax1.plot(steps, loss_curve, 'b-', linewidth=2)
ax1.set_title('🎯 Signal 1: Loss Converges (Stops Decreasing)')
ax1.set_xlabel('Training Step')
ax1.set_ylabel('Loss')
ax1.grid(True, alpha=0.3)

ax1.annotate('📉 Rapid learning', xy=(20, loss_curve[20]), 
            xytext=(30, loss_curve[20] + 0.5),
            arrowprops=dict(arrowstyle='->', color='red'))
ax1.annotate('🎯 Converged - ready to stop!', xy=(80, loss_curve[80]), 
            xytext=(60, loss_curve[80] + 0.3),
            arrowprops=dict(arrowstyle='->', color='green'))

# Sample quality improvement
ax2.set_xlim(0, 100)
ax2.set_ylim(-0.5, len(sample_progression) - 0.5)
ax2.set_title('🎭 Signal 2: Generated Text Quality Improves')
ax2.set_xlabel('Training Step')

for i, (step, sample) in enumerate(sample_progression):
    color = 'red' if i < 2 else 'orange' if i < 4 else 'green'
    ax2.scatter(step, i, s=100, color=color, alpha=0.7)
    ax2.text(step + 2, i, f'"{sample}"', fontsize=10, 
            verticalalignment='center', color=color, fontweight='bold')

ax2.set_yticks(range(len(sample_progression)))
ax2.set_yticklabels(['Terrible', 'Bad', 'Getting there', 'Good', 'Great!', 'Perfect'])
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("🛑 Simple stopping rules:")
print("1. 📉 LOSS RULE: Stop when loss plateaus for 20+ steps")
print("2. 🎯 QUALITY RULE: Stop when generated text looks reasonable")
print("3. ⏰ TIME RULE: Stop after reasonable training time")
print("4. 💾 SAFETY RULE: Save best model, stop if no improvement")

print("\n⚠️ Warning: Don't overtrain!")
print("Overtraining = memorizing instead of understanding")
print("Like cramming vs learning - works on test, fails in real world")

## Summary

You now understand how transformers learn!

**The Training Recipe**:
1. **Forward pass**: Model makes predictions
2. **Loss calculation**: Measure prediction errors
3. **Backward pass**: Calculate weight adjustments  
4. **Weight update**: Apply adjustments and repeat

**Key Insights**:
- **Training = Practice**: Like learning piano through repetition
- **Loss = Confusion**: Lower loss means better understanding
- **Learning = Adjustment**: Small weight changes create intelligence
- **Stopping = Recognition**: Know when the model has learned enough

**What's Next**: This same process creates ChatGPT, GPT-4, and all language models - just at massive scale with billions of parameters!

Ready to explore text generation techniques! 🚀

In [None]:
def demonstrate_overfitting():
    """Show how overfitting manifests in language models."""
    
    print("Overfitting in Language Models")
    print("=" * 35)
    
    # Create a tiny dataset to encourage overfitting
    tiny_text = "The cat sat on the mat. The cat sat on the hat."
    tokenizer = create_tokenizer("simple")
    
    # Encode the text
    tokens = tokenizer.encode(tiny_text, add_special_tokens=False)
    print(f"Training text: '{tiny_text}'")
    print(f"Tokens: {tokens}")
    print(f"Unique tokens: {len(set(tokens))}")
    
    # Create dataset
    dataset = SimpleTextDataset(tiny_text, tokenizer, block_size=8)
    dataloader = create_dataloader(dataset, batch_size=2, shuffle=False)
    
    print(f"Training samples: {len(dataset)}")
    
    # Show samples
    print("\nTraining samples:")
    for i, (input_ids, target_ids) in enumerate(dataloader):
        if i < 3:  # Show first 3 batches
            print(f"Batch {i}: input shape {input_ids.shape}")
            for j in range(input_ids.shape[0]):
                input_text = tokenizer.decode(input_ids[j].tolist(), skip_special_tokens=True)
                target_text = tokenizer.decode(target_ids[j].tolist(), skip_special_tokens=True)
                print(f"  Input:  '{input_text}'")
                print(f"  Target: '{target_text}'")
    
    print("\nSigns of Overfitting:")
    print("• Training loss goes to zero but validation loss increases")
    print("• Model memorizes training data exactly")
    print("• Poor generalization to new text")
    print("• Generated text becomes repetitive or nonsensical")
    
    return dataset, dataloader

tiny_dataset, tiny_dataloader = demonstrate_overfitting()

def demonstrate_regularization_techniques():
    """Show different regularization methods."""
    
    print("\nRegularization Techniques for Transformers")
    print("=" * 45)
    
    techniques = {
        "Dropout": {
            "description": "Randomly zero out neurons during training",
            "implementation": "nn.Dropout(p=0.1) in attention and FFN",
            "effect": "Prevents co-adaptation of neurons"
        },
        "Weight Decay": {
            "description": "Add L2 penalty to weights",
            "implementation": "weight_decay=0.01 in optimizer",
            "effect": "Keeps weights small, improves generalization"
        },
        "Gradient Clipping": {
            "description": "Limit gradient magnitude",
            "implementation": "clip_grad_norm_(params, max_norm=1.0)",
            "effect": "Prevents exploding gradients"
        },
        "Early Stopping": {
            "description": "Stop when validation loss stops improving",
            "implementation": "Monitor validation loss, save best model",
            "effect": "Prevents overfitting to training data"
        },
        "Data Augmentation": {
            "description": "Increase effective dataset size",
            "implementation": "Paraphrasing, back-translation, masking",
            "effect": "More diverse training examples"
        }
    }
    
    for technique, info in techniques.items():
        print(f"\n{technique}:")
        print(f"  Description: {info['description']}")
        print(f"  Implementation: {info['implementation']}")
        print(f"  Effect: {info['effect']}")
    
    # Visualize dropout effect
    print("\nDropout Visualization:")
    
    # Simulate dropout on a tensor
    x = torch.ones(4, 8)  # 4x8 tensor of ones
    dropout = nn.Dropout(p=0.3)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Original
    axes[0].imshow(x, cmap='Blues')
    axes[0].set_title('Original Activations')
    axes[0].set_xlabel('Feature')
    axes[0].set_ylabel('Example')
    
    # With dropout (training mode)
    dropout.train()
    x_dropout = dropout(x)
    axes[1].imshow(x_dropout, cmap='Blues')
    axes[1].set_title('With Dropout (Training)')
    axes[1].set_xlabel('Feature')
    
    # Without dropout (eval mode)
    dropout.eval()
    x_eval = dropout(x)
    axes[2].imshow(x_eval, cmap='Blues')
    axes[2].set_title('Without Dropout (Evaluation)')
    axes[2].set_xlabel('Feature')
    
    plt.tight_layout()
    plt.show()
    
    print("Notice how dropout randomly zeros neurons during training but not evaluation!")

demonstrate_regularization_techniques()