# Training Transformers

This notebook provides an interactive guide to training transformer models effectively, covering optimization strategies, learning rate schedules, and debugging techniques.

## Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import math
from typing import Dict, List, Optional
from collections import defaultdict

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Understanding the Training Process

Let's start by building a simple transformer model to demonstrate training concepts.

In [None]:
class SimpleTransformer(nn.Module):
    """A simplified transformer for demonstration."""
    
    def __init__(self, vocab_size=1000, d_model=128, nhead=8, num_layers=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(1, 512, d_model))
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_feedforward=512, dropout=0.1
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.output_projection = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, mask=None):
        # Embed and add positional encoding
        seq_len = x.size(1)
        x = self.embedding(x)
        x = x + self.pos_encoding[:, :seq_len, :]
        
        # Transformer blocks
        x = x.transpose(0, 1)  # (seq_len, batch, features)
        x = self.transformer(x, src_key_padding_mask=mask)
        x = x.transpose(0, 1)  # (batch, seq_len, features)
        
        # Output projection
        return self.output_projection(x)

# Create model
model = SimpleTransformer().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 2. Data Preparation and Loading

Efficient data loading is crucial for training performance.

In [None]:
class SimpleDataset(Dataset):
    """Simple dataset for demonstration."""
    
    def __init__(self, num_samples=1000, seq_length=50, vocab_size=1000):
        self.data = torch.randint(1, vocab_size, (num_samples, seq_length))
        self.targets = torch.randint(1, vocab_size, (num_samples, seq_length))
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.data[idx],
            'labels': self.targets[idx]
        }

# Create datasets
train_dataset = SimpleDataset(num_samples=1000)
val_dataset = SimpleDataset(num_samples=200)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

## 3. Learning Rate Schedules

Let's visualize different learning rate schedules commonly used for transformers.

In [None]:
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """Linear warmup and linear decay."""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / 
                   float(max(1, num_training_steps - num_warmup_steps)))
    return LambdaLR(optimizer, lr_lambda)

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """Cosine learning rate schedule with warmup."""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(
            max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return LambdaLR(optimizer, lr_lambda)

def get_inverse_sqrt_schedule(optimizer, num_warmup_steps):
    """Inverse square root schedule (original Transformer)."""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return float(num_warmup_steps) ** 0.5 / float(max(current_step, 1)) ** 0.5
    return LambdaLR(optimizer, lr_lambda)

# Visualize schedules
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Create dummy optimizer
dummy_model = nn.Linear(10, 10)
lr = 1e-3
warmup_steps = 100
total_steps = 1000

schedules = [
    ('Linear', get_linear_schedule_with_warmup),
    ('Cosine', get_cosine_schedule_with_warmup),
    ('Inverse Sqrt', lambda opt, warm, total: get_inverse_sqrt_schedule(opt, warm))
]

for ax, (name, schedule_fn) in zip(axes, schedules):
    optimizer = torch.optim.Adam(dummy_model.parameters(), lr=lr)
    
    if name == 'Inverse Sqrt':
        scheduler = schedule_fn(optimizer, warmup_steps, total_steps)
    else:
        scheduler = schedule_fn(optimizer, warmup_steps, total_steps)
    
    lrs = []
    for step in range(total_steps):
        lrs.append(optimizer.param_groups[0]['lr'])
        scheduler.step()
    
    ax.plot(lrs)
    ax.set_title(f'{name} Schedule')
    ax.set_xlabel('Step')
    ax.set_ylabel('Learning Rate')
    ax.axvline(x=warmup_steps, color='r', linestyle='--', alpha=0.5, label='Warmup End')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Optimizer Configuration

AdamW is the standard optimizer for transformers. Let's explore its configuration.

In [None]:
def configure_optimizer(model, learning_rate=3e-4, weight_decay=0.01):
    """Configure AdamW optimizer with proper weight decay."""
    
    # Separate parameters that should and shouldn't have weight decay
    no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
    
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() 
                      if not any(nd in n for nd in no_decay) and p.requires_grad],
            'weight_decay': weight_decay,
        },
        {
            'params': [p for n, p in model.named_parameters() 
                      if any(nd in n for nd in no_decay) and p.requires_grad],
            'weight_decay': 0.0,
        }
    ]
    
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=learning_rate,
        betas=(0.9, 0.999),
        eps=1e-8
    )
    
    return optimizer

# Configure optimizer
optimizer = configure_optimizer(model)

# Print parameter groups
for i, group in enumerate(optimizer.param_groups):
    print(f"Group {i}: {len(group['params'])} parameters, weight_decay={group['weight_decay']}")

# Count parameters by type
param_counts = defaultdict(int)
for name, param in model.named_parameters():
    if 'bias' in name:
        param_counts['bias'] += param.numel()
    elif 'LayerNorm' in name or 'layer_norm' in name:
        param_counts['layer_norm'] += param.numel()
    else:
        param_counts['other'] += param.numel()

print("\nParameter distribution:")
for ptype, count in param_counts.items():
    print(f"  {ptype}: {count:,}")

## 5. Training Loop with Gradient Clipping

Let's implement a training loop with proper gradient clipping and monitoring.

In [None]:
def train_step(model, batch, optimizer, clip_value=1.0):
    """Single training step with gradient clipping."""
    model.train()
    
    # Forward pass
    input_ids = batch['input_ids'].to(device)
    labels = batch['labels'].to(device)
    
    outputs = model(input_ids)
    loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1))
    
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    
    # Gradient clipping
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
    
    # Optimizer step
    optimizer.step()
    
    return loss.item(), grad_norm.item()

def evaluate(model, dataloader):
    """Evaluate the model."""
    model.eval()
    total_loss = 0
    total_tokens = 0
    
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids)
            loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1))
            
            total_loss += loss.item() * input_ids.size(0) * input_ids.size(1)
            total_tokens += input_ids.size(0) * input_ids.size(1)
    
    return total_loss / total_tokens

# Training loop with monitoring
num_epochs = 3
train_losses = []
val_losses = []
grad_norms = []
learning_rates = []

# Setup scheduler
total_steps = len(train_loader) * num_epochs
scheduler = get_cosine_schedule_with_warmup(optimizer, 
                                           num_warmup_steps=100, 
                                           num_training_steps=total_steps)

# Training
step = 0
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Training
    pbar = tqdm(train_loader, desc="Training")
    epoch_losses = []
    epoch_grad_norms = []
    
    for batch in pbar:
        loss, grad_norm = train_step(model, batch, optimizer)
        scheduler.step()
        
        epoch_losses.append(loss)
        epoch_grad_norms.append(grad_norm)
        learning_rates.append(optimizer.param_groups[0]['lr'])
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss:.4f}',
            'grad_norm': f'{grad_norm:.2f}',
            'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'
        })
        
        step += 1
    
    # Store metrics
    train_losses.extend(epoch_losses)
    grad_norms.extend(epoch_grad_norms)
    
    # Validation
    val_loss = evaluate(model, val_loader)
    val_losses.append(val_loss)
    
    print(f"Epoch {epoch + 1} - Train Loss: {np.mean(epoch_losses):.4f}, "
          f"Val Loss: {val_loss:.4f}, Avg Grad Norm: {np.mean(epoch_grad_norms):.2f}")

## 6. Visualizing Training Metrics

Let's visualize the training progress to identify any issues.

In [None]:
# Create subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot training loss
ax = axes[0, 0]
ax.plot(train_losses, alpha=0.7)
ax.set_title('Training Loss')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.grid(True, alpha=0.3)

# Plot validation loss
ax = axes[0, 1]
val_steps = np.linspace(0, len(train_losses), len(val_losses))
ax.plot(val_steps, val_losses, 'o-', color='red', label='Validation')
ax.set_title('Validation Loss')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.grid(True, alpha=0.3)

# Plot gradient norms
ax = axes[1, 0]
ax.plot(grad_norms, alpha=0.7, color='green')
ax.axhline(y=1.0, color='r', linestyle='--', label='Clip threshold')
ax.set_title('Gradient Norms')
ax.set_xlabel('Step')
ax.set_ylabel('Gradient Norm')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot learning rate
ax = axes[1, 1]
ax.plot(learning_rates, color='orange')
ax.set_title('Learning Rate Schedule')
ax.set_xlabel('Step')
ax.set_ylabel('Learning Rate')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print summary statistics
print("\nTraining Summary:")
print(f"Final train loss: {train_losses[-1]:.4f}")
print(f"Final validation loss: {val_losses[-1]:.4f}")
print(f"Average gradient norm: {np.mean(grad_norms):.2f}")
print(f"Max gradient norm: {np.max(grad_norms):.2f}")
print(f"Gradient clipping activated: {sum(g > 1.0 for g in grad_norms)} times")

## 7. Mixed Precision Training

Mixed precision training can significantly speed up training and reduce memory usage.

In [None]:
from torch.cuda.amp import autocast, GradScaler

def train_step_mixed_precision(model, batch, optimizer, scaler, clip_value=1.0):
    """Training step with mixed precision."""
    model.train()
    
    input_ids = batch['input_ids'].to(device)
    labels = batch['labels'].to(device)
    
    optimizer.zero_grad()
    
    # Mixed precision forward pass
    with autocast():
        outputs = model(input_ids)
        loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1))
    
    # Scaled backward pass
    scaler.scale(loss).backward()
    
    # Unscale gradients for clipping
    scaler.unscale_(optimizer)
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
    
    # Optimizer step with scaling
    scaler.step(optimizer)
    scaler.update()
    
    return loss.item(), grad_norm.item()

# Demonstrate mixed precision
if device.type == 'cuda':
    print("Mixed precision training available!")
    
    # Create scaler
    scaler = GradScaler()
    
    # Train for a few steps
    model_mp = SimpleTransformer().to(device)
    optimizer_mp = configure_optimizer(model_mp)
    
    mp_losses = []
    for i, batch in enumerate(train_loader):
        if i >= 10:  # Just a few steps for demo
            break
        loss, grad_norm = train_step_mixed_precision(model_mp, batch, optimizer_mp, scaler)
        mp_losses.append(loss)
    
    print(f"\nMixed precision training losses: {mp_losses[:5]}")
    print(f"Average loss: {np.mean(mp_losses):.4f}")
else:
    print("Mixed precision requires CUDA device")

## 8. Gradient Accumulation

Gradient accumulation allows training with larger effective batch sizes.

In [None]:
def train_with_gradient_accumulation(model, dataloader, optimizer, accumulation_steps=4):
    """Training with gradient accumulation."""
    model.train()
    accumulated_loss = 0
    
    optimizer.zero_grad()
    
    for i, batch in enumerate(dataloader):
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(input_ids)
        loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1))
        
        # Scale loss by accumulation steps
        loss = loss / accumulation_steps
        loss.backward()
        
        accumulated_loss += loss.item()
        
        # Update weights every accumulation_steps
        if (i + 1) % accumulation_steps == 0:
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Optimizer step
            optimizer.step()
            optimizer.zero_grad()
            
            print(f"Step {(i + 1) // accumulation_steps}: "
                  f"Loss = {accumulated_loss:.4f} "
                  f"(effective batch size = {accumulation_steps * dataloader.batch_size})")
            
            accumulated_loss = 0
            
        if i >= 20:  # Demo only
            break

# Demonstrate gradient accumulation
print("Training with gradient accumulation:")
model_ga = SimpleTransformer().to(device)
optimizer_ga = configure_optimizer(model_ga)

train_with_gradient_accumulation(model_ga, train_loader, optimizer_ga, accumulation_steps=4)

## 9. Monitoring Training Health

Let's implement tools to monitor and diagnose training issues.

In [None]:
def analyze_gradients(model):
    """Analyze gradient statistics across layers."""
    gradient_stats = {}
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad = param.grad.data
            gradient_stats[name] = {
                'mean': grad.mean().item(),
                'std': grad.std().item(),
                'max': grad.abs().max().item(),
                'norm': grad.norm().item(),
            }
    
    return gradient_stats

def check_gradient_flow(model):
    """Check if gradients are flowing properly."""
    # Forward and backward pass
    dummy_input = torch.randint(0, 1000, (2, 10)).to(device)
    dummy_target = torch.randint(0, 1000, (2, 10)).to(device)
    
    output = model(dummy_input)
    loss = F.cross_entropy(output.view(-1, output.size(-1)), dummy_target.view(-1))
    loss.backward()
    
    # Analyze gradients
    grad_stats = analyze_gradients(model)
    
    # Visualize gradient norms by layer
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Extract layer names and gradient norms
    layer_names = []
    grad_norms = []
    grad_means = []
    
    for name, stats in grad_stats.items():
        # Simplify layer names
        if 'weight' in name:
            layer_names.append(name.replace('.weight', ''))
            grad_norms.append(stats['norm'])
            grad_means.append(abs(stats['mean']))
    
    # Plot gradient norms
    ax1.bar(range(len(grad_norms)), grad_norms)
    ax1.set_xlabel('Layer')
    ax1.set_ylabel('Gradient Norm')
    ax1.set_title('Gradient Norms by Layer')
    ax1.set_xticks(range(len(layer_names)))
    ax1.set_xticklabels(layer_names, rotation=45, ha='right')
    
    # Plot gradient means (log scale)
    ax2.bar(range(len(grad_means)), grad_means)
    ax2.set_xlabel('Layer')
    ax2.set_ylabel('|Gradient Mean|')
    ax2.set_title('Gradient Means by Layer (abs value)')
    ax2.set_yscale('log')
    ax2.set_xticks(range(len(layer_names)))
    ax2.set_xticklabels(layer_names, rotation=45, ha='right')
    
    plt.tight_layout()
    plt.show()
    
    # Check for potential issues
    issues = []
    for name, stats in grad_stats.items():
        if stats['norm'] < 1e-7:
            issues.append(f"Very small gradients in {name}")
        elif stats['norm'] > 100:
            issues.append(f"Very large gradients in {name}")
            
    if issues:
        print("\nPotential issues detected:")
        for issue in issues:
            print(f"  - {issue}")
    else:
        print("\nGradient flow looks healthy!")
    
    return grad_stats

# Check gradient flow
print("Analyzing gradient flow...")
grad_stats = check_gradient_flow(model)

## 10. Learning Rate Finder

Finding the optimal learning rate is crucial. Let's implement a learning rate finder.

In [None]:
def find_learning_rate(model, dataloader, start_lr=1e-7, end_lr=1, num_iter=100):
    """Find optimal learning rate using the LR range test."""
    model_copy = SimpleTransformer().to(device)  # Fresh model
    model_copy.load_state_dict(model.state_dict())
    
    optimizer = torch.optim.Adam(model_copy.parameters(), lr=start_lr)
    
    lrs = []
    losses = []
    
    # Exponential learning rate schedule
    lr_lambda = lambda x: math.exp(x * math.log(end_lr / start_lr) / num_iter)
    scheduler = LambdaLR(optimizer, lr_lambda)
    
    data_iter = iter(dataloader)
    smooth_loss = None
    
    for iteration in range(num_iter):
        # Get batch
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch = next(data_iter)
        
        # Training step
        model_copy.train()
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        outputs = model_copy(input_ids)
        loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1))
        
        # Smooth the loss
        if smooth_loss is None:
            smooth_loss = loss.item()
        else:
            smooth_loss = 0.98 * smooth_loss + 0.02 * loss.item()
        
        # Record
        lrs.append(optimizer.param_groups[0]['lr'])
        losses.append(smooth_loss)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        # Stop if loss explodes
        if smooth_loss > 4 * losses[0] or math.isnan(smooth_loss):
            break
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.title('Learning Rate Finder')
    plt.grid(True, alpha=0.3)
    
    # Find suggested LR (steepest descent)
    min_grad_idx = np.argmin(np.gradient(losses))
    suggested_lr = lrs[min_grad_idx]
    plt.axvline(x=suggested_lr, color='r', linestyle='--', 
                label=f'Suggested LR: {suggested_lr:.2e}')
    plt.legend()
    plt.show()
    
    return lrs, losses, suggested_lr

# Find optimal learning rate
print("Running learning rate finder...")
lrs, losses, suggested_lr = find_learning_rate(model, train_loader)
print(f"\nSuggested learning rate: {suggested_lr:.2e}")

## 11. Debugging Training Issues

Let's create tools to diagnose common training problems.

In [None]:
class TrainingDebugger:
    """Tools for debugging transformer training."""
    
    def __init__(self, model):
        self.model = model
        self.activation_stats = {}
        self.gradient_stats = {}
        self.hooks = []
        
    def register_hooks(self):
        """Register forward and backward hooks."""
        def forward_hook(module, input, output):
            if isinstance(output, torch.Tensor):
                self.activation_stats[module] = {
                    'mean': output.mean().item(),
                    'std': output.std().item(),
                    'max': output.abs().max().item(),
                }
        
        def backward_hook(module, grad_input, grad_output):
            if isinstance(grad_output[0], torch.Tensor):
                self.gradient_stats[module] = {
                    'mean': grad_output[0].mean().item(),
                    'std': grad_output[0].std().item(),
                    'max': grad_output[0].abs().max().item(),
                }
        
        for module in self.model.modules():
            if isinstance(module, (nn.Linear, nn.LayerNorm)):
                self.hooks.append(module.register_forward_hook(forward_hook))
                self.hooks.append(module.register_backward_hook(backward_hook))
    
    def remove_hooks(self):
        """Remove all hooks."""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
    
    def diagnose(self, dataloader, num_batches=5):
        """Run diagnostic forward and backward passes."""
        self.register_hooks()
        
        issues = []
        
        for i, batch in enumerate(dataloader):
            if i >= num_batches:
                break
                
            # Forward pass
            input_ids = batch['input_ids'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = self.model(input_ids)
            loss = F.cross_entropy(outputs.view(-1, outputs.size(-1)), labels.view(-1))
            
            # Backward pass
            loss.backward()
            
            # Check for issues
            for module, stats in self.activation_stats.items():
                if math.isnan(stats['mean']) or math.isinf(stats['mean']):
                    issues.append(f"NaN/Inf activations in {module.__class__.__name__}")
                elif stats['std'] < 1e-6:
                    issues.append(f"Dead neurons in {module.__class__.__name__} (std={stats['std']:.2e})")
            
            for module, stats in self.gradient_stats.items():
                if math.isnan(stats['mean']) or math.isinf(stats['mean']):
                    issues.append(f"NaN/Inf gradients in {module.__class__.__name__}")
                elif stats['max'] > 100:
                    issues.append(f"Exploding gradients in {module.__class__.__name__} (max={stats['max']:.2f})")
                elif stats['max'] < 1e-6:
                    issues.append(f"Vanishing gradients in {module.__class__.__name__} (max={stats['max']:.2e})")
            
            # Clear gradients
            self.model.zero_grad()
        
        self.remove_hooks()
        
        return list(set(issues))  # Remove duplicates
    
    def plot_activation_distribution(self):
        """Plot activation statistics."""
        if not self.activation_stats:
            print("No activation statistics available. Run diagnose() first.")
            return
        
        layers = []
        means = []
        stds = []
        
        for module, stats in self.activation_stats.items():
            layers.append(module.__class__.__name__)
            means.append(stats['mean'])
            stds.append(stats['std'])
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        
        # Plot means
        ax1.bar(range(len(means)), means)
        ax1.set_xlabel('Layer')
        ax1.set_ylabel('Activation Mean')
        ax1.set_title('Activation Means by Layer')
        ax1.set_xticks(range(len(layers)))
        ax1.set_xticklabels(layers, rotation=45, ha='right')
        
        # Plot stds
        ax2.bar(range(len(stds)), stds)
        ax2.set_xlabel('Layer')
        ax2.set_ylabel('Activation Std')
        ax2.set_title('Activation Stds by Layer')
        ax2.set_xticks(range(len(layers)))
        ax2.set_xticklabels(layers, rotation=45, ha='right')
        
        plt.tight_layout()
        plt.show()

# Run diagnostics
print("Running training diagnostics...")
debugger = TrainingDebugger(model)
issues = debugger.diagnose(train_loader)

if issues:
    print("\nIssues detected:")
    for issue in issues:
        print(f"  - {issue}")
else:
    print("\nNo issues detected!")

# Plot activation distribution
debugger.plot_activation_distribution()

## Summary and Best Practices

### Key Takeaways:

1. **Learning Rate Schedule**: Use warmup + cosine/linear decay
2. **Optimizer**: AdamW with proper weight decay settings
3. **Gradient Clipping**: Essential for stable training (typically 1.0)
4. **Mixed Precision**: 2x speedup with minimal accuracy loss
5. **Gradient Accumulation**: Simulate larger batch sizes
6. **Monitoring**: Track loss, gradients, and learning rate

### Common Issues and Solutions:

| Issue | Symptoms | Solutions |
|-------|----------|----------|
| Loss explosion | NaN loss, huge gradients | Lower LR, gradient clipping |
| Slow convergence | Loss plateaus | Increase LR, check data |
| Overfitting | Train << Val loss | Dropout, weight decay, more data |
| Unstable training | Loss spikes | LR warmup, better initialization |

### Training Recipe:

1. Start with small model/data to debug
2. Use learning rate finder
3. Monitor gradients and activations
4. Scale up gradually
5. Save checkpoints frequently