# Training Optimization Deep Dive

This notebook teaches essential techniques for training transformers efficiently: learning rate scheduling, gradient clipping, mixed precision training, and memory optimization.

In [ ]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, OneCycleLR
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
import time
import warnings
warnings.filterwarnings('ignore')

plt.style.use('default')
sns.set_palette("husl")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [ ]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        Q = self.w_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.w_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.w_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.w_o(attn_output)

class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, n_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        attn_output = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        return x

class GPTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
        self.pos_embedding = nn.Embedding(config['max_seq_len'], config['d_model'])
        
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(config['d_model'], config['n_heads'], config['d_ff'], config['dropout'])
            for _ in range(config['n_layers'])
        ])
        
        self.ln_f = nn.LayerNorm(config['d_model'])
        self.lm_head = nn.Linear(config['d_model'], config['vocab_size'])
        self.dropout = nn.Dropout(config['dropout'])
    
    def forward(self, x):
        seq_len = x.size(1)
        pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)
        
        x = self.embedding(x) + self.pos_embedding(pos_ids)
        x = self.dropout(x)
        
        for block in self.transformer_blocks:
            x = block(x)
        
        x = self.ln_f(x)
        return self.lm_head(x)

## Learning Rate Scheduling

Learning rate scheduling controls how the learning rate changes during training. Warmup prevents early instability, while cosine decay provides smooth convergence. This is essential for transformer training.

In [ ]:
def warmup_cosine_schedule(step, warmup_steps, total_steps, base_lr=1e-4, min_lr=1e-6):
    if step < warmup_steps:
        return base_lr * (step + 1) / warmup_steps
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return min_lr + (base_lr - min_lr) * 0.5 * (1 + np.cos(np.pi * progress))

total_steps = 5000
warmup_steps = 500
steps = np.arange(total_steps)

warmup_cosine = [warmup_cosine_schedule(s, warmup_steps, total_steps) for s in steps]

model = nn.Linear(10, 1)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

onecycle_scheduler = OneCycleLR(optimizer, max_lr=5e-4, total_steps=total_steps)
onecycle_lrs = []
for _ in range(total_steps):
    onecycle_lrs.append(optimizer.param_groups[0]['lr'])
    optimizer.step()
    onecycle_scheduler.step()

optimizer = optim.Adam(model.parameters(), lr=1e-4)
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps)
cosine_lrs = []
for _ in range(total_steps):
    cosine_lrs.append(optimizer.param_groups[0]['lr'])
    optimizer.step()
    cosine_scheduler.step()

plt.figure(figsize=(12, 6))
plt.plot(steps, warmup_cosine, label='Warmup + Cosine (Recommended)', linewidth=3)
plt.plot(steps, onecycle_lrs, label='OneCycle', linewidth=2)
plt.plot(steps, cosine_lrs, label='Pure Cosine', linewidth=2)
plt.axvline(x=warmup_steps, color='red', linestyle='--', alpha=0.7, label='Warmup End')
plt.title('Learning Rate Schedules Comparison')
plt.xlabel('Training Steps')
plt.ylabel('Learning Rate')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("Warmup + Cosine works best for transformers")

## Gradient Clipping

Gradient clipping prevents gradient explosions by constraining the magnitude of gradients to a threshold. This is crucial for stable transformer training, typically using values between 0.5-2.0.

In [ ]:
def calculate_gradient_norm(model):
    total_norm = 0.0
    for param in model.parameters():
        if param.grad is not None:
            total_norm += param.grad.data.norm(2).item() ** 2
    return total_norm ** 0.5

config = {
    'vocab_size': 100,
    'd_model': 64,
    'n_heads': 4,
    'n_layers': 2,
    'd_ff': 128,
    'max_seq_len': 32,
    'dropout': 0.1
}

clip_values = [None, 1.0, 0.1]

plt.figure(figsize=(15, 5))

for idx, clip_value in enumerate(clip_values):
    model = GPTModel(config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    grad_norms = []
    
    for step in range(100):
        x = torch.randint(0, config['vocab_size'], (4, 16), device=device)
        targets = torch.randint(0, config['vocab_size'], (4, 16), device=device)
        
        optimizer.zero_grad()
        logits = model(x)
        loss = nn.CrossEntropyLoss()(logits.reshape(-1, config['vocab_size']), targets.reshape(-1))
        loss.backward()
        
        grad_norm = calculate_gradient_norm(model)
        grad_norms.append(grad_norm)
        
        if clip_value is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        
        optimizer.step()
    
    plt.subplot(1, 3, idx + 1)
    plt.plot(grad_norms, alpha=0.8, label='Gradient Norm')
    if clip_value is not None:
        plt.axhline(y=clip_value, color='red', linestyle='--', label=f'Clip: {clip_value}')
        plt.title(f'Clipping: {clip_value}')
    else:
        plt.title('No Clipping')
    
    plt.xlabel('Training Steps')
    plt.ylabel('Gradient Norm')
    plt.yscale('log')
    plt.legend()
    plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Gradient clipping prevents explosions and stabilizes training")

## Mixed Precision Training

Mixed precision uses FP16 for forward pass and FP32 for parameter updates. This provides ~2x speedup and memory savings with minimal quality loss. Essential for modern GPU training.

In [ ]:
def benchmark_mixed_precision():
    if not torch.cuda.is_available():
        print("CUDA not available - skipping mixed precision demo")
        return
    
    config = {
        'vocab_size': 1000,
        'd_model': 256,
        'n_heads': 8,
        'n_layers': 4,
        'd_ff': 512,
        'max_seq_len': 128,
        'dropout': 0.1
    }
    
    batch_size, seq_len, num_steps = 8, 64, 30
    results = {}
    
    for precision in ['FP32', 'FP16']:
        print(f"\nTesting {precision}...")
        
        model = GPTModel(config).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-4)
        scaler = torch.cuda.amp.GradScaler() if precision == 'FP16' else None
        
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        start_time = time.time()
        start_memory = torch.cuda.memory_allocated()
        
        losses = []
        
        for step in range(num_steps):
            x = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
            targets = torch.randint(0, config['vocab_size'], (batch_size, seq_len), device=device)
            
            optimizer.zero_grad()
            
            if precision == 'FP16':
                with torch.cuda.amp.autocast():
                    outputs = model(x)
                    loss = nn.CrossEntropyLoss()(outputs.reshape(-1, config['vocab_size']), targets.reshape(-1))
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(x)
                loss = nn.CrossEntropyLoss()(outputs.reshape(-1, config['vocab_size']), targets.reshape(-1))
                loss.backward()
                optimizer.step()
            
            losses.append(loss.item())
        
        torch.cuda.synchronize()
        end_time = time.time()
        peak_memory = torch.cuda.max_memory_allocated()
        
        results[precision] = {
            'time': end_time - start_time,
            'memory': (peak_memory - start_memory) / 1e9,
            'final_loss': losses[-1]
        }
        
        print(f"  Time: {results[precision]['time']:.2f}s")
        print(f"  Memory: {results[precision]['memory']:.2f}GB")
        print(f"  Final loss: {results[precision]['final_loss']:.4f}")
        
        del model, optimizer
        torch.cuda.empty_cache()
    
    if 'FP32' in results and 'FP16' in results:
        speedup = results['FP32']['time'] / results['FP16']['time']
        memory_savings = (results['FP32']['memory'] - results['FP16']['memory']) / results['FP32']['memory'] * 100
        
        print(f"\nMixed Precision Benefits:")
        print(f"  Speedup: {speedup:.1f}x faster")
        print(f"  Memory savings: {memory_savings:.0f}%")

benchmark_mixed_precision()

## Gradient Accumulation

Gradient accumulation simulates large batch training with smaller memory footprint. Instead of processing large batches, we accumulate gradients over multiple mini-batches before updating parameters.

In [ ]:
config = {
    'vocab_size': 1000,
    'd_model': 128,
    'n_heads': 8,
    'n_layers': 4,
    'd_ff': 256,
    'max_seq_len': 64,
    'dropout': 0.1
}

strategies = {
    'Large Batch': {'batch_size': 16, 'accum_steps': 1},
    'Grad Accum 4x': {'batch_size': 4, 'accum_steps': 4},
    'Grad Accum 8x': {'batch_size': 2, 'accum_steps': 8}
}

results = {}

for name, params in strategies.items():
    print(f"\nTesting {name}...")
    
    model = GPTModel(config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    losses = []
    peak_memory = 0
    
    for step in range(15):
        optimizer.zero_grad()
        step_loss = 0
        
        for accum_step in range(params['accum_steps']):
            x = torch.randint(0, config['vocab_size'], (params['batch_size'], 32), device=device)
            targets = torch.randint(0, config['vocab_size'], (params['batch_size'], 32), device=device)
            
            outputs = model(x)
            loss = nn.CrossEntropyLoss()(outputs.reshape(-1, config['vocab_size']), targets.reshape(-1))
            loss = loss / params['accum_steps']
            step_loss += loss.item()
            loss.backward()
            
            if torch.cuda.is_available():
                current_memory = torch.cuda.memory_allocated() / 1e9
                peak_memory = max(peak_memory, current_memory)
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        losses.append(step_loss * params['accum_steps'])
    
    results[name] = {
        'losses': losses,
        'peak_memory': peak_memory,
        'final_loss': losses[-1]
    }
    
    print(f"  Peak memory: {peak_memory:.2f}GB")
    print(f"  Final loss: {losses[-1]:.4f}")

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
for name, data in results.items():
    plt.plot(data['losses'], label=name, linewidth=2)
plt.title('Training Loss (Same Effective Batch Size)')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
names = list(results.keys())
memories = [results[name]['peak_memory'] for name in names]
colors = ['red', 'orange', 'green']

bars = plt.bar(names, memories, color=colors, alpha=0.7)
plt.title('Peak Memory Usage')
plt.ylabel('Memory (GB)')
plt.xticks(rotation=45)

for bar, memory in zip(bars, memories):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
            f'{memory:.2f}GB', ha='center', va='bottom')

plt.tight_layout()
plt.show()

print("\nGradient accumulation maintains quality with less memory")

## Complete Optimized Training Setup

This demonstrates how to combine all optimization techniques into a production-ready training pipeline with AdamW optimizer, warmup+cosine scheduling, mixed precision, and gradient clipping.

In [ ]:
def create_optimized_trainer(model, total_steps, warmup_steps=1000, lr=1e-4):
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.95))
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        else:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return 0.1 + 0.9 * 0.5 * (1 + np.cos(np.pi * progress))
    
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    
    return optimizer, scheduler, scaler

def optimized_training_step(model, batch, scaler, gradient_accumulation_steps=1):
    inputs, targets = batch
    
    if scaler is not None:
        with torch.cuda.amp.autocast():
            outputs = model(inputs)
            loss = nn.CrossEntropyLoss()(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
            loss = loss / gradient_accumulation_steps
        scaler.scale(loss).backward()
    else:
        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1))
        loss = loss / gradient_accumulation_steps
        loss.backward()
    
    return loss.item() * gradient_accumulation_steps

def optimized_update_step(model, optimizer, scheduler, scaler, clip_value=1.0):
    if scaler is not None:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        scaler.step(optimizer)
        scaler.update()
    else:
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
    
    scheduler.step()
    optimizer.zero_grad()

config = {
    'vocab_size': 500,
    'd_model': 128,
    'n_heads': 4,
    'n_layers': 2,
    'd_ff': 256,
    'max_seq_len': 32,
    'dropout': 0.1
}

model = GPTModel(config).to(device)
total_steps = 200
gradient_accumulation_steps = 4

optimizer, scheduler, scaler = create_optimized_trainer(model, total_steps)

print("Starting optimized training...")
print(f"• Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"• Mixed precision: {'Enabled' if scaler else 'Disabled'}")
print(f"• Gradient accumulation: {gradient_accumulation_steps}x")

losses = []
learning_rates = []

for step in range(total_steps):
    step_loss = 0
    
    for accum_step in range(gradient_accumulation_steps):
        x = torch.randint(0, config['vocab_size'], (2, 16), device=device)
        targets = torch.randint(0, config['vocab_size'], (2, 16), device=device)
        batch = (x, targets)
        
        loss = optimized_training_step(model, batch, scaler, gradient_accumulation_steps)
        step_loss += loss / gradient_accumulation_steps
    
    optimized_update_step(model, optimizer, scheduler, scaler)
    
    losses.append(step_loss)
    learning_rates.append(optimizer.param_groups[0]['lr'])
    
    if (step + 1) % 50 == 0:
        print(f"Step {step + 1}: Loss = {step_loss:.4f}, LR = {learning_rates[-1]:.6f}")

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(losses, linewidth=2)
plt.title('Training Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(learning_rates, linewidth=2, color='orange')
plt.title('Learning Rate Schedule')
plt.xlabel('Steps')
plt.ylabel('Learning Rate')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTraining completed!")
print(f"• Final loss: {losses[-1]:.4f}")
print(f"• Loss reduction: {((losses[0] - losses[-1]) / losses[0] * 100):.1f}%")

## Summary

Essential training optimization techniques for transformers:

- **Learning Rate Scheduling**: Warmup + cosine decay prevents instability and ensures smooth convergence
- **Gradient Clipping**: Threshold of 0.5-2.0 prevents gradient explosions
- **Mixed Precision**: FP16/FP32 combination provides 2x speedup with minimal quality loss
- **Gradient Accumulation**: Simulates large batches with limited memory
- **AdamW Optimizer**: Superior to Adam for transformer training with proper weight decay

These techniques are essential for efficient transformer training at scale.