# Debugging Transformer Training

Training transformers is like learning to drive - things can go wrong in predictable ways. This notebook teaches you to diagnose problems and fix them quickly.

## Why Training Fails

Neural network training fails for fundamental reasons:

1. **Gradient Problems**: Gradients become too large (exploding) or too small (vanishing)
2. **Learning Rate Issues**: Too high causes instability, too low causes stagnation
3. **Data Problems**: Poor quality, wrong size, or insufficient quantity
4. **Architecture Issues**: Model too deep, bad initialization, or wrong configuration

## What You'll Master

1. **Recognize Symptoms**: Learn the warning signs of different failures
2. **Monitor Health**: Track gradient and loss patterns during training
3. **Diagnose Problems**: Systematically identify root causes
4. **Apply Fixes**: Know exactly how to solve each problem

Think of this as becoming a "transformer doctor" - diagnosing symptoms and prescribing cures!

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
import warnings
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
import time

from src.model.transformer import GPTModel, create_model_config
from src.data.tokenizer import create_tokenizer
from src.data.dataset import SimpleTextDataset, create_dataloader

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('default')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("Debugging toolkit loaded! 🔧")

## 1. The Big Three Training Failures

### Understanding Why Training Breaks

Every training failure falls into three categories. Understanding the physics behind each helps you diagnose problems instantly:

### A. Exploding Gradients
**The Physics**: During backpropagation, gradients multiply through layers. In deep networks, small errors can compound exponentially, like compound interest in reverse.

**What Happens**: 
- Loss suddenly jumps to infinity
- Model parameters become NaN (Not a Number)
- Training completely breaks

**Root Causes**:
- Learning rate too high (most common)
- Poor weight initialization
- No gradient clipping
- Unstable operations in model

### B. Vanishing Gradients  
**The Physics**: Gradients become smaller as they travel backward through layers. Eventually they become so small that early layers stop learning entirely.

**What Happens**:
- Loss decreases extremely slowly or plateaus
- Early layers don't learn (parameters barely change)
- Model underperforms despite long training

**Root Causes**:
- Weights initialized too small
- Activation functions that saturate (like sigmoid)
- Too many layers without residual connections

### C. Wrong Learning Rate
**The Physics**: Learning rate controls step size in the loss landscape. Too big and you overshoot the minimum; too small and you barely move.

**What Happens**:
- Too high: Loss oscillates wildly
- Too low: Painfully slow convergence
- Just right: Smooth, steady improvement

Let's simulate each failure to see their signatures:

In [None]:
# Setup: Create a simple training environment to simulate failures
import sys
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import math
from collections import defaultdict

from src.model.transformer import GPTModel

# Simple setup for demonstrations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create a tiny model for fast demonstrations
config = {
    'vocab_size': 100,
    'd_model': 64,
    'n_heads': 4,
    'n_layers': 2,
    'd_ff': 128,
    'max_seq_len': 32,
    'dropout': 0.1
}

def create_dummy_batch(batch_size=4, seq_len=16):
    """Create dummy training data for experiments."""
    x = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
    targets = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
    return x, targets

def run_training_experiment(model, optimizer, steps=20, experiment_name=""):
    """Run a training experiment and collect loss/gradient data."""
    print(f"\n🧪 {experiment_name}")
    print("-" * 40)
    
    criterion = nn.CrossEntropyLoss()
    losses = []
    grad_norms = []
    
    for step in range(steps):
        x, targets = create_dummy_batch()
        
        optimizer.zero_grad()
        
        try:
            outputs = model(x)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
            
            # Check if loss is reasonable
            if not torch.isfinite(loss) or loss.item() > 100:
                print(f"💥 Step {step}: Loss exploded to {loss.item():.2f}")
                break
                
            loss.backward()
            
            # Calculate gradient norm
            grad_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    grad_norm += p.grad.data.norm(2).item() ** 2
            grad_norm = grad_norm ** 0.5
            
            optimizer.step()
            
            losses.append(loss.item())
            grad_norms.append(grad_norm)
            
            if step % 5 == 0:
                print(f"Step {step}: Loss = {loss.item():.4f}, Grad Norm = {grad_norm:.4f}")
                
        except Exception as e:
            print(f"💥 Error at step {step}: {e}")
            break
    
    return losses, grad_norms

print("Experimental setup ready! 🔬")

In [None]:
# Experiment 1: Exploding Gradients
# We'll use an extremely high learning rate to cause gradient explosion

model = GPTModel(config).to(device)
optimizer = optim.Adam(model.parameters(), lr=1.0)  # Dangerously high!

exploding_losses, exploding_grads = run_training_experiment(
    model, optimizer, steps=15, 
    experiment_name="EXPLODING GRADIENTS (LR = 1.0)"
)

print("\n🔍 DIAGNOSIS:")
print("✓ Loss increases or becomes NaN")
print("✓ Gradient norms grow exponentially") 
print("✓ Training completely breaks")
print("\n💊 CURE: Lower learning rate to ~1e-4, add gradient clipping")

In [None]:
# Experiment 2: Vanishing Gradients  
# We'll create a deep model with poor initialization

# Create deeper model with tiny initialization
deep_config = config.copy()
deep_config['n_layers'] = 8  # Much deeper

model = GPTModel(deep_config).to(device)

# Initialize weights to be extremely small (bad practice!)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.normal_(p, 0, 0.001)  # Way too small

optimizer = optim.Adam(model.parameters(), lr=1e-4)

vanishing_losses, vanishing_grads = run_training_experiment(
    model, optimizer, steps=25,
    experiment_name="VANISHING GRADIENTS (Deep + Tiny Init)"
)

print("\n🔍 DIAGNOSIS:")
print("✓ Loss decreases extremely slowly")
print("✓ Gradient norms are tiny (< 1e-4)")
print("✓ Model barely learns despite many steps")
print("\n💊 CURE: Better initialization (Xavier/He), residual connections, layer norm")

In [None]:
# Experiment 3: Learning Rate Too Low
# We'll use an extremely small learning rate

model = GPTModel(config).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-8)  # Way too small!

slow_losses, slow_grads = run_training_experiment(
    model, optimizer, steps=30,
    experiment_name="LEARNING RATE TOO LOW (LR = 1e-8)"
)

print("\n🔍 DIAGNOSIS:")
print("✓ Loss decreases painfully slowly")
print("✓ Gradients are reasonable but updates are tiny")
print("✓ Would take forever to converge")
print("\n💊 CURE: Increase learning rate to 1e-4 to 1e-3 range")

# Experiment 4: Healthy Training (for comparison)
print("\n" + "="*50)
model = GPTModel(config).to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

healthy_losses, healthy_grads = run_training_experiment(
    model, optimizer, steps=25,
    experiment_name="HEALTHY TRAINING (LR = 3e-4 + AdamW)"
)

print("\n🔍 DIAGNOSIS:")
print("✓ Loss decreases smoothly and consistently")
print("✓ Gradient norms are stable (0.1 - 10 range)")
print("✓ No instabilities or plateaus")
print("\n💊 This is what good training looks like!")

In [None]:
# Visualize all failure modes side by side
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot loss curves
axes[0, 0].plot(exploding_losses, 'r-', linewidth=3, label='Exploding')
axes[0, 0].plot(slow_losses, 'b-', linewidth=2, label='Too Slow') 
axes[0, 0].plot(healthy_losses, 'g-', linewidth=2, label='Healthy')
axes[0, 0].set_title('Loss Curves: Failure vs Success')
axes[0, 0].set_xlabel('Training Steps')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Plot gradient norms
axes[0, 1].plot(exploding_grads, 'r-', linewidth=3, label='Exploding')
axes[0, 1].plot(vanishing_grads, 'orange', linewidth=2, label='Vanishing')
axes[0, 1].plot(slow_grads, 'b-', linewidth=2, label='Too Slow')
axes[0, 1].plot(healthy_grads, 'g-', linewidth=2, label='Healthy')
axes[0, 1].set_title('Gradient Norms: Different Signatures')
axes[0, 1].set_xlabel('Training Steps')
axes[0, 1].set_ylabel('Gradient Norm')
axes[0, 1].set_yscale('log')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Add healthy zones for gradient norms
axes[0, 1].axhspan(0.1, 10, alpha=0.2, color='green', label='Healthy Zone')
axes[0, 1].axhspan(10, 1000, alpha=0.2, color='red', label='Danger Zone')
axes[0, 1].axhspan(1e-6, 0.1, alpha=0.2, color='orange', label='Vanishing Zone')

# Create diagnostic summary
axes[1, 0].text(0.1, 0.9, '🚨 EXPLODING GRADIENTS', fontsize=14, weight='bold', color='red', transform=axes[1, 0].transAxes)
axes[1, 0].text(0.1, 0.8, '• Loss increases rapidly', fontsize=12, transform=axes[1, 0].transAxes)
axes[1, 0].text(0.1, 0.7, '• Grad norms > 10', fontsize=12, transform=axes[1, 0].transAxes)
axes[1, 0].text(0.1, 0.6, '• Training breaks', fontsize=12, transform=axes[1, 0].transAxes)
axes[1, 0].text(0.1, 0.5, '💊 Fix: Lower LR, clip gradients', fontsize=12, weight='bold', color='blue', transform=axes[1, 0].transAxes)

axes[1, 0].text(0.1, 0.3, '🐌 VANISHING GRADIENTS', fontsize=14, weight='bold', color='orange', transform=axes[1, 0].transAxes)
axes[1, 0].text(0.1, 0.2, '• Loss decreases slowly', fontsize=12, transform=axes[1, 0].transAxes)
axes[1, 0].text(0.1, 0.1, '• Grad norms < 1e-4', fontsize=12, transform=axes[1, 0].transAxes)
axes[1, 0].text(0.1, 0.0, '💊 Fix: Better init, residuals', fontsize=12, weight='bold', color='blue', transform=axes[1, 0].transAxes)
axes[1, 0].set_xlim(0, 1)
axes[1, 0].set_ylim(0, 1)
axes[1, 0].axis('off')
axes[1, 0].set_title('Quick Diagnosis Guide')

axes[1, 1].text(0.1, 0.9, '🐢 LEARNING RATE TOO LOW', fontsize=14, weight='bold', color='blue', transform=axes[1, 1].transAxes)
axes[1, 1].text(0.1, 0.8, '• Painfully slow progress', fontsize=12, transform=axes[1, 1].transAxes)
axes[1, 1].text(0.1, 0.7, '• Reasonable gradients', fontsize=12, transform=axes[1, 1].transAxes)
axes[1, 1].text(0.1, 0.6, '• Takes forever', fontsize=12, transform=axes[1, 1].transAxes)
axes[1, 1].text(0.1, 0.5, '💊 Fix: Increase LR 5-10x', fontsize=12, weight='bold', color='blue', transform=axes[1, 1].transAxes)

axes[1, 1].text(0.1, 0.3, '✅ HEALTHY TRAINING', fontsize=14, weight='bold', color='green', transform=axes[1, 1].transAxes)
axes[1, 1].text(0.1, 0.2, '• Smooth loss decrease', fontsize=12, transform=axes[1, 1].transAxes)
axes[1, 1].text(0.1, 0.1, '• Stable grad norms (0.1-10)', fontsize=12, transform=axes[1, 1].transAxes)
axes[1, 1].text(0.1, 0.0, '💊 Keep going!', fontsize=12, weight='bold', color='green', transform=axes[1, 1].transAxes)
axes[1, 1].set_xlim(0, 1)
axes[1, 1].set_ylim(0, 1)
axes[1, 1].axis('off')
axes[1, 1].set_title('Treatment Guide')

plt.tight_layout()
plt.show()

print("\n🎯 KEY INSIGHT: Each failure has a unique signature!")
print("Learn to recognize these patterns and you can debug any training problem.")

## 2. Real-Time Health Monitoring

### Why Monitor During Training?

Training can go wrong at any moment. Problems often start small and then explode. Real-time monitoring catches issues early, like a smoke detector for your model.

### What to Monitor

**Essential Metrics** (monitor these always):
1. **Loss value**: Should decrease smoothly
2. **Gradient norm**: Should stay in 0.1-10 range  
3. **Learning rate**: Should follow your schedule
4. **Training speed**: Should be consistent

**Advanced Metrics** (for deeper insights):
1. **Gradient-to-parameter ratio**: Measures update magnitude
2. **Layer-wise gradients**: Detects vanishing/exploding by layer
3. **Parameter norms**: Tracks model evolution
4. **Memory usage**: Prevents OOM crashes

### The Physics of Gradient Monitoring

Gradients tell you everything about training health:
- **Too large (>10)**: Updates are too big, causing instability
- **Too small (<0.001)**: Updates are too tiny, learning is slow
- **Just right (0.1-10)**: Model learns efficiently and stably

Let's build a monitoring system that tracks these metrics:

In [None]:
class HealthMonitor:
    """Real-time training health monitor - your model's vital signs tracker."""
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Start fresh monitoring."""
        self.losses = []
        self.grad_norms = []
        self.learning_rates = []
        self.grad_param_ratios = []
        self.steps = []
        
    def check_vitals(self, model, loss, optimizer, step):
        """Check model's vital signs after each training step."""
        
        # Record basic metrics
        self.losses.append(loss.item())
        self.learning_rates.append(optimizer.param_groups[0]['lr'])
        self.steps.append(step)
        
        # Calculate gradient norm
        total_grad_norm = 0.0
        total_param_norm = 0.0
        
        for param in model.parameters():
            if param.grad is not None:
                grad_norm = param.grad.data.norm(2).item()
                param_norm = param.data.norm(2).item()
                
                total_grad_norm += grad_norm ** 2
                total_param_norm += param_norm ** 2
        
        total_grad_norm = total_grad_norm ** 0.5
        total_param_norm = total_param_norm ** 0.5
        
        self.grad_norms.append(total_grad_norm)
        
        # Gradient-to-parameter ratio (key health indicator)
        ratio = total_grad_norm / (total_param_norm + 1e-8)
        self.grad_param_ratios.append(ratio)
        
        return self.diagnose_current_health()
    
    def diagnose_current_health(self):
        """Instant diagnosis of current training health."""
        if not self.grad_norms:
            return "No data yet"
        
        current_grad = self.grad_norms[-1]
        current_ratio = self.grad_param_ratios[-1]
        current_loss = self.losses[-1]
        
        # Check for critical issues
        if not np.isfinite(current_loss):
            return "🚨 CRITICAL: Loss is NaN or infinite"
        
        if current_grad > 10:
            return "🚨 DANGER: Gradient norm too high, risk of explosion"
        
        if current_grad < 1e-5:
            return "🐌 WARNING: Gradient norm too low, vanishing gradients"
        
        if current_ratio > 0.1:
            return "⚠️ CAUTION: Learning rate might be too high"
        
        if current_ratio < 1e-5:
            return "⚠️ CAUTION: Learning rate might be too low"
        
        # Check trend if we have enough data
        if len(self.losses) > 5:
            recent_losses = self.losses[-5:]
            if all(l > recent_losses[0] for l in recent_losses[2:]):
                return "📈 WARNING: Loss is increasing, check for overfitting"
        
        return "✅ HEALTHY: All vitals normal"
    
    def emergency_stop_needed(self):
        """Check if training should be stopped immediately."""
        if not self.grad_norms:
            return False
        
        current_grad = self.grad_norms[-1]
        current_loss = self.losses[-1]
        
        # Emergency conditions
        return (not np.isfinite(current_loss) or 
                current_grad > 100 or 
                current_loss > 50)
    
    def plot_dashboard(self):
        """Plot a comprehensive health dashboard."""
        if len(self.steps) < 2:
            print("Need more data for dashboard")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Loss curve with health zones
        axes[0, 0].plot(self.steps, self.losses, 'b-', linewidth=2)
        axes[0, 0].set_title('📉 Loss Curve')
        axes[0, 0].set_xlabel('Steps')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Gradient norms with danger zones
        axes[0, 1].plot(self.steps, self.grad_norms, 'g-', linewidth=2)
        axes[0, 1].axhspan(0.1, 10, alpha=0.2, color='green', label='Healthy Zone')
        axes[0, 1].axhspan(10, 1000, alpha=0.2, color='red', label='Danger Zone')
        axes[0, 1].axhspan(1e-6, 0.1, alpha=0.2, color='orange', label='Vanishing Zone')
        axes[0, 1].set_title('🌡️ Gradient Health')
        axes[0, 1].set_xlabel('Steps')
        axes[0, 1].set_ylabel('Gradient Norm')
        axes[0, 1].set_yscale('log')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Learning rate schedule
        axes[1, 0].plot(self.steps, self.learning_rates, 'r-', linewidth=2)
        axes[1, 0].set_title('📊 Learning Rate')
        axes[1, 0].set_xlabel('Steps')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Gradient-to-parameter ratio
        axes[1, 1].plot(self.steps, self.grad_param_ratios, 'm-', linewidth=2)
        axes[1, 1].axhspan(1e-5, 0.1, alpha=0.2, color='green', label='Healthy Zone')
        axes[1, 1].axhspan(0.1, 10, alpha=0.2, color='red', label='Too High')
        axes[1, 1].set_title('⚖️ Update Magnitude (Grad/Param Ratio)')
        axes[1, 1].set_xlabel('Steps')
        axes[1, 1].set_ylabel('Ratio')
        axes[1, 1].set_yscale('log')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print summary
        print(f"\n📊 HEALTH SUMMARY (Step {self.steps[-1]}):")
        print(f"   Current Status: {self.diagnose_current_health()}")
        print(f"   Loss: {self.losses[-1]:.4f}")
        print(f"   Grad Norm: {self.grad_norms[-1]:.4f}")
        print(f"   Learning Rate: {self.learning_rates[-1]:.6f}")
        print(f"   Update Ratio: {self.grad_param_ratios[-1]:.6f}")

print("Health monitoring system ready! 🏥")

In [None]:
# Demonstrate healthy training with monitoring
print("🏥 DEMONSTRATING HEALTHY TRAINING")
print("=" * 40)

# Create fresh model and monitor
model = GPTModel(**config).to(device)
monitor = GradientMonitor(model)

# Healthy training setup
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

data_iter = iter(data_loader)

# Train with monitoring
for step in range(50):
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(data_loader)
        batch = next(data_iter)
    
    input_ids, target_ids = batch
    input_ids, target_ids = input_ids.to(device), target_ids.to(device)
    
    optimizer.zero_grad()
    logits, _ = model(input_ids)
    loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
    loss.backward()
    
    # Gradient clipping (good practice!)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    # Monitor before stepping
    monitor.update(step)
    
    optimizer.step()
    
    if step % 10 == 0:
        diagnosis = monitor.diagnose()
        print(f"Step {step}: Loss = {loss.item():.4f}")
        for key, value in diagnosis.items():
            print(f"  {key}: {value}")

# Plot the healthy training
monitor.plot_health()

print("\n✅ This is what healthy training looks like!")
print("Notice the stable, well-behaved gradients.")

## 3. Loss Curve Interpretation

Understanding what different loss curve patterns mean.

In [None]:
class LossCurveAnalyzer:
    """Analyze and interpret loss curves."""
    
    @staticmethod
    def analyze_loss_curve(losses: List[float], window_size: int = 10) -> Dict[str, str]:
        """Analyze loss curve patterns."""
        if len(losses) < window_size:
            return {"status": "Insufficient data"}
        
        diagnosis = {}
        
        # Overall trend
        start_loss = np.mean(losses[:window_size])
        end_loss = np.mean(losses[-window_size:])
        improvement = (start_loss - end_loss) / start_loss
        
        if improvement > 0.1:
            diagnosis["overall_trend"] = "✅ GOOD - Loss decreasing well"
        elif improvement > 0.01:
            diagnosis["overall_trend"] = "⚠️ SLOW - Loss decreasing slowly"
        elif improvement < -0.01:
            diagnosis["overall_trend"] = "🚨 BAD - Loss increasing"
        else:
            diagnosis["overall_trend"] = "📊 FLAT - Loss plateaued"
        
        # Volatility
        recent_losses = losses[-window_size:]
        volatility = np.std(recent_losses) / np.mean(recent_losses)
        
        if volatility > 0.1:
            diagnosis["stability"] = "🌊 NOISY - High variance in loss"
        elif volatility < 0.01:
            diagnosis["stability"] = "📉 SMOOTH - Very stable loss"
        else:
            diagnosis["stability"] = "✅ STABLE - Normal variance"
        
        # Check for oscillations
        if len(losses) > 20:
            recent = losses[-20:]
            # Count direction changes
            direction_changes = 0
            for i in range(1, len(recent) - 1):
                if (recent[i] > recent[i-1] and recent[i] > recent[i+1]) or \
                   (recent[i] < recent[i-1] and recent[i] < recent[i+1]):
                    direction_changes += 1
            
            if direction_changes > len(recent) * 0.3:
                diagnosis["pattern"] = "🌀 OSCILLATING - Loss bouncing up and down"
            else:
                diagnosis["pattern"] = "➡️ MONOTONIC - Smooth trend"
        
        # Check for NaN or infinite values
        if any(not np.isfinite(l) for l in losses[-10:]):
            diagnosis["health"] = "💀 BROKEN - NaN or infinite loss"
        elif max(losses[-10:]) > 100:
            diagnosis["health"] = "🚨 EXPLODING - Loss too high"
        else:
            diagnosis["health"] = "✅ HEALTHY - Normal loss values"
        
        return diagnosis
    
    @staticmethod
    def generate_example_curves():
        """Generate example loss curves for different scenarios."""
        steps = np.arange(100)
        
        curves = {}
        
        # Healthy learning curve
        healthy = 5.0 * np.exp(-steps / 30) + 1.0 + 0.1 * np.random.randn(100) * 0.1
        curves["Healthy Learning"] = healthy
        
        # Overfitting curve
        overfitting = 5.0 * np.exp(-steps / 20) + 1.0
        # Add upturn after step 60
        upturn = np.where(steps > 60, 0.02 * (steps - 60), 0)
        overfitting += upturn + 0.05 * np.random.randn(100)
        curves["Overfitting"] = overfitting
        
        # Learning rate too high (oscillating)
        oscillating = 3.0 + 0.5 * np.sin(steps * 0.3) + 0.1 * np.random.randn(100)
        curves["LR Too High (Oscillating)"] = oscillating
        
        # Learning rate too low (plateau)
        plateau = 5.0 * np.exp(-steps / 100) + 3.0 + 0.05 * np.random.randn(100)
        curves["LR Too Low (Plateau)"] = plateau
        
        # Exploding gradients
        exploding = 2.0 + np.where(steps > 30, 0.5 * (steps - 30) ** 1.5, 0) + 0.1 * np.random.randn(100)
        exploding = np.clip(exploding, 0, 50)  # Clip to avoid infinity
        curves["Exploding Gradients"] = exploding
        
        # Vanishing gradients (no learning)
        vanishing = 5.0 + 0.05 * np.random.randn(100)
        curves["Vanishing Gradients"] = vanishing
        
        return curves

# Generate and analyze example curves
analyzer = LossCurveAnalyzer()
example_curves = analyzer.generate_example_curves()

# Plot all example curves
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for i, (name, curve) in enumerate(example_curves.items()):
    if i < len(axes):
        axes[i].plot(curve, linewidth=2)
        axes[i].set_title(f'{name}')
        axes[i].set_xlabel('Training Step')
        axes[i].set_ylabel('Loss')
        axes[i].grid(True, alpha=0.3)
        
        # Add diagnosis
        diagnosis = analyzer.analyze_loss_curve(curve.tolist())
        diagnosis_text = "\n".join([f"{k}: {v}" for k, v in diagnosis.items()])
        axes[i].text(0.02, 0.98, diagnosis_text, transform=axes[i].transAxes,
                    verticalalignment='top', fontsize=8,
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

plt.tight_layout()
plt.show()

print("\n📊 LOSS CURVE INTERPRETATION GUIDE:")
print("=" * 40)
print("✅ Healthy: Smooth decrease, low noise, good convergence")
print("📈 Overfitting: Initial decrease then increase (validation loss)")
print("🌀 Oscillating: Learning rate too high, bouncy loss")
print("📊 Plateau: Learning rate too low, stuck at high loss")
print("🚨 Exploding: Gradients explode, loss shoots up")
print("😴 Vanishing: No learning, loss stays flat")

## 4. Comprehensive Troubleshooting Guide

A systematic approach to diagnosing and fixing training problems.

In [None]:
class TransformerDoctor:
    """Comprehensive transformer training diagnostics."""
    
    def __init__(self):
        self.symptoms = []
        self.diagnosis = []
        self.treatments = []
    
    def examine_patient(self, model, losses, grad_monitor, data_loader, optimizer):
        """Comprehensive examination of training health."""
        report = {
            "model_health": self._check_model_health(model),
            "loss_health": self._check_loss_health(losses),
            "gradient_health": self._check_gradient_health(grad_monitor),
            "data_health": self._check_data_health(data_loader),
            "optimizer_health": self._check_optimizer_health(optimizer),
            "overall_diagnosis": "",
            "treatment_plan": []
        }
        
        # Generate overall diagnosis
        critical_issues = []
        warnings = []
        
        for category, findings in report.items():
            if isinstance(findings, dict):
                for finding in findings.values():
                    if "🚨" in str(finding) or "💀" in str(finding):
                        critical_issues.append(finding)
                    elif "⚠️" in str(finding) or "🐌" in str(finding):
                        warnings.append(finding)
        
        if critical_issues:
            report["overall_diagnosis"] = "🚨 CRITICAL: Immediate attention required"
            report["treatment_plan"] = self._generate_critical_treatment(critical_issues)
        elif warnings:
            report["overall_diagnosis"] = "⚠️ CONCERNING: Optimization needed"
            report["treatment_plan"] = self._generate_warning_treatment(warnings)
        else:
            report["overall_diagnosis"] = "✅ HEALTHY: Training looks good"
            report["treatment_plan"] = ["Continue current training regimen"]
        
        return report
    
    def _check_model_health(self, model):
        """Check model architecture health."""
        health = {}
        
        # Parameter count
        total_params = sum(p.numel() for p in model.parameters())
        if total_params > 10**9:
            health["size"] = "🚨 VERY LARGE - May need distributed training"
        elif total_params > 10**6:
            health["size"] = "⚠️ LARGE - Monitor memory usage"
        else:
            health["size"] = "✅ REASONABLE"
        
        # Check for NaN parameters
        has_nan = any(torch.isnan(p).any() for p in model.parameters())
        if has_nan:
            health["parameters"] = "💀 NaN DETECTED - Model corrupted"
        else:
            health["parameters"] = "✅ CLEAN"
        
        # Parameter initialization check
        param_stds = [p.data.std().item() for p in model.parameters() if p.dim() > 1]
        if param_stds:
            avg_std = np.mean(param_stds)
            if avg_std > 0.5:
                health["initialization"] = "⚠️ LARGE - Parameters might be too large"
            elif avg_std < 0.001:
                health["initialization"] = "⚠️ SMALL - Parameters might be too small"
            else:
                health["initialization"] = "✅ REASONABLE"
        
        return health
    
    def _check_loss_health(self, losses):
        """Check loss curve health."""
        if not losses:
            return {"status": "No loss data"}
        
        analyzer = LossCurveAnalyzer()
        return analyzer.analyze_loss_curve(losses)
    
    def _check_gradient_health(self, grad_monitor):
        """Check gradient health."""
        return grad_monitor.diagnose()
    
    def _check_data_health(self, data_loader):
        """Check data health."""
        health = {}
        
        # Check batch size
        batch_size = data_loader.batch_size
        if batch_size < 2:
            health["batch_size"] = "⚠️ VERY SMALL - May cause instability"
        elif batch_size > 64:
            health["batch_size"] = "⚠️ LARGE - May need gradient accumulation"
        else:
            health["batch_size"] = "✅ REASONABLE"
        
        # Check dataset size
        dataset_size = len(data_loader.dataset)
        if dataset_size < 100:
            health["dataset_size"] = "⚠️ TINY - Risk of overfitting"
        elif dataset_size < 1000:
            health["dataset_size"] = "⚠️ SMALL - Limited generalization"
        else:
            health["dataset_size"] = "✅ ADEQUATE"
        
        return health
    
    def _check_optimizer_health(self, optimizer):
        """Check optimizer configuration."""
        health = {}
        
        lr = optimizer.param_groups[0]['lr']
        if lr > 1e-2:
            health["learning_rate"] = "🚨 TOO HIGH - Risk of exploding gradients"
        elif lr < 1e-6:
            health["learning_rate"] = "🐌 TOO LOW - Very slow learning"
        elif lr < 1e-5:
            health["learning_rate"] = "⚠️ LOW - May learn slowly"
        else:
            health["learning_rate"] = "✅ REASONABLE"
        
        optimizer_type = type(optimizer).__name__
        if optimizer_type in ['AdamW', 'Adam']:
            health["optimizer_type"] = "✅ GOOD CHOICE"
        elif optimizer_type == 'SGD':
            health["optimizer_type"] = "⚠️ BASIC - Consider Adam/AdamW"
        else:
            health["optimizer_type"] = "❓ UNKNOWN"
        
        return health
    
    def _generate_critical_treatment(self, issues):
        """Generate treatment for critical issues."""
        treatments = []
        
        issue_text = " ".join(str(issue) for issue in issues)
        
        if "NaN" in issue_text:
            treatments.extend([
                "🚨 IMMEDIATE: Stop training, model is corrupted",
                "💊 Restore from last good checkpoint",
                "🔧 Check for division by zero in loss function",
                "🔧 Add gradient clipping: clip_grad_norm_(params, 1.0)"
            ])
        
        if "exploding" in issue_text.lower() or "TOO HIGH" in issue_text:
            treatments.extend([
                "📉 Reduce learning rate by 10x",
                "✂️ Add gradient clipping",
                "🔄 Restart with better initialization"
            ])
        
        if "BROKEN" in issue_text:
            treatments.extend([
                "🛑 Stop training immediately",
                "🔍 Debug loss computation",
                "💾 Restore from checkpoint"
            ])
        
        return treatments if treatments else ["🆘 Seek expert help"]
    
    def _generate_warning_treatment(self, warnings):
        """Generate treatment for warning issues."""
        treatments = []
        
        warning_text = " ".join(str(warning) for warning in warnings)
        
        if "SLOW" in warning_text or "TOO LOW" in warning_text:
            treatments.extend([
                "📈 Increase learning rate by 2-5x",
                "⏰ Consider learning rate scheduling",
                "🔧 Try different optimizer (AdamW)"
            ])
        
        if "OSCILLATING" in warning_text or "NOISY" in warning_text:
            treatments.extend([
                "📉 Reduce learning rate by 2-3x",
                "📊 Increase batch size",
                "🎯 Add learning rate decay"
            ])
        
        if "overfitting" in warning_text.lower():
            treatments.extend([
                "🛑 Add early stopping",
                "💧 Increase dropout",
                "📚 Get more training data",
                "⚖️ Add weight decay"
            ])
        
        return treatments if treatments else ["📊 Monitor closely, minor optimizations needed"]
    
    def print_diagnosis(self, report):
        """Print a nicely formatted diagnosis report."""
        print("\n" + "=" * 60)
        print("🏥 TRANSFORMER TRAINING HEALTH REPORT")
        print("=" * 60)
        
        print(f"\n🎯 OVERALL DIAGNOSIS: {report['overall_diagnosis']}")
        
        print("\n📋 DETAILED EXAMINATION:")
        for category, findings in report.items():
            if category not in ['overall_diagnosis', 'treatment_plan'] and isinstance(findings, dict):
                print(f"\n  {category.replace('_', ' ').title()}:")
                for key, value in findings.items():
                    print(f"    • {key}: {value}")
        
        print("\n💊 TREATMENT PLAN:")
        for i, treatment in enumerate(report['treatment_plan'], 1):
            print(f"  {i}. {treatment}")
        
        print("\n" + "=" * 60)

print("Transformer Doctor ready for consultation! 👨‍⚕️")

In [None]:
# Demonstrate the Transformer Doctor
print("🏥 RUNNING COMPREHENSIVE HEALTH CHECK")
print("=" * 50)

# Create a problematic scenario
model = GPTModel(**config).to(device)
monitor = GradientMonitor(model)

# Problematic setup - learning rate too high
optimizer = optim.Adam(model.parameters(), lr=0.1)  # Too high!
criterion = nn.CrossEntropyLoss()

# Simulate some problematic training
losses = []
data_iter = iter(data_loader)

for step in range(20):
    try:
        batch = next(data_iter)
    except StopIteration:
        data_iter = iter(data_loader)
        batch = next(data_iter)
    
    input_ids, target_ids = batch
    input_ids, target_ids = input_ids.to(device), target_ids.to(device)
    
    optimizer.zero_grad()
    logits, _ = model(input_ids)
    loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))
    
    # Stop if loss explodes
    if loss.item() > 50:
        print(f"⚠️ Stopping at step {step}, loss exploded to {loss.item():.2f}")
        break
    
    loss.backward()
    monitor.update(step)
    optimizer.step()
    
    losses.append(loss.item())

# Get diagnosis from the doctor
doctor = TransformerDoctor()
health_report = doctor.examine_patient(model, losses, monitor, data_loader, optimizer)

# Print the diagnosis
doctor.print_diagnosis(health_report)

## Summary: Your Transformer Training Toolkit 🔧

You now have a comprehensive debugging and monitoring system for transformer training!

### 🎯 Key Debugging Skills

**1. Recognizing Failure Modes**
- **Exploding Gradients**: Loss shoots up, grad norms explode → Lower LR, add clipping
- **Vanishing Gradients**: Loss barely changes, tiny grad norms → Better init, check architecture
- **Wrong Learning Rate**: Too high = oscillations, too low = plateau → Tune LR carefully

**2. Gradient Health Monitoring**
- **Healthy range**: Grad norms between 1e-6 and 10
- **Grad/param ratio**: Should be between 1e-6 and 0.1
- **Layer analysis**: All layers should receive reasonable gradients

**3. Loss Curve Reading**
- **Healthy**: Smooth decrease with low noise
- **Overfitting**: Initial decrease then increase (for validation)
- **LR issues**: Oscillations (too high) or plateau (too low)

**4. Systematic Diagnosis**
- **Check model**: Parameters, initialization, architecture
- **Check data**: Batch size, dataset size, quality
- **Check optimizer**: Learning rate, type, configuration
- **Check gradients**: Norms, ratios, layer distribution

### 🚨 Emergency Procedures

**If training explodes:**
1. Stop immediately
2. Restore from last good checkpoint
3. Reduce learning rate by 10x
4. Add gradient clipping
5. Restart training

**If training stagnates:**
1. Check gradient norms
2. Increase learning rate by 2-5x
3. Verify data quality
4. Consider architecture changes

### 💡 Best Practices

- **Always monitor gradients** during training
- **Use gradient clipping** as safety net
- **Save checkpoints frequently** for recovery
- **Start with proven hyperparameters** then tune
- **Validate on held-out data** to catch overfitting

### 🎓 What You've Mastered

You can now:
- **Diagnose** common training failures
- **Monitor** training health in real-time
- **Interpret** loss curves and gradient behavior
- **Troubleshoot** systematically when things go wrong
- **Prevent** most common training disasters

This debugging toolkit will save you countless hours of frustration and help you train transformers reliably! 🛠️