# Training Optimization Fundamentals

This notebook teaches you the essential techniques to train transformers efficiently and stably. We'll focus on the most impactful optimizations that every practitioner needs to know.

## Why Training Optimization Matters

When training neural networks, especially transformers, you'll encounter several fundamental challenges:

1. **Learning Rate Problems**: Too high causes divergence, too low is inefficient
2. **Gradient Issues**: Gradients can explode or vanish, destroying training
3. **Memory Constraints**: Large models require clever memory management
4. **Training Efficiency**: Modern techniques can provide 2x speedups

## What You'll Learn

1. **Learning Rate Scheduling**: Control learning rate changes over time
2. **Gradient Clipping**: Prevent gradient explosions
3. **Mixed Precision**: Train faster with 16-bit floats
4. **Memory Optimization**: Train larger models with less memory

Let's start with the foundations.

In [None]:
import sys
import os
sys.path.append('..')

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import (
    LambdaLR, CosineAnnealingLR, OneCycleLR, 
    ReduceLROnPlateau, StepLR
)
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import time
import psutil
import warnings
warnings.filterwarnings('ignore')

# Import our transformer components
from src.model.transformer import GPTModel
from src.data.tokenizer import CharacterTokenizer

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 1. Learning Rate Scheduling

### The Problem with Fixed Learning Rates

Imagine you're looking for the lowest point in a valley while blindfolded. If you take huge steps, you'll overshoot and bounce around. If you take tiny steps, you'll barely move. Learning rate scheduling solves this by:

1. **Starting small** (warmup): Prevents early instability
2. **Increasing gradually**: Allows faster learning once stable
3. **Decreasing over time**: Fine-tunes the solution

### The Science Behind It

Neural networks are sensitive to learning rates because:
- **Too high**: Parameters oscillate wildly, loss explodes
- **Too low**: Training is painfully slow, gets stuck in bad regions
- **Just right**: Smooth convergence to good solutions

### Essential Schedules

1. **Warmup + Cosine**: Linear increase, then smooth decay (best for transformers)
2. **OneCycle**: Single peak, good for fast training
3. **Step Decay**: Sudden drops at intervals (simple but effective)

Now let's see these in action:

In [None]:
# Simple learning rate schedule implementations
import sys
sys.path.append('..')
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
import numpy as np
import matplotlib.pyplot as plt

def warmup_cosine_schedule(step, warmup_steps, total_steps, base_lr=1e-4, min_lr=1e-6):
    """The most important schedule for transformers."""
    if step < warmup_steps:
        # Linear warmup: gradually increase from 0 to base_lr
        return base_lr * (step + 1) / warmup_steps
    else:
        # Cosine decay: smooth decrease to min_lr
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return min_lr + (base_lr - min_lr) * 0.5 * (1 + np.cos(np.pi * progress))

# Visualize the schedules
total_steps = 5000
warmup_steps = 500
steps = np.arange(total_steps)

# Calculate different schedules
warmup_cosine = [warmup_cosine_schedule(s, warmup_steps, total_steps) for s in steps]

# Create a dummy model to get PyTorch scheduler curves
model = nn.Linear(10, 1)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# OneCycle schedule
onecycle_scheduler = OneCycleLR(optimizer, max_lr=5e-4, total_steps=total_steps)
onecycle_lrs = []
for _ in range(total_steps):
    onecycle_lrs.append(optimizer.param_groups[0]['lr'])
    optimizer.step()
    onecycle_scheduler.step()

# Reset for cosine
optimizer = optim.Adam(model.parameters(), lr=1e-4)
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)
cosine_lrs = []
for _ in range(total_steps):
    cosine_lrs.append(optimizer.param_groups[0]['lr'])
    optimizer.step()
    cosine_scheduler.step()

# Plot comparison
plt.figure(figsize=(12, 6))
plt.plot(steps, warmup_cosine, label='Warmup + Cosine (Recommended)', linewidth=3)
plt.plot(steps, onecycle_lrs, label='OneCycle', linewidth=2)
plt.plot(steps, cosine_lrs, label='Pure Cosine', linewidth=2)

plt.axvline(x=warmup_steps, color='red', linestyle='--', alpha=0.7, label='Warmup End')
plt.title('Learning Rate Schedules Comparison')
plt.xlabel('Training Steps')
plt.ylabel('Learning Rate')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("Key Insights:")
print("• Warmup + Cosine works best for transformers")
print("• Warmup prevents early training instability") 
print("• Cosine decay provides smooth convergence")
print("• OneCycle can be faster but less stable")

## 2. Gradient Clipping

### The Exploding Gradient Problem

During backpropagation, gradients can become extremely large, causing:
- **Parameter updates that are too big**: Model parameters jump wildly
- **Loss spikes**: Training loss suddenly shoots up to infinity
- **Training collapse**: Model becomes impossible to train

Think of it like driving a car: if you turn the steering wheel too hard, you'll crash.

### How Gradient Clipping Works

Gradient clipping constrains the magnitude of gradients:

1. **Calculate gradient norm**: √(sum of all squared gradients)
2. **Check if too large**: Compare to threshold (e.g., 1.0)
3. **Scale down if needed**: Multiply all gradients by (threshold / norm)

This preserves the direction but limits the magnitude.

### Why It's Essential for Transformers

Transformers are especially prone to gradient explosions because:
- **Deep networks**: Gradients multiply through many layers
- **Attention mechanism**: Can create very large gradients
- **Residual connections**: Can amplify gradient flow

Let's implement gradient monitoring and clipping:

In [None]:
from src.model.transformer import GPTModel
import warnings
warnings.filterwarnings('ignore')

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

def calculate_gradient_norm(model):
    """Calculate the L2 norm of all gradients."""
    total_norm = 0.0
    for param in model.parameters():
        if param.grad is not None:
            total_norm += param.grad.data.norm(2).item() ** 2
    return total_norm ** 0.5

def demonstrate_gradient_clipping():
    """Show the effect of gradient clipping on training stability."""
    # Create a small transformer for demonstration
    config = {
        'vocab_size': 100,
        'd_model': 64,
        'n_heads': 4,
        'n_layers': 2,
        'd_ff': 128,
        'max_seq_len': 32,
        'dropout': 0.1
    }
    
    clip_values = [None, 1.0, 0.1]  # No clipping, moderate clipping, strong clipping
    
    plt.figure(figsize=(15, 5))
    
    for idx, clip_value in enumerate(clip_values):
        model = GPTModel(config).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)  # High LR to cause instability
        
        grad_norms = []
        losses = []
        
        for step in range(100):
            # Create random batch
            x = torch.randint(0, config['vocab_size'], (4, 16), device=device)
            targets = torch.randint(0, config['vocab_size'], (4, 16), device=device)
            
            # Forward pass
            optimizer.zero_grad()
            logits = model(x)
            loss = nn.CrossEntropyLoss()(logits.reshape(-1, config['vocab_size']), targets.reshape(-1))
            
            # Backward pass
            loss.backward()
            
            # Record gradient norm before clipping
            grad_norm = calculate_gradient_norm(model)
            grad_norms.append(grad_norm)
            losses.append(loss.item())
            
            # Apply clipping if specified
            if clip_value is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
            
            optimizer.step()
        
        # Plot results
        plt.subplot(1, 3, idx + 1)
        plt.plot(grad_norms, alpha=0.8, label='Gradient Norm')
        if clip_value is not None:
            plt.axhline(y=clip_value, color='red', linestyle='--', label=f'Clip threshold: {clip_value}')
            plt.title(f'Clipping: {clip_value}')
        else:
            plt.title('No Clipping')
        
        plt.xlabel('Training Steps')
        plt.ylabel('Gradient Norm')
        plt.yscale('log')
        plt.legend()
        plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("Observations:")
    print("• Without clipping: Gradients can explode (>100)")
    print("• With clipping: Gradients stay bounded") 
    print("• Too aggressive clipping can slow learning")
    print("• Typical values: 0.5-2.0 for transformers")

demonstrate_gradient_clipping()

## 3. Mixed Precision Training

### The Memory and Speed Problem

Training large transformers faces two major bottlenecks:
1. **Memory**: Models can require 20+ GB of GPU memory
2. **Speed**: Training can take days or weeks

### What is Mixed Precision?

Mixed precision uses two number formats:
- **FP32 (32-bit floats)**: High precision for critical operations
- **FP16 (16-bit floats)**: Lower precision for most operations

This provides:
- **2x memory reduction**: Store activations and gradients in 16-bit
- **1.5-2x speed boost**: Modern GPUs have specialized 16-bit units
- **Minimal accuracy loss**: Careful handling preserves model quality

### How It Works

1. **Forward pass**: Compute in FP16 (faster, less memory)
2. **Loss scaling**: Multiply loss to prevent gradient underflow
3. **Backward pass**: Gradients in FP16, but scaled up
4. **Parameter updates**: Unscale and update in FP32 (precision)

### The Challenge: Gradient Underflow

FP16 has a much smaller range than FP32. Very small gradients can become zero, hurting training. Loss scaling solves this by multiplying the loss (and gradients) by a large number before backprop.

Let's see mixed precision in action:

In [None]:
import time

def benchmark_mixed_precision():
    """Compare FP32 vs FP16 training."""
    if not torch.cuda.is_available():
        print("CUDA not available - skipping mixed precision demo")
        return
    
    # Medium-sized model for noticeable differences
    config = {
        'vocab_size': 1000,
        'd_model': 256,
        'n_heads': 8,
        'n_layers': 4,
        'd_ff': 512,
        'max_seq_len': 128,
        'dropout': 0.1
    }
    
    batch_size = 8
    seq_len = 64
    num_steps = 30
    
    results = {}
    
    for precision in ['FP32', 'FP16']:
        print(f"\nTesting {precision}...")
        
        # Create fresh model
        model = GPTModel(config).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        
        # Setup mixed precision if needed
        scaler = torch.cuda.amp.GradScaler() if precision == 'FP16' else None
        
        # Warmup
        for _ in range(3):
            x = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
            targets = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
            
            optimizer.zero_grad()
            if precision == 'FP16':
                with torch.cuda.amp.autocast():
                    outputs = model(x)
                    loss = nn.CrossEntropyLoss()(outputs.reshape(-1, config['vocab_size']), targets.reshape(-1))
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(x)
                loss = nn.CrossEntropyLoss()(outputs.reshape(-1, config['vocab_size']), targets.reshape(-1))
                loss.backward()
                optimizer.step()
        
        # Clear memory and start timing
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        start_time = time.time()
        start_memory = torch.cuda.memory_allocated()
        
        losses = []
        
        # Actual benchmark
        for step in range(num_steps):
            x = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
            targets = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
            
            optimizer.zero_grad()
            
            if precision == 'FP16':
                with torch.cuda.amp.autocast():
                    outputs = model(x)
                    loss = nn.CrossEntropyLoss()(outputs.reshape(-1, config['vocab_size']), targets.reshape(-1))
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(x)
                loss = nn.CrossEntropyLoss()(outputs.reshape(-1, config['vocab_size']), targets.reshape(-1))
                loss.backward()
                optimizer.step()
            
            losses.append(loss.item())
        
        torch.cuda.synchronize()
        end_time = time.time()
        peak_memory = torch.cuda.max_memory_allocated()
        
        # Store results
        results[precision] = {
            'time': end_time - start_time,
            'memory': (peak_memory - start_memory) / 1e9,  # GB
            'final_loss': losses[-1]
        }
        
        print(f"  Time: {results[precision]['time']:.2f}s")
        print(f"  Memory: {results[precision]['memory']:.2f}GB")
        print(f"  Final loss: {results[precision]['final_loss']:.4f}")
        
        # Cleanup
        del model, optimizer
        torch.cuda.empty_cache()
    
    # Calculate improvements
    if 'FP32' in results and 'FP16' in results:
        speedup = results['FP32']['time'] / results['FP16']['time']
        memory_savings = (results['FP32']['memory'] - results['FP16']['memory']) / results['FP32']['memory'] * 100
        
        print(f"\n🚀 Mixed Precision Benefits:")
        print(f"  Speedup: {speedup:.1f}x faster")
        print(f"  Memory savings: {memory_savings:.0f}%")
        print(f"  Quality: Similar final loss")

benchmark_mixed_precision()

## 4. Memory Optimization

### The Memory Crisis

Large transformer models require enormous amounts of GPU memory:
- **Model parameters**: Billions of weights in FP32
- **Activations**: Intermediate values stored for backpropagation  
- **Gradients**: Same size as parameters
- **Optimizer states**: Adam stores momentum and variance

A 1B parameter model needs ~24GB just for the basics!

### Gradient Accumulation: The Simple Solution

Instead of processing large batches at once, accumulate gradients over multiple smaller batches:

1. **Forward pass**: Process small batch (e.g., 2 samples)
2. **Backward pass**: Compute gradients, but don't update yet
3. **Accumulate**: Add gradients to running total
4. **Update**: After N mini-batches, average gradients and update parameters

This simulates large batch training with less memory.

### How Gradient Accumulation Works

```
Normal training (batch=8):
  forward(8 samples) → backward → update

Gradient accumulation (batch=2, steps=4):
  forward(2) → backward → accumulate
  forward(2) → backward → accumulate  
  forward(2) → backward → accumulate
  forward(2) → backward → update (with average)
```

The effective batch size is the same, but peak memory is much lower.

Let's see this in practice:

In [None]:
def demonstrate_gradient_accumulation():
    """Show how gradient accumulation saves memory while maintaining quality."""
    
    # Larger model to see memory differences  
    config = {
        'vocab_size': 1000,
        'd_model': 128,
        'n_heads': 8,
        'n_layers': 4,
        'd_ff': 256,
        'max_seq_len': 64,
        'dropout': 0.1
    }
    
    strategies = {
        'Large Batch': {'batch_size': 16, 'accum_steps': 1},
        'Grad Accum 4x': {'batch_size': 4, 'accum_steps': 4},
        'Grad Accum 8x': {'batch_size': 2, 'accum_steps': 8}
    }
    
    results = {}
    
    for name, params in strategies.items():
        print(f"\nTesting {name}...")
        
        model = GPTModel(config).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        
        losses = []
        peak_memory = 0
        
        for step in range(15):  # Short run for demo
            optimizer.zero_grad()
            step_loss = 0
            
            # Gradient accumulation loop
            for accum_step in range(params['accum_steps']):
                # Create mini-batch
                x = torch.randint(0, config['vocab_size'], (params['batch_size'], 32), device=device)
                targets = torch.randint(0, config['vocab_size'], (params['batch_size'], 32), device=device)
                
                # Forward pass
                outputs = model(x)
                loss = nn.CrossEntropyLoss()(outputs.reshape(-1, config['vocab_size']), targets.reshape(-1))
                
                # Scale loss for accumulation
                loss = loss / params['accum_steps']
                step_loss += loss.item()
                
                # Backward pass
                loss.backward()
                
                # Track memory usage
                if torch.cuda.is_available():
                    current_memory = torch.cuda.memory_allocated() / 1e9
                    peak_memory = max(peak_memory, current_memory)
            
            # Update after accumulation
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            losses.append(step_loss * params['accum_steps'])  # Unscale for comparison
        
        results[name] = {
            'losses': losses,
            'peak_memory': peak_memory,
            'final_loss': losses[-1]
        }
        
        print(f"  Peak memory: {peak_memory:.2f}GB")
        print(f"  Final loss: {losses[-1]:.4f}")
    
    # Plot comparison
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    for name, data in results.items():
        plt.plot(data['losses'], label=name, linewidth=2)
    plt.title('Training Loss (Same Effective Batch Size)')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    names = list(results.keys())
    memories = [results[name]['peak_memory'] for name in names]
    colors = ['red', 'orange', 'green']
    
    bars = plt.bar(names, memories, color=colors, alpha=0.7)
    plt.title('Peak Memory Usage')
    plt.ylabel('Memory (GB)')
    plt.xticks(rotation=45)
    
    # Add value labels on bars
    for bar, memory in zip(bars, memories):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{memory:.2f}GB', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    print("\n💡 Key Insights:")
    print("• Gradient accumulation maintains training quality")
    print("• Memory usage scales with mini-batch size, not effective batch size")
    print("• Essential technique for training large models")

demonstrate_gradient_accumulation()

## 5. Essential Training Recipe

### Putting It All Together

Now you know the four pillars of efficient transformer training. Here's how to combine them into a production-ready training loop:

### The Complete Recipe

1. **Learning Rate**: Warmup + cosine decay
2. **Gradient Clipping**: Clip to 1.0 
3. **Mixed Precision**: Enable for speed and memory
4. **Gradient Accumulation**: Use when memory is limited
5. **Optimizer**: AdamW with weight decay

### Why AdamW?

AdamW (Adam with decoupled Weight decay) is the gold standard because:
- **Adaptive learning rates**: Different rates for each parameter
- **Momentum**: Smooths out noisy gradients
- **Proper weight decay**: Regularizes without interfering with gradients
- **Proven track record**: Used by GPT, BERT, and most successful models

Here's the complete training template:

In [None]:
def create_optimized_trainer(model, total_steps, warmup_steps=1000, lr=1e-4):
    """Create an optimized training setup with all best practices."""
    
    # 1. AdamW optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.95))
    
    # 2. Learning rate scheduler (warmup + cosine)
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.1 + 0.9 * 0.5 * (1 + np.cos(np.pi * progress))
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    # 3. Mixed precision scaler
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    
    return optimizer, scheduler, scaler

def optimized_training_step(model, batch, optimizer, scheduler, scaler, 
                          gradient_accumulation_steps=1, clip_value=1.0):
    """A single optimized training step with all techniques."""
    
    inputs, targets = batch
    
    # Mixed precision forward pass
    if scaler is not None:
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = nn.CrossEntropyLoss()(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
            loss = loss / gradient_accumulation_steps  # Scale for accumulation
        
        # Mixed precision backward pass
        scaler.scale(loss).backward()
    else:
        # Regular precision
        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
        loss = loss / gradient_accumulation_steps
        loss.backward()
    
    return loss.item() * gradient_accumulation_steps  # Return unscaled loss

def optimized_update_step(model, optimizer, scheduler, scaler, clip_value=1.0):
    """Update parameters with clipping and mixed precision handling."""
    
    if scaler is not None:
        # Mixed precision parameter update
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        scaler.step(optimizer)
        scaler.update()
    else:
        # Regular precision parameter update
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
    
    scheduler.step()
    optimizer.zero_grad()

# Demonstrate the complete training loop
def demonstrate_optimized_training():
    """Show the complete optimized training setup."""
    
    # Small model for quick demo
    config = {
        'vocab_size': 500,
        'd_model': 128,
        'n_heads': 4,
        'n_layers': 2,
        'd_ff': 256,
        'max_seq_len': 32,
        'dropout': 0.1
    }
    
    model = GPTModel(config).to(device)
    total_steps = 200
    gradient_accumulation_steps = 4
    
    # Create optimized setup
    optimizer, scheduler, scaler = create_optimized_trainer(model, total_steps)
    
    print("🚀 Starting optimized training...")
    print(f"• Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"• Mixed precision: {'Enabled' if scaler else 'Disabled'}")
    print(f"• Gradient accumulation: {gradient_accumulation_steps}x")
    
    losses = []
    learning_rates = []
    
    for step in range(total_steps):
        step_loss = 0
        
        # Gradient accumulation loop
        for accum_step in range(gradient_accumulation_steps):
            # Create batch
            x = torch.randint(0, config['vocab_size'], (2, 16), device=device)
            targets = torch.randint(0, config['vocab_size'], (2, 16), device=device)
            batch = (x, targets)
            
            # Forward and backward
            loss = optimized_training_step(model, batch, optimizer, scheduler, scaler, 
                                         gradient_accumulation_steps)
            step_loss += loss / gradient_accumulation_steps
        
        # Parameter update
        optimized_update_step(model, optimizer, scheduler, scaler)
        
        # Record metrics
        losses.append(step_loss)
        learning_rates.append(optimizer.param_groups[0]['lr'])
        
        if (step + 1) % 50 == 0:
            print(f"Step {step + 1}: Loss = {step_loss:.4f}, LR = {learning_rates[-1]:.6f}")
    
    # Plot results
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(losses, linewidth=2)
    plt.title('Training Loss')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(learning_rates, linewidth=2, color='orange')
    plt.title('Learning Rate Schedule')
    plt.xlabel('Steps')
    plt.ylabel('Learning Rate')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\n✅ Training completed!")
    print(f"• Final loss: {losses[-1]:.4f}")
    print(f"• Loss reduction: {((losses[0] - losses[-1]) / losses[0] * 100):.1f}%")

demonstrate_optimized_training()

## Summary: Your Training Optimization Checklist

### ✅ Essential Techniques (Use Always)

1. **Learning Rate Scheduling**
   - Use warmup (1000-4000 steps) for stability
   - Apply cosine decay for smooth convergence
   - Typical values: 1e-4 base LR, 1e-6 minimum

2. **Gradient Clipping** 
   - Always clip to prevent explosions
   - Typical values: 0.5-2.0 for transformers
   - Monitor gradient norms to set threshold

3. **Mixed Precision**
   - 1.5-2x speedup with minimal quality loss
   - Essential for large models
   - Use PyTorch's `autocast()` and `GradScaler()`

4. **AdamW Optimizer**
   - Better than Adam for transformers
   - Use weight decay: 0.01-0.1
   - Betas: (0.9, 0.95) for transformers

### 💡 Memory Optimization (Use When Needed)

5. **Gradient Accumulation**
   - Simulate large batches with small memory
   - Essential for training large models
   - Accumulate 4-16 steps typically

### 🚫 Common Mistakes to Avoid

- **No warmup**: Causes early training instability
- **Learning rate too high**: Loss explodes and training fails
- **No gradient clipping**: Gradients explode, destroying training
- **Ignoring memory limits**: OOM errors halt training
- **Using Adam instead of AdamW**: Suboptimal convergence

### 📝 Production Training Template

```python
# 1. Setup
model = GPTModel(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps=1000, num_training_steps=10000)
scaler = torch.cuda.amp.GradScaler()

# 2. Training loop
for step in range(num_steps):
    for accum_step in range(gradient_accumulation_steps):
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = loss_fn(outputs, targets) / gradient_accumulation_steps
        
        scaler.scale(loss).backward()
    
    # 3. Update with clipping
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()
    scheduler.step()
    optimizer.zero_grad()
```

### 🎯 Key Results

With these optimizations, you can expect:
- **2x faster training** (mixed precision)
- **50% less memory usage** (gradient accumulation)
- **More stable training** (gradient clipping + warmup)
- **Better final performance** (AdamW + proper scheduling)

You now have the tools to train transformers efficiently and effectively!

In [None]:
class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing cross entropy loss."""
    
    def __init__(self, smoothing: float = 0.1):
        super().__init__()
        self.smoothing = smoothing
    
    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        # pred: (batch_size * seq_len, vocab_size)
        # target: (batch_size * seq_len,)
        
        vocab_size = pred.size(-1)
        log_probs = torch.log_softmax(pred, dim=-1)
        
        # Convert targets to one-hot with smoothing
        target_one_hot = torch.zeros_like(log_probs)
        target_one_hot.fill_(self.smoothing / (vocab_size - 1))
        target_one_hot.scatter_(1, target.unsqueeze(1), 1.0 - self.smoothing)
        
        # Compute loss
        loss = -(target_one_hot * log_probs).sum(dim=-1).mean()
        return loss

class LayerwiseLROptimizer:
    """Optimizer with layer-wise learning rates."""
    
    def __init__(self, model: nn.Module, base_lr: float = 1e-4, 
                 decay_factor: float = 0.8):
        self.model = model
        self.base_lr = base_lr
        self.decay_factor = decay_factor
        
        # Create parameter groups with different learning rates
        param_groups = self._create_param_groups()
        self.optimizer = optim.AdamW(param_groups, weight_decay=0.01)
    
    def _create_param_groups(self) -> List[Dict]:
        """Create parameter groups with layer-wise learning rates."""
        param_groups = []
        
        # Embedding layers - lowest LR
        embedding_params = []
        if hasattr(self.model, 'embedding'):
            embedding_params.extend(self.model.embedding.parameters())
        if hasattr(self.model, 'pos_embedding'):
            embedding_params.extend(self.model.pos_embedding.parameters())
        
        if embedding_params:
            param_groups.append({
                'params': embedding_params,
                'lr': self.base_lr * (self.decay_factor ** 3)
            })
        
        # Transformer blocks - layer-wise decay
        if hasattr(self.model, 'transformer_blocks'):
            for i, block in enumerate(self.model.transformer_blocks):
                layer_lr = self.base_lr * (self.decay_factor ** (len(self.model.transformer_blocks) - i - 1))
                param_groups.append({
                    'params': list(block.parameters()),
                    'lr': layer_lr
                })
        
        # Output layers - highest LR
        if hasattr(self.model, 'ln_f'):
            param_groups.append({
                'params': list(self.model.ln_f.parameters()),
                'lr': self.base_lr
            })
        if hasattr(self.model, 'lm_head'):
            param_groups.append({
                'params': list(self.model.lm_head.parameters()),
                'lr': self.base_lr
            })
        
        return param_groups
    
    def step(self):
        self.optimizer.step()
    
    def zero_grad(self):
        self.optimizer.zero_grad()
    
    def get_lr_info(self) -> Dict[str, float]:
        """Get learning rate information for each group."""
        lr_info = {}
        group_names = ['embeddings'] + [f'block_{i}' for i in range(len(self.model.transformer_blocks))] + ['output']
        
        for i, group in enumerate(self.optimizer.param_groups):
            if i < len(group_names):
                lr_info[group_names[i]] = group['lr']
        
        return lr_info

def demonstrate_advanced_techniques():
    """Demonstrate advanced training techniques."""
    print("Demonstrating advanced training techniques...")
    
    # Model configuration
    config = {
        'vocab_size': 300,
        'd_model': 128,
        'n_heads': 4,
        'n_layers': 3,
        'd_ff': 256,
        'max_seq_len': 64,
        'dropout': 0.1
    }
    
    # Compare different techniques
    techniques = {
        'Baseline': {
            'loss_fn': nn.CrossEntropyLoss(),
            'optimizer_type': 'standard'
        },
        'Label Smoothing': {
            'loss_fn': LabelSmoothingCrossEntropy(smoothing=0.1),
            'optimizer_type': 'standard'
        },
        'Layer-wise LR': {
            'loss_fn': nn.CrossEntropyLoss(),
            'optimizer_type': 'layerwise'
        },
        'Both': {
            'loss_fn': LabelSmoothingCrossEntropy(smoothing=0.1),
            'optimizer_type': 'layerwise'
        }
    }
    
    results = {}
    num_steps = 150
    
    for tech_name, tech_config in techniques.items():
        print(f"\nTraining with {tech_name}...")
        
        # Create model
        model = GPTModel(config).to(device)
        
        # Create optimizer
        if tech_config['optimizer_type'] == 'layerwise':
            optimizer = LayerwiseLROptimizer(model)
            print(f"  Layer-wise learning rates: {optimizer.get_lr_info()}")
        else:
            optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
        
        loss_fn = tech_config['loss_fn']
        
        losses = []
        accuracies = []
        
        for step in range(num_steps):
            # Create batch
            x = torch.randint(0, config['vocab_size'], (8, 32), device=device)
            targets = torch.randint(0, config['vocab_size'], (8, 32), device=device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(x)
            loss = loss_fn(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
            
            # Compute accuracy
            with torch.no_grad():
                pred_tokens = outputs.argmax(dim=-1)
                accuracy = (pred_tokens == targets).float().mean().item()
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            losses.append(loss.item())
            accuracies.append(accuracy)
        
        results[tech_name] = {
            'losses': losses,
            'accuracies': accuracies
        }
        
        print(f"  Final loss: {losses[-1]:.4f}")
        print(f"  Final accuracy: {accuracies[-1]:.4f}")
    
    # Plot results
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss curves
    for tech_name, data in results.items():
        axes[0, 0].plot(data['losses'], label=tech_name, linewidth=2)
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].set_xlabel('Steps')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].set_yscale('log')
    
    # Accuracy curves
    for tech_name, data in results.items():
        axes[0, 1].plot(data['accuracies'], label=tech_name, linewidth=2)
    axes[0, 1].set_title('Training Accuracy')
    axes[0, 1].set_xlabel('Steps')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Smoothed loss (rolling average)
    window = 20
    for tech_name, data in results.items():
        smoothed = np.convolve(data['losses'], np.ones(window)/window, mode='valid')
        axes[1, 0].plot(range(window-1, len(data['losses'])), smoothed, 
                       label=tech_name, linewidth=2)
    axes[1, 0].set_title(f'Smoothed Loss (window={window})')
    axes[1, 0].set_xlabel('Steps')
    axes[1, 0].set_ylabel('Loss')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Final comparison
    final_losses = [results[tech]['losses'][-1] for tech in techniques.keys()]
    final_accuracies = [results[tech]['accuracies'][-1] for tech in techniques.keys()]
    
    x = np.arange(len(techniques))
    
    ax2 = axes[1, 1]
    color = 'tab:red'
    ax2.set_xlabel('Technique')
    ax2.set_ylabel('Final Loss', color=color)
    bars1 = ax2.bar(x - 0.2, final_losses, 0.4, color=color, alpha=0.7)
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.set_yscale('log')
    
    ax3 = ax2.twinx()
    color = 'tab:blue'
    ax3.set_ylabel('Final Accuracy', color=color)
    bars2 = ax3.bar(x + 0.2, final_accuracies, 0.4, color=color, alpha=0.7)
    ax3.tick_params(axis='y', labelcolor=color)
    
    ax2.set_xticks(x)
    ax2.set_xticklabels(techniques.keys(), rotation=45)
    ax2.set_title('Final Performance Comparison')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print("\nTechnique Performance Summary:")
    print(f"{'Technique':<15} {'Final Loss':<12} {'Final Acc':<12} {'Best Loss':<12}")
    print("-" * 55)
    
    for tech_name, data in results.items():
        final_loss = data['losses'][-1]
        final_acc = data['accuracies'][-1]
        best_loss = min(data['losses'])
        
        print(f"{tech_name:<15} {final_loss:<12.4f} {final_acc:<12.4f} {best_loss:<12.4f}")

# Demonstrate advanced techniques
demonstrate_advanced_techniques()

## 7. Key Takeaways and Best Practices

### Training Optimization Checklist:

#### Learning Rate Scheduling:
- ✅ Use warmup (1000-4000 steps) for transformer training
- ✅ Consider cosine annealing for smooth decay
- ✅ OneCycle can be effective for faster convergence
- ✅ Monitor learning rate vs loss relationship

#### Gradient Management:
- ✅ Always use gradient clipping (0.5-2.0 for transformers)
- ✅ Monitor gradient norms regularly
- ✅ Watch for exploding/vanishing gradients
- ✅ Consider adaptive clipping strategies

#### Memory Optimization:
- ✅ Use mixed precision training when available
- ✅ Implement gradient accumulation for large effective batch sizes
- ✅ Consider gradient checkpointing for memory-limited scenarios
- ✅ Monitor memory usage and optimize accordingly

#### Optimizer Selection:
- ✅ AdamW is generally the best starting point
- ✅ Consider Lion for efficiency
- ✅ Experiment with layer-wise learning rates
- ✅ Use appropriate weight decay (0.01-0.1)

#### Advanced Techniques:
- ✅ Label smoothing can improve generalization
- ✅ Curriculum learning for complex tasks
- ✅ Adaptive dropout scheduling
- ✅ Regular checkpointing and early stopping

### Common Pitfalls to Avoid:
1. **No warmup**: Can cause early training instability
2. **Learning rate too high**: Causes loss spikes and divergence
3. **No gradient clipping**: Exploding gradients destroy training
4. **Ignoring memory usage**: OOM errors halt training
5. **Not monitoring training**: Miss important signals about model health

### Recommended Training Pipeline:
```python
# 1. Setup with proper initialization
model = GPTModel(config)
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps=1000, num_training_steps=10000)

# 2. Enable optimizations
scaler = GradScaler()  # Mixed precision
gradient_accumulation_steps = 4

# 3. Training loop with monitoring
for step in range(num_steps):
    # Forward pass with mixed precision
    with autocast():
        outputs = model(inputs)
        loss = loss_fn(outputs, targets) / gradient_accumulation_steps
    
    # Backward pass
    scaler.scale(loss).backward()
    
    if (step + 1) % gradient_accumulation_steps == 0:
        # Gradient clipping and optimization
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()
```

This completes our deep dive into training optimization! These techniques will help you train transformers more efficiently and effectively.