# Training Your First Transformer: Like Learning to Cook

Training a transformer is like learning to cook. You start with a recipe (model), try it out, see how it tastes (check the loss), and then adjust your technique to make it better next time.

## What You'll Learn

1. **What is Training?** - Think of it like practicing piano - getting better with each attempt
2. **The Training Loop** - Make a prediction, see how wrong you are, adjust, repeat
3. **Watching Learning Happen** - See your model get smarter in real-time
4. **When to Stop** - Know when your model has learned enough

By the end, you'll watch your transformer learn to predict text and understand exactly what's happening at each step!

Let's start cooking! 👩‍🍳

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Dict
import time
from tqdm import tqdm
import math

# Set style for better plots
plt.style.use('default')
sns.set_palette("husl")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Environment setup complete!")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## 1. What is Training? 🎯

**Training = Practice Makes Perfect**

Imagine you're learning to play piano. At first, you hit wrong notes. But with each practice session:
1. You play a song (make a prediction)
2. Your teacher says "that note was wrong" (calculate loss)
3. You adjust your finger placement (update weights)
4. You try again and get better!

**For transformers, it's the same:**
- The transformer predicts the next word
- We check if it's right or wrong (that's the "loss")
- We nudge the model to be slightly better
- Repeat thousands of times until it's good!

Let's see this in action with the simplest possible example:

In [None]:
def demonstrate_learning_concept():
    """Show learning in the simplest possible way."""
    
    print("🎹 LEARNING PIANO ANALOGY")
    print("=" * 30)
    
    # Simulate "learning" with simple numbers
    attempts = ["Wrong note!", "Better!", "Almost!", "Perfect!"]
    scores = [8.5, 6.2, 3.1, 1.0]  # Lower is better (like loss)
    
    print("Piano practice sessions:")
    for i, (attempt, score) in enumerate(zip(attempts, scores)):
        print(f"Practice {i+1}: {attempt:<12} (Mistake level: {score})")
    
    print(f"\nSee how the mistake level goes down? That's learning!")
    print(f"Session 1: {scores[0]} mistakes → Session 4: {scores[-1]} mistakes")
    
    print("\n🤖 TRANSFORMER LEARNING (SAME IDEA)")
    print("=" * 40)
    
    sentences = [
        "The cat sat on the ???",  # What should come next?
        "The cat sat on the mat",  # This is the right answer!
    ]
    
    predictions = ["dog", "table", "floor", "mat"]  # Model guesses
    confidence = [20, 45, 60, 95]  # How sure the model is
    
    print("Transformer trying to predict next word:")
    print(f"Sentence: '{sentences[0]}'")
    print(f"Correct answer: '{sentences[1].split()[-1]}'")
    print()
    
    for i, (pred, conf) in enumerate(zip(predictions, confidence)):
        status = "✅ CORRECT!" if pred == "mat" else "❌ Wrong"
        print(f"Attempt {i+1}: Predicts '{pred}' ({conf}% sure) {status}")
    
    print(f"\n💡 KEY INSIGHT: Each wrong guess teaches the model!")
    print(f"   The model learns: 'When I see \"on the\", try \"mat\" not \"dog\"'")

demonstrate_learning_concept()

## 2. The Training Loop - Just 4 Simple Steps! 

**Think of it like this simple recipe:**
1. 🍳 **Make a prediction** (like cracking an egg)
2. 🧑‍🍳 **Check how wrong you are** (taste it - too salty?)
3. ⚡ **Adjust slightly** (add less salt next time)
4. 🔄 **Repeat** until it tastes perfect!

Let's see the actual code for these 4 steps:

In [None]:
# Setup: Create the simplest possible transformer for demonstration

# Get a tiny model (easier to understand)
config = create_model_config("tiny")  # Very small for demo
model = GPTModel(**config)

print("🎯 TRAINING SETUP")
print("=" * 20)
print("Our tiny transformer:")
print(f"• Has {model.get_num_params():,} parameters to learn")
print(f"• Knows {config['vocab_size']:,} different words")
print(f"• Can remember {config['max_len']} words at once")
print()

# Simple training data - just one sentence!
text = "The cat sat on the mat"
tokens = tokenizer.encode(text, add_special_tokens=False)
inputs = torch.tensor(tokens[:-1]).unsqueeze(0)   # All but last token
targets = torch.tensor(tokens[1:]).unsqueeze(0)   # All but first token

print(f"📝 Training text: '{text}'")
print("Model will learn to predict each next word:")
for i in range(len(tokens)-1):
    input_word = tokenizer.decode([tokens[i]])
    target_word = tokenizer.decode([tokens[i+1]])
    print(f"   '{input_word}' → predict '{target_word}'")
print()

def simple_training_step(model, inputs, targets, optimizer):
    """The 4 steps of training - clearly separated!"""
    
    print("🔄 TRAINING STEP")
    print("-" * 15)
    
    # Step 1: Make a prediction 🍳
    logits, _ = model(inputs)  # Model predicts next word probabilities
    print("1️⃣ Made prediction ✓")
    
    # Step 2: Check how wrong we are 🧑‍🍳  
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    print(f"2️⃣ Calculated loss: {loss.item():.3f} (lower = better)")
    
    # Step 3: Adjust slightly ⚡
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print("3️⃣ Updated model weights ✓")
    
    # Step 4 happens when we call this function again! 🔄
    print("4️⃣ Ready for next step!")
    print()
    
    return loss.item()

# Create optimizer (the thing that adjusts the model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print("🚀 Let's do 3 training steps and watch the loss go down!")
print("=" * 50)

# Do just 3 steps so we can see each one clearly
losses = []
for step in range(3):
    print(f"\n📍 STEP {step + 1}/3")
    loss = simple_training_step(model, inputs, targets, optimizer)
    losses.append(loss)
    
    # Show progress
    if step == 0:
        print("👶 First step - model is just guessing randomly")
    elif step == 1:
        print("🧠 Second step - model starts to learn patterns")
    else:
        print("🎯 Third step - model is getting better!")

print(f"\n📊 PROGRESS SUMMARY")
print("=" * 20)
for i, loss in enumerate(losses):
    trend = ""
    if i > 0:
        if loss < losses[i-1]:
            trend = "📉 (Getting better!)"
        else:
            trend = "📈 (Still learning...)"
    print(f"Step {i+1}: Loss = {loss:.3f} {trend}")

print(f"\n🎉 Did the loss go down? That means learning happened!")
if losses[-1] < losses[0]:
    print(f"✅ YES! From {losses[0]:.3f} to {losses[-1]:.3f} - That's improvement!")
else:
    print(f"⏳ Not quite yet, but that's normal! Training takes many steps.")
    
print("\n💡 KEY INSIGHT:")
print("This is exactly what happens during real transformer training,")
print("but repeated millions of times with millions of text examples!")

## 3. Watching Learning Happen - The Magic Moment! ✨

**This is where the magic happens!**

We've seen the 4 simple steps. Now let's repeat them hundreds of times and watch our transformer get smarter in real-time! It's like watching a baby learn to talk.

**What to expect:**
- Loss will start high (model is confused) 
- Loss drops rapidly ("Aha!" moments)
- Generated text gets more coherent
- Model learns patterns in the language

In [None]:
# Let's create more training data for more realistic learning
training_text = """
The cat sat on the mat.
The dog ran in the park.
The bird flew over the tree.
The fish swam in the pond.
The sun shines in the sky.
The moon glows at night.
The children play outside.
The flowers bloom in spring.
"""

# Create dataset and dataloader
from src.data.dataset import SimpleTextDataset, create_dataloader
dataset = SimpleTextDataset(training_text, tokenizer, block_size=8)
dataloader = create_dataloader(dataset, batch_size=2, shuffle=True)

print("🎓 LEARNING SETUP")
print("=" * 20)
print(f"Training data: {len(training_text)} characters")
print(f"Training samples: {len(dataset)} examples")
print("Sample sentences the model will learn from:")
for line in training_text.strip().split('\n'):
    if line.strip():
        print(f"  • {line.strip()}")

def watch_learning_happen(model, dataloader, optimizer, num_steps=20):
    """Watch the model learn step by step!"""
    
    print("\n🔬 WATCHING LEARNING HAPPEN")
    print("=" * 30)
    
    losses = []
    samples = []
    
    # Test before training
    model.eval()
    with torch.no_grad():
        sample_input = torch.tensor(tokenizer.encode("The cat", add_special_tokens=False)[:2]).unsqueeze(0)
        sample_before = model.generate(sample_input, max_new_tokens=8, temperature=1.0, do_sample=True)
        sample_before_text = tokenizer.decode(sample_before[0].tolist(), skip_special_tokens=True)
    
    print(f"Before training: 'The cat' → '{sample_before_text}'")
    print("(Probably nonsense - model doesn't know anything yet!)\n")
    
    # Train and observe
    step = 0
    data_iter = iter(dataloader)
    
    for _ in range(num_steps):
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)  # Reset iterator
            batch = next(data_iter)
        
        # Training step
        model.train()
        input_ids, target_ids = batch
        
        optimizer.zero_grad()
        logits, _ = model(input_ids)
        
        loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_ids.view(-1))
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        # Generate sample every 5 steps
        if step % 5 == 0:
            model.eval()
            with torch.no_grad():
                sample = model.generate(sample_input, max_new_tokens=8, temperature=1.0, do_sample=True)
                sample_text = tokenizer.decode(sample[0].tolist(), skip_special_tokens=True)
                samples.append((step, sample_text))
                
                # Show progress with interpretation
                if step == 0:
                    interpretation = "🤔 Still very confused"
                elif step <= 5:
                    interpretation = "🧠 Starting to recognize patterns"
                elif step <= 10:
                    interpretation = "💡 Learning word relationships"
                else:
                    interpretation = "🎯 Getting much better!"
                    
                print(f"Step {step:2d}: Loss={loss.item():.3f} | Sample: '{sample_text}' | {interpretation}")
        
        step += 1
    
    # Final comparison
    print("\n📊 LEARNING PROGRESS SUMMARY")
    print("=" * 30)
    print(f"Initial loss: {losses[0]:.3f} (very confused)")
    print(f"Final loss:   {losses[-1]:.3f} (much better!)")
    print(f"Improvement:  {losses[0] - losses[-1]:.3f} loss reduction")
    
    print("\n🎭 GENERATION SAMPLES OVER TIME:")
    for step, sample in samples:
        print(f"Step {step:2d}: '{sample}'")
    
    print("\n✨ Notice how the text becomes more coherent?")
    print("That's the transformer learning language patterns!")
    
    return losses, samples

# Watch the magic happen!
losses, samples = watch_learning_happen(model, dataloader, optimizer, num_steps=20)

# Plot the learning curve
plt.figure(figsize=(10, 6))
plt.plot(losses, 'b-', linewidth=2, alpha=0.8)
plt.title('🧠 Watching the Transformer Get Smarter', fontsize=14, fontweight='bold')
plt.xlabel('Training Step')
plt.ylabel('Loss (Confusion Level)')
plt.grid(True, alpha=0.3)

# Add annotations
if len(losses) > 5:
    plt.annotate('🤔 "What is language?"', xy=(0, losses[0]), 
                xytext=(2, losses[0] + 0.5), fontsize=12,
                arrowprops=dict(arrowstyle='->', color='red', alpha=0.7))
                
    plt.annotate('🎯 "I\'m getting this!"', xy=(len(losses)-1, losses[-1]), 
                xytext=(len(losses)-5, losses[-1] + 0.3), fontsize=12,
                arrowprops=dict(arrowstyle='->', color='green', alpha=0.7))

plt.show()

print("\n🎉 Congratulations! You just watched a transformer learn!")
print("This is exactly how ChatGPT, GPT-4, and all language models learn.")
print("The only difference: they do this with billions of examples!")

## 4. When to Stop - "My Model is Trained!" 🏁

**How do you know when your transformer is done learning?**

Think of it like learning to drive:
- **Beginner**: Can't even start the car (high loss, bad predictions)
- **Learning**: Gets better every lesson (loss going down)
- **Competent**: Can drive safely (good loss, sensible predictions)
- **Expert**: Could teach others (very low loss)

**The 3 signals that training is done:**
1. **Loss stops improving** - "I'm not getting better anymore"
2. **Generated text looks good** - "My outputs make sense!"
3. **Validation loss levels off** - "I'm not just memorizing"

In [None]:
# Let's see what "done training" looks like
def show_training_completion_signals():
    """Show the 3 key signals that training is complete."""
    
    print("🏁 RECOGNIZING WHEN TRAINING IS DONE")
    print("=" * 40)
    
    # Simulate a complete training run
    steps = np.arange(100)
    
    # 1. Loss curve that levels off
    loss_curve = 3.0 * np.exp(-steps/20) + 1.0 + 0.1 * np.sin(steps/5) * np.exp(-steps/30)
    
    # 2. Sample quality over time
    sample_quality = [
        (0, "ghjk qwerty random"),      # Random noise
        (20, "The cat dog tree"),       # Some words
        (40, "The cat sat on"),         # Getting structure
        (60, "The cat sat on the mat"), # Good!
        (80, "The cat sat on the mat."), # Perfect!
        (99, "The cat sat on the mat.") # Stable
    ]
    
    # Plot training completion
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))
    
    # Loss curve
    ax1.plot(steps, loss_curve, 'b-', linewidth=2, alpha=0.8)
    ax1.set_title('🎯 Signal 1: Loss Stops Improving', fontweight='bold')
    ax1.set_xlabel('Training Step')
    ax1.set_ylabel('Loss')
    ax1.grid(True, alpha=0.3)
    
    # Annotate key points
    ax1.annotate('📉 Rapid learning', xy=(20, loss_curve[20]), 
                xytext=(30, loss_curve[20] + 0.5), fontsize=11,
                arrowprops=dict(arrowstyle='->', color='red'))
    ax1.annotate('🎯 Converged!', xy=(80, loss_curve[80]), 
                xytext=(60, loss_curve[80] + 0.3), fontsize=11,
                arrowprops=dict(arrowstyle='->', color='green'))
    
    # Sample quality timeline
    ax2.set_xlim(0, 100)
    ax2.set_ylim(-0.5, len(sample_quality) - 0.5)
    ax2.set_title('🎭 Signal 2: Generated Text Gets Good', fontweight='bold')
    ax2.set_xlabel('Training Step')
    ax2.set_ylabel('Sample Quality')
    
    for i, (step, sample) in enumerate(sample_quality):
        color = 'red' if i < 2 else 'orange' if i < 4 else 'green'
        ax2.scatter(step, i, s=100, color=color, alpha=0.7)
        ax2.text(step + 2, i, f'"{sample}"', fontsize=10, 
                verticalalignment='center', color=color, fontweight='bold')
    
    ax2.set_yticks(range(len(sample_quality)))
    ax2.set_yticklabels(['Terrible', 'Bad', 'Getting there', 'Good', 'Great!', 'Perfect'])
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("🔍 HOW TO RECOGNIZE EACH SIGNAL:")
    print("-" * 35)
    print("1. 📉 LOSS STOPS IMPROVING:")
    print("   • Loss curve flattens out")
    print("   • No significant decrease for many steps")
    print("   • Loss might slightly oscillate but doesn't go down")
    print()
    print("2. 🎭 GENERATED TEXT LOOKS GOOD:")
    print("   • Output makes grammatical sense")
    print("   • Follows patterns from training data")
    print("   • No more random gibberish")
    print()
    print("3. 🚧 VALIDATION LOSS LEVELS OFF:")
    print("   • Test on unseen data")
    print("   • If validation loss stops decreasing: you're done!")
    print("   • If validation loss increases: you're overfitting!")
    
    print("\n⚠️  IMPORTANT WARNING:")
    print("Don't train too long! Overtraining makes models worse.")
    print("It's like studying for a test by memorizing instead of understanding.")

show_training_completion_signals()

def demonstrate_simple_stopping_criteria():
    """Show simple rules for when to stop training."""
    
    print("\n🛑 SIMPLE STOPPING RULES")
    print("=" * 25)
    
    print("For beginners, use these simple rules:")
    print()
    print("1. 🕐 TIME RULE:")
    print("   • Stop after X hours of training")
    print("   • Good for experiments and learning")
    print()
    print("2. 📉 LOSS RULE:")
    print("   • Stop when loss < some threshold (like 2.0)")
    print("   • Stop when loss stops decreasing for 50+ steps")
    print()
    print("3. 🎯 GENERATION RULE:")
    print("   • Stop when generated samples look reasonable")
    print("   • Test generation every 100 steps")
    print()
    print("4. 💾 PATIENCE RULE:")
    print("   • Save model every time validation improves")
    print("   • Stop if no improvement for 10 saves")
    print("   • Use the best saved model")
    
    print("\n💡 PRO TIP:")
    print("Always save your model regularly!")
    print("You never know when training might crash or when you've hit the sweet spot.")

demonstrate_simple_stopping_criteria()

## Summary: You Now Know How Transformers Learn! 🎓

**Congratulations!** You've just learned the core of how ALL language models work:

**The Magic Recipe:**
1. **Show the model some text** ("The cat sat on the")
2. **Ask it to predict the next word** ("mat")
3. **Tell it how wrong it was** (loss = wrongness)
4. **Nudge it to be slightly better** (adjust weights)
5. **Repeat millions of times** (practice makes perfect!)

That's it! This simple process creates ChatGPT, GPT-4, and every transformer.

In [None]:
def summarize_what_we_learned():
    """Recap the key insights about transformer training."""
    
    print("🎓 WHAT YOU NOW UNDERSTAND")
    print("=" * 30)
    
    key_insights = [
        ("🎯", "Training = next word prediction", "The model learns by guessing the next word millions of times"),
        ("📉", "Loss = wrongness level", "Lower loss means better predictions (less confused)"),
        ("🔄", "Learning = small adjustments", "Each mistake teaches the model to be slightly better"),
        ("⏰", "Time scale matters", "Real models train for days/weeks on massive datasets"),
        ("🎭", "Generated text reveals progress", "Watch outputs to see if the model is learning language")
    ]
    
    for emoji, concept, explanation in key_insights:
        print(f"{emoji} {concept:25} → {explanation}")
    
    print("\n🌟 THE BIG PICTURE")
    print("=" * 20)
    print("This simple process creates the most advanced AI:")
    print("• ChatGPT: Trained on internet text for months")
    print("• GPT-4: 1.76 trillion parameters, massive compute")
    print("• Claude: Advanced training with human feedback")
    print("• Your model: Same principles, smaller scale!")
    
    print("\n🚀 WHAT'S NEXT?")
    print("=" * 15)
    print("Now that you understand training, you can:")
    print("• Scale up: Use bigger models and more data")
    print("• Experiment: Try different architectures")
    print("• Optimize: Improve training efficiency")
    print("• Deploy: Use your trained model for tasks")
    
    print("\n🎉 YOU'RE NOW A TRANSFORMER EXPERT!")
    print("You understand the core of how language AI works.")
    print("The rest is just engineering and scaling!")

summarize_what_we_learned()

# Show one final learning visualization
def final_learning_demonstration():
    """One last simple demo to cement understanding."""
    
    print("\n🔬 FINAL DEMO: LEARNING IN ACTION")
    print("=" * 35)
    
    # Simple demo of loss going down
    steps = [1, 10, 50, 100, 500]
    losses = [10.5, 8.2, 4.1, 2.8, 1.9]
    examples = [
        "ajsdh qwehjk asdf",
        "The dog cat", 
        "The cat sat",
        "The cat sat on the",
        "The cat sat on the mat"
    ]
    
    print("Watch how loss drops as the model gets smarter:")
    print()
    print("Step | Loss | Sample Output        | Model's Thoughts")
    print("-" * 60)
    
    thoughts = [
        "😵 What even is language??",
        "🤔 These are probably words...",
        "💡 I see patterns emerging!",
        "🧠 I'm understanding structure!",
        "🎯 I've got this!"
    ]
    
    for step, loss, example, thought in zip(steps, losses, examples, thoughts):
        print(f"{step:4} | {loss:4.1f} | {example:20} | {thought}")
    
    print("\n✨ That's machine learning in action!")
    print("From confusion to competence through practice.")

final_learning_demonstration()

In [None]:
def demonstrate_overfitting():
    """Show how overfitting manifests in language models."""
    
    print("Overfitting in Language Models")
    print("=" * 35)
    
    # Create a tiny dataset to encourage overfitting
    tiny_text = "The cat sat on the mat. The cat sat on the hat."
    tokenizer = create_tokenizer("simple")
    
    # Encode the text
    tokens = tokenizer.encode(tiny_text, add_special_tokens=False)
    print(f"Training text: '{tiny_text}'")
    print(f"Tokens: {tokens}")
    print(f"Unique tokens: {len(set(tokens))}")
    
    # Create dataset
    dataset = SimpleTextDataset(tiny_text, tokenizer, block_size=8)
    dataloader = create_dataloader(dataset, batch_size=2, shuffle=False)
    
    print(f"Training samples: {len(dataset)}")
    
    # Show samples
    print("\nTraining samples:")
    for i, (input_ids, target_ids) in enumerate(dataloader):
        if i < 3:  # Show first 3 batches
            print(f"Batch {i}: input shape {input_ids.shape}")
            for j in range(input_ids.shape[0]):
                input_text = tokenizer.decode(input_ids[j].tolist(), skip_special_tokens=True)
                target_text = tokenizer.decode(target_ids[j].tolist(), skip_special_tokens=True)
                print(f"  Input:  '{input_text}'")
                print(f"  Target: '{target_text}'")
    
    print("\nSigns of Overfitting:")
    print("• Training loss goes to zero but validation loss increases")
    print("• Model memorizes training data exactly")
    print("• Poor generalization to new text")
    print("• Generated text becomes repetitive or nonsensical")
    
    return dataset, dataloader

tiny_dataset, tiny_dataloader = demonstrate_overfitting()

def demonstrate_regularization_techniques():
    """Show different regularization methods."""
    
    print("\nRegularization Techniques for Transformers")
    print("=" * 45)
    
    techniques = {
        "Dropout": {
            "description": "Randomly zero out neurons during training",
            "implementation": "nn.Dropout(p=0.1) in attention and FFN",
            "effect": "Prevents co-adaptation of neurons"
        },
        "Weight Decay": {
            "description": "Add L2 penalty to weights",
            "implementation": "weight_decay=0.01 in optimizer",
            "effect": "Keeps weights small, improves generalization"
        },
        "Gradient Clipping": {
            "description": "Limit gradient magnitude",
            "implementation": "clip_grad_norm_(params, max_norm=1.0)",
            "effect": "Prevents exploding gradients"
        },
        "Early Stopping": {
            "description": "Stop when validation loss stops improving",
            "implementation": "Monitor validation loss, save best model",
            "effect": "Prevents overfitting to training data"
        },
        "Data Augmentation": {
            "description": "Increase effective dataset size",
            "implementation": "Paraphrasing, back-translation, masking",
            "effect": "More diverse training examples"
        }
    }
    
    for technique, info in techniques.items():
        print(f"\n{technique}:")
        print(f"  Description: {info['description']}")
        print(f"  Implementation: {info['implementation']}")
        print(f"  Effect: {info['effect']}")
    
    # Visualize dropout effect
    print("\nDropout Visualization:")
    
    # Simulate dropout on a tensor
    x = torch.ones(4, 8)  # 4x8 tensor of ones
    dropout = nn.Dropout(p=0.3)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Original
    axes[0].imshow(x, cmap='Blues')
    axes[0].set_title('Original Activations')
    axes[0].set_xlabel('Feature')
    axes[0].set_ylabel('Example')
    
    # With dropout (training mode)
    dropout.train()
    x_dropout = dropout(x)
    axes[1].imshow(x_dropout, cmap='Blues')
    axes[1].set_title('With Dropout (Training)')
    axes[1].set_xlabel('Feature')
    
    # Without dropout (eval mode)
    dropout.eval()
    x_eval = dropout(x)
    axes[2].imshow(x_eval, cmap='Blues')
    axes[2].set_title('Without Dropout (Evaluation)')
    axes[2].set_xlabel('Feature')
    
    plt.tight_layout()
    plt.show()
    
    print("Notice how dropout randomly zeros neurons during training but not evaluation!")

demonstrate_regularization_techniques()

In [None]:
def comprehensive_training_monitoring():
    """Demonstrate comprehensive training monitoring."""
    
    print("Comprehensive Training Monitoring")
    print("=" * 40)
    
    # Create a slightly larger model for realistic monitoring
    config = create_model_config("small")
    config["vocab_size"] = 200
    config["n_layers"] = 3  # Smaller for faster training
    model = GPTModel(**config)
    
    # More comprehensive training text
    training_text = """
    The transformer architecture revolutionized natural language processing.
    Attention mechanisms allow models to focus on relevant parts of the input.
    Large language models demonstrate emergent capabilities at scale.
    Training requires careful optimization and regularization techniques.
    Deep learning continues to advance the field of artificial intelligence.
    Neural networks learn complex patterns from vast amounts of data.
    Machine learning algorithms can generalize to unseen examples.
    The future of AI depends on responsible development and deployment.
    """
    
    tokenizer = create_tokenizer("simple")
    dataset = SimpleTextDataset(training_text, tokenizer, block_size=24)
    dataloader = create_dataloader(dataset, batch_size=3, shuffle=True)
    
    # Setup optimizer with proper settings
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01, betas=(0.9, 0.95))
    criterion = nn.CrossEntropyLoss()
    
    # Learning rate scheduler
    total_steps = len(dataloader) * 3  # 3 epochs
    scheduler = CosineWarmupScheduler(optimizer, warmup_steps=10, max_steps=total_steps, base_lr=3e-4)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
    print(f"Dataset: {len(dataset)} samples")
    print(f"Total training steps: {total_steps}")
    
    # Training loop with comprehensive monitoring
    metrics = {
        'steps': [],
        'losses': [],
        'learning_rates': [],
        'grad_norms': [],
        'weight_norms': [],
        'samples': []
    }
    
    step = 0
    
    for epoch in range(3):
        print(f"\nEpoch {epoch + 1}/3")
        
        for batch_idx, batch in enumerate(dataloader):
            # Training step
            model.train()
            input_ids, target_ids = batch
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)
            
            optimizer.zero_grad()
            logits, _ = model(input_ids)
            
            # Calculate loss
            loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
            loss.backward()
            
            # Monitor gradients
            total_grad_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    total_grad_norm += p.grad.data.norm(2).item() ** 2
            total_grad_norm = total_grad_norm ** 0.5
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            current_lr = scheduler.step()
            
            # Monitor weights
            total_weight_norm = 0
            for p in model.parameters():
                total_weight_norm += p.data.norm(2).item() ** 2
            total_weight_norm = total_weight_norm ** 0.5
            
            # Record metrics
            metrics['steps'].append(step)
            metrics['losses'].append(loss.item())
            metrics['learning_rates'].append(current_lr)
            metrics['grad_norms'].append(total_grad_norm)
            metrics['weight_norms'].append(total_weight_norm)
            
            # Generate sample every 10 steps
            if step % 10 == 0:
                sample = generate_sample(model, tokenizer, "The", max_length=10)
                metrics['samples'].append((step, sample))
                print(f"Step {step:3d}: Loss={loss.item():.3f}, LR={current_lr:.2e}, Sample='{sample}'")
            
            step += 1
    
    return metrics

# Run comprehensive training
metrics = comprehensive_training_monitoring()

# Plot comprehensive metrics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss
axes[0, 0].plot(metrics['steps'], metrics['losses'], 'b-', alpha=0.7)
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training Loss')
axes[0, 0].grid(True, alpha=0.3)

# Learning rate
axes[0, 1].plot(metrics['steps'], metrics['learning_rates'], 'r-', alpha=0.7)
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Learning Rate')
axes[0, 1].set_title('Learning Rate Schedule')
axes[0, 1].grid(True, alpha=0.3)

# Gradient norms
axes[1, 0].plot(metrics['steps'], metrics['grad_norms'], 'g-', alpha=0.7)
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Gradient Norm')
axes[1, 0].set_title('Gradient Norms')
axes[1, 0].grid(True, alpha=0.3)

# Weight norms
axes[1, 1].plot(metrics['steps'], metrics['weight_norms'], 'm-', alpha=0.7)
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Weight Norm')
axes[1, 1].set_title('Weight Norms')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nTraining Quality Indicators:")
print(f"• Loss decreased from {metrics['losses'][0]:.3f} to {metrics['losses'][-1]:.3f}")
print(f"• Gradient norms: {np.mean(metrics['grad_norms']):.3f} (should be stable, not too large)")
print(f"• Weight norms growing: {metrics['weight_norms'][-1] > metrics['weight_norms'][0]} (expected during training)")
print(f"• Learning rate properly scheduled: {metrics['learning_rates'][0]:.2e} → {metrics['learning_rates'][-1]:.2e}")