# Debugging and Monitoring Transformer Training

Training transformers is like flying a plane - you need instruments to know if you're crashing. This notebook teaches you to diagnose problems before they destroy your training.

## The Physics of Training Failure

Every training failure has a root cause in the mathematics:

**Gradient Explosion**: When gradients compound through layers, they can grow exponentially. If each layer multiplies gradients by factor > 1, then after L layers: gradient ∝ (factor)^L → ∞

**Gradient Vanishing**: When gradients shrink through layers. If each layer multiplies by factor < 1, then: gradient ∝ (factor)^L → 0

**Learning Rate Problems**: Learning rate η controls step size in loss landscape. Too high = overshoot minimum, too low = barely move.

## What You'll Master

1. **Recognize symptoms** of common failures instantly
2. **Monitor training health** with key metrics
3. **Diagnose root causes** systematically
4. **Apply targeted fixes** to save your training

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
from collections import defaultdict

from src.model.transformer import GPTModel, create_model_config
from src.data.tokenizer import create_tokenizer
from src.data.dataset import SimpleTextDataset, create_dataloader

torch.manual_seed(42)
np.random.seed(42)

plt.style.use('default')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("Training debugger loaded! 🔧")

## 1. The Three Deadly Training Failures

Understanding the fundamental physics helps you diagnose problems instantly.

### A. Exploding Gradients

**Root Cause**: During backpropagation, gradients multiply through layers. In deep networks:
```
gradient_layer_1 = gradient_output × weight_L × weight_(L-1) × ... × weight_2
```

If weights are large or learning rate is high, this product explodes exponentially.

**Symptoms**: Loss suddenly jumps to infinity, parameters become NaN, training crashes

### B. Vanishing Gradients

**Root Cause**: Same chain rule, but weights are too small or activations saturate:
```
If each weight < 1, then product → 0 as depth increases
```

**Symptoms**: Loss barely decreases, early layers don't learn, painfully slow progress

### C. Wrong Learning Rate

**Root Cause**: Learning rate controls step size in parameter space:
```
θ_new = θ_old - η × gradient
```

Too high η = overshoot minimum, too low η = tiny steps

**Symptoms**: High η causes oscillations, low η causes stagnation

Let's simulate each failure to see their signatures:

In [None]:
# Setup: Create simulation environment for training failures

# Simple model configuration for fast experiments
config = {
    'vocab_size': 100,
    'd_model': 64,
    'n_heads': 4,
    'n_layers': 2,
    'd_ff': 128,
    'max_seq_len': 32,
    'dropout': 0.1
}

def create_dummy_batch(batch_size=4, seq_len=16):
    """Create dummy training data for experiments."""
    x = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
    targets = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
    return x, targets

def simulate_training_failure(model, optimizer, steps=20, name=""):
    """Simulate training and track failure patterns."""
    print(f"\n🧪 Simulating: {name}")
    print("-" * 40)
    
    criterion = nn.CrossEntropyLoss()
    losses = []
    grad_norms = []
    
    for step in range(steps):
        x, targets = create_dummy_batch()
        
        optimizer.zero_grad()
        
        try:
            # Forward pass
            outputs = model(x)
            loss = criterion(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
            
            # Check for explosion
            if not torch.isfinite(loss) or loss.item() > 100:
                print(f"💥 Step {step}: Loss exploded to {loss.item():.2f}")
                break
            
            # Backward pass
            loss.backward()
            
            # Calculate gradient norm (key diagnostic)
            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    total_norm += p.grad.data.norm(2).item() ** 2
            total_norm = total_norm ** 0.5
            
            optimizer.step()
            
            # Record metrics
            losses.append(loss.item())
            grad_norms.append(total_norm)
            
            if step % 5 == 0:
                print(f"Step {step}: Loss = {loss.item():.4f}, Grad Norm = {total_norm:.4f}")
                
        except Exception as e:
            print(f"💥 Crashed at step {step}: {e}")
            break
    
    return losses, grad_norms

print("Failure simulation setup complete! 🔬")

In [None]:
# Experiment 1: Exploding Gradients (Learning Rate Too High)

model1 = GPTModel(**config).to(device)
optimizer1 = optim.Adam(model1.parameters(), lr=1.0)  # Dangerously high!

exploding_losses, exploding_grads = simulate_training_failure(
    model1, optimizer1, steps=15, 
    name="EXPLODING GRADIENTS (LR = 1.0)"
)

print("\n🔍 Key Insight: Notice how gradients grow exponentially before explosion")

In [None]:
# Experiment 2: Vanishing Gradients (Deep Model + Poor Initialization)

# Create deeper model with tiny weights
deep_config = config.copy()
deep_config['n_layers'] = 6  # Much deeper

model2 = GPTModel(**deep_config).to(device)

# Initialize weights extremely small (bad practice!)
for p in model2.parameters():
    if p.dim() > 1:
        nn.init.normal_(p, 0, 0.001)  # Way too small

optimizer2 = optim.Adam(model2.parameters(), lr=1e-4)

vanishing_losses, vanishing_grads = simulate_training_failure(
    model2, optimizer2, steps=25,
    name="VANISHING GRADIENTS (Deep + Tiny Init)"
)

print("\n🔍 Key Insight: Gradients are tiny, model barely learns despite many steps")

In [None]:
# Experiment 3: Learning Rate Too Low

model3 = GPTModel(**config).to(device)
optimizer3 = optim.Adam(model3.parameters(), lr=1e-8)  # Way too small!

slow_losses, slow_grads = simulate_training_failure(
    model3, optimizer3, steps=30,
    name="LEARNING RATE TOO LOW (LR = 1e-8)"
)

print("\n🔍 Key Insight: Gradients are reasonable but progress is painfully slow")

In [None]:
# Experiment 4: Healthy Training (For Comparison)

model4 = GPTModel(**config).to(device)
optimizer4 = optim.AdamW(model4.parameters(), lr=3e-4, weight_decay=0.01)

healthy_losses, healthy_grads = simulate_training_failure(
    model4, optimizer4, steps=25,
    name="HEALTHY TRAINING (LR = 3e-4)"
)

print("\n🔍 Key Insight: This is what good training looks like - stable and consistent")

In [None]:
# Visualize All Failure Modes Side by Side

fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves comparison
axes[0, 0].plot(exploding_losses, 'r-', linewidth=3, label='Exploding', marker='o')
axes[0, 0].plot(slow_losses, 'b-', linewidth=2, label='Too Slow', marker='s')
axes[0, 0].plot(healthy_losses, 'g-', linewidth=2, label='Healthy', marker='^')
axes[0, 0].set_title('Loss Patterns: Success vs Failure', fontsize=14, weight='bold')
axes[0, 0].set_xlabel('Training Steps')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Gradient norm comparison
axes[0, 1].plot(exploding_grads, 'r-', linewidth=3, label='Exploding', marker='o')
axes[0, 1].plot(vanishing_grads, 'orange', linewidth=2, label='Vanishing', marker='x')
axes[0, 1].plot(slow_grads, 'b-', linewidth=2, label='Too Slow', marker='s')
axes[0, 1].plot(healthy_grads, 'g-', linewidth=2, label='Healthy', marker='^')
axes[0, 1].set_title('Gradient Signatures', fontsize=14, weight='bold')
axes[0, 1].set_xlabel('Training Steps')
axes[0, 1].set_ylabel('Gradient Norm')
axes[0, 1].set_yscale('log')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Add healthy zones
axes[0, 1].axhspan(0.1, 10, alpha=0.2, color='green', label='Healthy Zone')
axes[0, 1].axhspan(10, 1000, alpha=0.2, color='red', label='Danger Zone')
axes[0, 1].axhspan(1e-6, 0.1, alpha=0.2, color='orange', label='Vanishing Zone')

# Diagnostic guide
diagnostic_text = [
    "🚨 EXPLODING GRADIENTS:",
    "• Loss increases rapidly",
    "• Grad norms > 10",
    "• Training crashes",
    "💊 Fix: Lower LR, clip gradients",
    "",
    "🐌 VANISHING GRADIENTS:",
    "• Loss decreases slowly",
    "• Grad norms < 1e-4",
    "• Model barely learns",
    "💊 Fix: Better init, residuals"
]

axes[1, 0].text(0.05, 0.95, '\n'.join(diagnostic_text), 
                transform=axes[1, 0].transAxes, fontsize=10,
                verticalalignment='top', fontfamily='monospace',
                bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
axes[1, 0].set_xlim(0, 1)
axes[1, 0].set_ylim(0, 1)
axes[1, 0].axis('off')
axes[1, 0].set_title('Quick Diagnosis Guide', fontsize=14, weight='bold')

treatment_text = [
    "🐢 LEARNING RATE TOO LOW:",
    "• Painfully slow progress",
    "• Reasonable gradients",
    "• Takes forever",
    "💊 Fix: Increase LR 5-10x",
    "",
    "✅ HEALTHY TRAINING:",
    "• Smooth loss decrease",
    "• Stable grad norms (0.1-10)",
    "• Consistent progress",
    "💊 Keep going!"
]

axes[1, 1].text(0.05, 0.95, '\n'.join(treatment_text), 
                transform=axes[1, 1].transAxes, fontsize=10,
                verticalalignment='top', fontfamily='monospace',
                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))
axes[1, 1].set_xlim(0, 1)
axes[1, 1].set_ylim(0, 1)
axes[1, 1].axis('off')
axes[1, 1].set_title('Treatment Guide', fontsize=14, weight='bold')

plt.tight_layout()
plt.show()

print("\n🎯 KEY INSIGHT: Each failure has a unique signature!")
print("Learn these patterns and you can diagnose any training problem instantly.")

## 2. Real-Time Health Monitor

Training problems start small then explode. Early detection saves your model.

### The Physics of Gradient Health

Gradients tell you everything about training state:

**Gradient Norm**: ||∇L|| measures total update magnitude
- Too large (>10): Updates are too big, causing instability  
- Too small (<0.001): Updates are too tiny, learning is slow
- Just right (0.1-10): Model learns efficiently

**Gradient-to-Parameter Ratio**: ||∇L|| / ||θ|| measures relative update size
- Healthy range: 1e-6 to 0.1
- Too high: Learning rate might be excessive
- Too low: Learning rate might be insufficient

**Loss Trends**: First derivative of loss curve
- Decreasing: Good progress
- Oscillating: Learning rate too high
- Flat: Learning rate too low or convergence

Let's build a monitoring system:

In [None]:
class TrainingHealthMonitor:
    """Real-time health monitoring for transformer training."""
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        """Reset all monitoring data."""
        self.losses = []
        self.grad_norms = []
        self.learning_rates = []
        self.grad_param_ratios = []
        self.steps = []
        
    def check_vitals(self, model, loss, optimizer, step):
        """Check model's vital signs after each step."""
        
        # Basic metrics
        self.losses.append(loss.item())
        self.learning_rates.append(optimizer.param_groups[0]['lr'])
        self.steps.append(step)
        
        # Calculate gradient and parameter norms
        total_grad_norm = 0.0
        total_param_norm = 0.0
        
        for param in model.parameters():
            if param.grad is not None:
                grad_norm = param.grad.data.norm(2).item()
                param_norm = param.data.norm(2).item()
                
                total_grad_norm += grad_norm ** 2
                total_param_norm += param_norm ** 2
        
        total_grad_norm = total_grad_norm ** 0.5
        total_param_norm = total_param_norm ** 0.5
        
        self.grad_norms.append(total_grad_norm)
        
        # Key diagnostic: gradient-to-parameter ratio
        ratio = total_grad_norm / (total_param_norm + 1e-8)
        self.grad_param_ratios.append(ratio)
        
        return self._diagnose_current_health()
    
    def _diagnose_current_health(self):
        """Instant health diagnosis."""
        if not self.grad_norms:
            return "No data yet"
        
        current_grad = self.grad_norms[-1]
        current_ratio = self.grad_param_ratios[-1]
        current_loss = self.losses[-1]
        
        # Critical issues (emergency stop needed)
        if not np.isfinite(current_loss):
            return "🚨 CRITICAL: Loss is NaN or infinite - STOP TRAINING"
        
        if current_grad > 10:
            return "🚨 DANGER: Gradient explosion imminent"
        
        if current_grad < 1e-5:
            return "🐌 WARNING: Vanishing gradients detected"
        
        # Learning rate issues
        if current_ratio > 0.1:
            return "⚠️ CAUTION: Learning rate might be too high"
        
        if current_ratio < 1e-5:
            return "⚠️ CAUTION: Learning rate might be too low"
        
        # Trend analysis (if enough data)
        if len(self.losses) > 5:
            recent_losses = self.losses[-5:]
            if all(l > recent_losses[0] for l in recent_losses[2:]):
                return "📈 WARNING: Loss trending upward - possible overfitting"
        
        return "✅ HEALTHY: All vitals normal"
    
    def emergency_stop_needed(self):
        """Check if training should stop immediately."""
        if not self.grad_norms:
            return False
        
        current_grad = self.grad_norms[-1]
        current_loss = self.losses[-1]
        
        return (not np.isfinite(current_loss) or 
                current_grad > 100 or 
                current_loss > 50)
    
    def plot_dashboard(self):
        """Plot comprehensive health dashboard."""
        if len(self.steps) < 2:
            print("Need more data for dashboard")
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Loss curve with trend
        axes[0, 0].plot(self.steps, self.losses, 'b-', linewidth=2, marker='o', markersize=4)
        axes[0, 0].set_title('📉 Loss Curve', fontsize=14, weight='bold')
        axes[0, 0].set_xlabel('Steps')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Gradient health with zones
        axes[0, 1].plot(self.steps, self.grad_norms, 'g-', linewidth=2, marker='s', markersize=4)
        axes[0, 1].axhspan(0.1, 10, alpha=0.2, color='green', label='Healthy Zone')
        axes[0, 1].axhspan(10, 1000, alpha=0.2, color='red', label='Danger Zone')
        axes[0, 1].axhspan(1e-6, 0.1, alpha=0.2, color='orange', label='Vanishing Zone')
        axes[0, 1].set_title('🌡️ Gradient Health', fontsize=14, weight='bold')
        axes[0, 1].set_xlabel('Steps')
        axes[0, 1].set_ylabel('Gradient Norm')
        axes[0, 1].set_yscale('log')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
        
        # Learning rate schedule
        axes[1, 0].plot(self.steps, self.learning_rates, 'r-', linewidth=2, marker='^', markersize=4)
        axes[1, 0].set_title('📊 Learning Rate', fontsize=14, weight='bold')
        axes[1, 0].set_xlabel('Steps')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Update magnitude (key diagnostic)
        axes[1, 1].plot(self.steps, self.grad_param_ratios, 'm-', linewidth=2, marker='d', markersize=4)
        axes[1, 1].axhspan(1e-5, 0.1, alpha=0.2, color='green', label='Healthy Zone')
        axes[1, 1].axhspan(0.1, 10, alpha=0.2, color='red', label='Too High')
        axes[1, 1].set_title('⚖️ Update Magnitude (Grad/Param Ratio)', fontsize=14, weight='bold')
        axes[1, 1].set_xlabel('Steps')
        axes[1, 1].set_ylabel('Ratio')
        axes[1, 1].set_yscale('log')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print current status
        print(f"\n📊 HEALTH SUMMARY (Step {self.steps[-1]}):")
        print(f"   Status: {self._diagnose_current_health()}")
        print(f"   Loss: {self.losses[-1]:.4f}")
        print(f"   Grad Norm: {self.grad_norms[-1]:.4f}")
        print(f"   Learning Rate: {self.learning_rates[-1]:.6f}")
        print(f"   Update Ratio: {self.grad_param_ratios[-1]:.6f}")

print("Health monitoring system ready! 🏥")

In [None]:
# Demo: Monitor Healthy Training in Real-Time

print("🏥 DEMONSTRATING REAL-TIME HEALTH MONITORING")
print("=" * 50)

# Setup healthy training
model = GPTModel(**config).to(device)
monitor = TrainingHealthMonitor()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# Training loop with monitoring
for step in range(30):
    x, targets = create_dummy_batch()
    
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
    loss.backward()
    
    # Apply gradient clipping (good practice!)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    # Monitor health before stepping
    diagnosis = monitor.check_vitals(model, loss, optimizer, step)
    
    # Emergency stop check
    if monitor.emergency_stop_needed():
        print(f"🚨 EMERGENCY STOP at step {step}!")
        break
    
    optimizer.step()
    
    # Report every 5 steps
    if step % 5 == 0:
        print(f"Step {step}: Loss = {loss.item():.4f} | {diagnosis}")

# Show comprehensive dashboard
monitor.plot_dashboard()

print("\n✅ This is what healthy, monitored training looks like!")
print("Notice how all metrics stay in healthy ranges.")

## 3. Loss Curve Interpretation

Loss curves tell the story of your training. Learn to read the signs.

### The Mathematics of Loss Patterns

**Healthy Decreasing Loss**: L(t) ∝ e^(-t/τ) where τ is the time constant
- Exponential decay toward minimum
- Smooth curve with low noise
- Consistent rate of improvement

**Oscillating Loss**: L(t) = baseline + A·sin(ωt) + noise
- Learning rate too high causes overshooting
- Model bounces around minimum
- High frequency variations

**Plateau Loss**: dL/dt ≈ 0
- Learning rate too low (can't escape local minimum)
- Model capacity exhausted
- Need architectural changes

**Exploding Loss**: L(t) ∝ e^(t/τ_explode)
- Unstable dynamics
- Gradients grow exponentially
- Training becomes chaotic

Let's analyze different loss patterns:

In [None]:
class LossCurveAnalyzer:
    """Analyze and interpret loss curve patterns."""
    
    @staticmethod
    def analyze_pattern(losses: List[float], window_size: int = 10) -> Dict[str, str]:
        """Analyze loss curve and provide diagnosis."""
        if len(losses) < window_size:
            return {"status": "Insufficient data for analysis"}
        
        analysis = {}
        
        # Overall trend analysis
        start_avg = np.mean(losses[:window_size])
        end_avg = np.mean(losses[-window_size:])
        improvement = (start_avg - end_avg) / start_avg
        
        if improvement > 0.2:
            analysis["trend"] = "✅ EXCELLENT - Strong improvement"
        elif improvement > 0.05:
            analysis["trend"] = "✅ GOOD - Steady improvement"
        elif improvement > 0.01:
            analysis["trend"] = "⚠️ SLOW - Marginal improvement"
        elif improvement < -0.01:
            analysis["trend"] = "🚨 BAD - Loss increasing"
        else:
            analysis["trend"] = "📊 FLAT - No clear trend"
        
        # Stability analysis
        recent_losses = losses[-window_size:]
        volatility = np.std(recent_losses) / np.mean(recent_losses)
        
        if volatility > 0.2:
            analysis["stability"] = "🌊 VERY NOISY - High variance"
        elif volatility > 0.05:
            analysis["stability"] = "🌀 NOISY - Some oscillation"
        else:
            analysis["stability"] = "📉 SMOOTH - Stable convergence"
        
        # Oscillation detection
        if len(losses) > 20:
            recent = losses[-20:]
            direction_changes = 0
            for i in range(1, len(recent) - 1):
                # Count local maxima and minima
                if (recent[i] > recent[i-1] and recent[i] > recent[i+1]) or \
                   (recent[i] < recent[i-1] and recent[i] < recent[i+1]):
                    direction_changes += 1
            
            oscillation_rate = direction_changes / len(recent)
            if oscillation_rate > 0.3:
                analysis["pattern"] = "🌀 OSCILLATING - Learning rate too high"
            elif oscillation_rate > 0.1:
                analysis["pattern"] = "📈 BOUNCY - Some instability"
            else:
                analysis["pattern"] = "➡️ MONOTONIC - Smooth progress"
        
        # Health check
        if any(not np.isfinite(l) for l in losses[-5:]):
            analysis["health"] = "💀 BROKEN - NaN or infinite loss"
        elif max(losses[-5:]) > 100:
            analysis["health"] = "🚨 EXPLODING - Loss too high"
        elif min(losses[-5:]) < 1e-6:
            analysis["health"] = "🎯 CONVERGED - Loss very low"
        else:
            analysis["health"] = "✅ HEALTHY - Normal values"
        
        return analysis
    
    @staticmethod
    def generate_example_patterns():
        """Generate example loss curves for different scenarios."""
        steps = np.arange(100)
        patterns = {}
        
        # Healthy exponential decay
        healthy = 4.0 * np.exp(-steps / 25) + 0.5 + 0.05 * np.random.randn(100)
        patterns["Healthy Learning"] = np.maximum(healthy, 0.1)
        
        # Overfitting (U-shape)
        overfitting = 4.0 * np.exp(-steps / 15) + 0.5
        upturn = np.where(steps > 50, 0.03 * (steps - 50) ** 1.2, 0)
        overfitting += upturn + 0.08 * np.random.randn(100)
        patterns["Overfitting"] = np.maximum(overfitting, 0.1)
        
        # Learning rate too high (oscillating)
        oscillating = 2.5 + 0.8 * np.sin(steps * 0.4) * np.exp(-steps / 50)
        oscillating += 0.1 * np.random.randn(100)
        patterns["LR Too High"] = np.maximum(oscillating, 0.1)
        
        # Learning rate too low (plateau)
        plateau = 4.0 * np.exp(-steps / 80) + 2.5 + 0.05 * np.random.randn(100)
        patterns["LR Too Low"] = np.maximum(plateau, 0.1)
        
        # Exploding gradients
        exploding = 2.0 + np.where(steps > 20, 0.3 * (steps - 20) ** 1.3, 0)
        exploding += 0.1 * np.random.randn(100)
        patterns["Exploding"] = np.clip(exploding, 0.1, 30)
        
        # No learning (flat)
        flat = 4.5 + 0.1 * np.random.randn(100)
        patterns["No Learning"] = np.maximum(flat, 0.1)
        
        return patterns

# Generate and analyze example patterns
analyzer = LossCurveAnalyzer()
example_patterns = analyzer.generate_example_patterns()

# Visualize all patterns with analysis
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for i, (name, curve) in enumerate(example_patterns.items()):
    if i < len(axes):
        # Plot the curve
        axes[i].plot(curve, linewidth=2, color=f'C{i}')
        axes[i].set_title(f'{name}', fontsize=14, weight='bold')
        axes[i].set_xlabel('Training Step')
        axes[i].set_ylabel('Loss')
        axes[i].grid(True, alpha=0.3)
        
        # Add analysis as text box
        analysis = analyzer.analyze_pattern(curve.tolist())
        analysis_text = '\n'.join([f'{k}: {v.split(" - ")[0]}' for k, v in analysis.items()])
        
        axes[i].text(0.02, 0.98, analysis_text, transform=axes[i].transAxes,
                    verticalalignment='top', fontsize=9, fontfamily='monospace',
                    bbox=dict(boxstyle='round', facecolor='white', alpha=0.9, edgecolor=f'C{i}'))

plt.tight_layout()
plt.show()

print("\n📊 LOSS CURVE INTERPRETATION GUIDE:")
print("=" * 50)
print("✅ Healthy: Smooth exponential decay, low noise")
print("📈 Overfitting: Initial decrease then increase (U-shape)")
print("🌀 Oscillating: Learning rate too high, bouncy loss")
print("📊 Plateau: Learning rate too low or convergence")
print("🚨 Exploding: Unstable dynamics, loss shoots up")
print("😴 No Learning: Flat line, model not learning")

## Summary: Master Transformer Training Diagnostics

You now have a complete diagnostic toolkit for transformer training!

### 🔍 Key Diagnostic Skills

**1. Recognize Failure Signatures**
- **Exploding Gradients**: Loss shoots up, grad norms > 10 → Lower LR, add clipping
- **Vanishing Gradients**: Loss flat, grad norms < 1e-4 → Better init, check architecture  
- **Wrong Learning Rate**: Oscillations (too high) or plateau (too low) → Adjust carefully

**2. Monitor Critical Metrics**
- **Gradient Norm**: Healthy range 0.1-10
- **Grad/Param Ratio**: Should be 1e-6 to 0.1  
- **Loss Trend**: Should decrease smoothly
- **Learning Rate**: Typically 1e-5 to 1e-3 for transformers

**3. Read Loss Curves**
- **Exponential decay** = healthy learning
- **Oscillations** = learning rate too high
- **Plateau** = learning rate too low or convergence
- **U-shape** = overfitting
- **Explosion** = unstable dynamics

### 🚨 Emergency Procedures

**Critical Issues (Stop immediately):**
- NaN or infinite loss
- Gradient explosion (norm > 100)
- Loss > 50

**Emergency Response:**
1. Stop training
2. Restore from checkpoint
3. Reduce learning rate 10x
4. Add gradient clipping
5. Restart carefully

### 💡 Best Practices

- **Always monitor** gradients during training
- **Use gradient clipping** as safety net (max_norm=1.0)
- **Save checkpoints frequently** for recovery
- **Start with proven hyperparameters** then tune
- **Validate regularly** to catch overfitting

### 🎯 What You've Mastered

- **Instant diagnosis** of common training failures
- **Real-time monitoring** with key health metrics
- **Loss curve interpretation** for all patterns
- **Emergency procedures** to save crashed training

With these skills, you can debug any transformer training problem and keep your models healthy! 🛠️