# üèãÔ∏è Lecture 12: Efficient Training - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/12_efficient_training/demo.ipynb)

## What You'll Learn
- Gradient checkpointing for memory savings
- Mixed precision training (FP16/BF16)
- Gradient accumulation
- Memory-efficient optimizers

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
print('Ready for Efficient Training!')

## Part 1: Training Memory Breakdown

Training memory = Model + Gradients + Optimizer States + Activations

In [None]:
def training_memory_breakdown(params_millions, seq_len=512, batch_size=8, dtype='fp32'):
    """
    Calculate memory breakdown for training.
    """
    params = params_millions * 1e6
    bytes_per_param = 4 if dtype == 'fp32' else 2
    
    # Model weights
    model_mem = params * bytes_per_param
    
    # Gradients (same size as model)
    grad_mem = params * bytes_per_param
    
    # Optimizer states (Adam: momentum + variance, always FP32)
    opt_mem = params * 4 * 2  # Two FP32 states
    
    # Activations (rough estimate: ~10x model size for typical transformers)
    act_mem = model_mem * 10 * (batch_size / 8)
    
    total = model_mem + grad_mem + opt_mem + act_mem
    
    return {
        'model': model_mem / 1e9,
        'gradients': grad_mem / 1e9,
        'optimizer': opt_mem / 1e9,
        'activations': act_mem / 1e9,
        'total': total / 1e9
    }

# Analyze different model sizes
print('üìä TRAINING MEMORY BREAKDOWN (FP32)')
print('=' * 70)
print(f'{"Model":<15} {"Weights":<12} {"Grads":<12} {"Optimizer":<12} {"Acts":<12} {"Total":<12}')
print('-' * 70)

models = {
    'BERT-base': 110,
    'GPT-2': 1500,
    'LLaMA-7B': 7000,
    'LLaMA-70B': 70000,
}

for name, params in models.items():
    mem = training_memory_breakdown(params)
    print(f'{name:<15} {mem["model"]:>10.1f}GB {mem["gradients"]:>10.1f}GB {mem["optimizer"]:>10.1f}GB {mem["activations"]:>10.1f}GB {mem["total"]:>10.1f}GB')

print('\n‚ö†Ô∏è Activations often dominate training memory!')

In [None]:
# Visualize memory breakdown
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Breakdown for LLaMA-7B
mem = training_memory_breakdown(7000)
labels = ['Model', 'Gradients', 'Optimizer', 'Activations']
sizes = [mem['model'], mem['gradients'], mem['optimizer'], mem['activations']]
colors = ['#3b82f6', '#22c55e', '#f59e0b', '#ef4444']

axes[0].pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
axes[0].set_title(f'LLaMA-7B Training Memory\nTotal: {mem["total"]:.0f} GB')

# Comparison across models
model_names = list(models.keys())
totals = [training_memory_breakdown(p)['total'] for p in models.values()]

bars = axes[1].bar(model_names, totals, color='#3b82f6')
axes[1].set_ylabel('Total Memory (GB)')
axes[1].set_title('Training Memory by Model Size')
axes[1].set_yscale('log')

# Add GPU reference lines
axes[1].axhline(y=24, color='green', linestyle='--', label='RTX 4090 (24GB)')
axes[1].axhline(y=80, color='orange', linestyle='--', label='A100 (80GB)')
axes[1].legend()

for bar, total in zip(bars, totals):
    axes[1].text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.1,
                 f'{total:.0f}GB', ha='center')

plt.tight_layout()
plt.show()

## Part 2: Gradient Checkpointing

In [None]:
class TransformerBlock(nn.Module):
    """Standard transformer block."""
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ffn(self.norm2(x))
        return x

class TransformerWithCheckpointing(nn.Module):
    """Transformer that optionally uses gradient checkpointing."""
    def __init__(self, n_layers=12, d_model=512, use_checkpointing=False):
        super().__init__()
        self.use_checkpointing = use_checkpointing
        self.layers = nn.ModuleList([TransformerBlock(d_model) for _ in range(n_layers)])
        self.head = nn.Linear(d_model, 1000)
    
    def forward(self, x):
        for layer in self.layers:
            if self.use_checkpointing and self.training:
                # Checkpoint: Don't store activations, recompute in backward
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        return self.head(x.mean(dim=1))

def measure_memory(model, input_shape, backward=True):
    """Measure peak memory during forward/backward."""
    torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
    
    x = torch.randn(input_shape)
    if torch.cuda.is_available():
        x = x.cuda()
        model = model.cuda()
    
    # Forward
    model.train()
    out = model(x)
    
    if backward:
        loss = out.sum()
        loss.backward()
    
    if torch.cuda.is_available():
        return torch.cuda.max_memory_allocated() / 1e9
    return 0

print('üìä GRADIENT CHECKPOINTING CONCEPT')
print('=' * 60)
print('\nWithout checkpointing:')
print('  - Store ALL intermediate activations')
print('  - Memory: O(n_layers √ó batch √ó seq √ó d)')
print('\nWith checkpointing:')
print('  - Store only layer inputs')
print('  - Recompute activations during backward')
print('  - Memory: O(‚àön_layers √ó batch √ó seq √ó d)')
print('  - Time: +30% (recomputation cost)')

# Demo models
model_standard = TransformerWithCheckpointing(n_layers=8, use_checkpointing=False)
model_checkpointed = TransformerWithCheckpointing(n_layers=8, use_checkpointing=True)

# Count parameters
params = sum(p.numel() for p in model_standard.parameters())
print(f'\nModel parameters: {params/1e6:.1f}M')

In [None]:
# Simulate memory savings
def simulate_checkpoint_savings(n_layers, batch_size=8, seq_len=512, d_model=512):
    """
    Simulate memory savings from checkpointing.
    """
    # Activation memory per layer (simplified)
    act_per_layer = batch_size * seq_len * d_model * 4 / 1e9  # GB, FP32
    
    # Without checkpointing: store all
    no_ckpt = n_layers * act_per_layer
    
    # With checkpointing: store sqrt(n) checkpoints
    n_checkpoints = int(np.sqrt(n_layers))
    with_ckpt = n_checkpoints * act_per_layer
    
    return no_ckpt, with_ckpt

# Compare across different layer counts
layer_counts = [6, 12, 24, 48, 96]

print('üìä CHECKPOINTING MEMORY SAVINGS')
print('=' * 50)
print(f'{"Layers":<10} {"No Ckpt (GB)":<15} {"With Ckpt (GB)":<15} {"Savings":<10}')
print('-' * 50)

no_ckpt_mems = []
ckpt_mems = []

for n in layer_counts:
    no_ckpt, with_ckpt = simulate_checkpoint_savings(n)
    savings = (no_ckpt - with_ckpt) / no_ckpt * 100
    no_ckpt_mems.append(no_ckpt)
    ckpt_mems.append(with_ckpt)
    print(f'{n:<10} {no_ckpt:<15.2f} {with_ckpt:<15.2f} {savings:<10.0f}%')

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(layer_counts))
width = 0.35

bars1 = ax.bar(x - width/2, no_ckpt_mems, width, label='Standard', color='#ef4444')
bars2 = ax.bar(x + width/2, ckpt_mems, width, label='Checkpointed', color='#22c55e')

ax.set_xlabel('Number of Layers')
ax.set_ylabel('Activation Memory (GB)')
ax.set_title('üìä Gradient Checkpointing Memory Savings')
ax.set_xticks(x)
ax.set_xticklabels(layer_counts)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## Part 3: Mixed Precision Training

In [None]:
def mixed_precision_demo():
    """
    Demonstrate mixed precision training concepts.
    """
    print('üìä MIXED PRECISION TRAINING')
    print('=' * 60)
    print('\nKey components:')
    print('  1. Forward pass: FP16 (half memory, faster compute)')
    print('  2. Loss scaling: Prevent underflow in FP16 gradients')
    print('  3. Master weights: FP32 copy for accumulation')
    print('  4. Gradient unscaling: Before optimizer step')
    
    # Compare data types
    dtypes = {
        'FP32': {'bytes': 4, 'range': '¬±3.4e38', 'precision': '~7 digits'},
        'FP16': {'bytes': 2, 'range': '¬±65504', 'precision': '~3 digits'},
        'BF16': {'bytes': 2, 'range': '¬±3.4e38', 'precision': '~2 digits'},
    }
    
    print(f'\n{"Type":<8} {"Bytes":<8} {"Range":<15} {"Precision":<15}')
    print('-' * 50)
    for name, info in dtypes.items():
        print(f'{name:<8} {info["bytes"]:<8} {info["range"]:<15} {info["precision"]:<15}')
    
    # Memory savings
    print('\nüìä MEMORY SAVINGS')
    params_m = 1000  # 1B parameters
    
    fp32_mem = params_m * 4 / 1000  # GB
    mixed_mem = params_m * 2 / 1000 + params_m * 4 / 1000  # FP16 model + FP32 master
    
    print(f'1B param model:')
    print(f'  FP32: {fp32_mem:.1f} GB (model only)')
    print(f'  Mixed: {mixed_mem:.1f} GB (FP16 + FP32 master)')
    print(f'  Activations: 2x savings in FP16!')

mixed_precision_demo()

In [None]:
# Simulate loss scaling
def demonstrate_loss_scaling():
    """
    Show why loss scaling is needed for FP16 training.
    """
    # FP16 smallest positive normal: ~6e-5
    # Gradients can be smaller than this!
    
    # Simulate gradient distribution
    np.random.seed(42)
    gradients = np.abs(np.random.randn(10000)) * 1e-5  # Small gradients
    
    fp16_min = 6e-5  # Approximate FP16 minimum
    
    # Without scaling: many gradients underflow
    underflow_mask = gradients < fp16_min
    underflow_pct = underflow_mask.sum() / len(gradients) * 100
    
    # With scaling (scale = 1024)
    scale = 1024
    scaled_gradients = gradients * scale
    scaled_underflow = (scaled_gradients < fp16_min).sum() / len(gradients) * 100
    
    print('üìä LOSS SCALING DEMONSTRATION')
    print('=' * 50)
    print(f'Sample gradient magnitude: ~{np.mean(gradients):.2e}')
    print(f'FP16 minimum normal: ~{fp16_min:.0e}')
    print(f'\nWithout scaling: {underflow_pct:.1f}% gradients underflow')
    print(f'With scaling (1024x): {scaled_underflow:.1f}% gradients underflow')
    
    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].hist(np.log10(gradients + 1e-10), bins=50, color='#ef4444', alpha=0.7)
    axes[0].axvline(x=np.log10(fp16_min), color='black', linestyle='--', 
                    linewidth=2, label=f'FP16 min ({fp16_min:.0e})')
    axes[0].set_xlabel('log10(gradient)')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Without Loss Scaling')
    axes[0].legend()
    
    axes[1].hist(np.log10(scaled_gradients + 1e-10), bins=50, color='#22c55e', alpha=0.7)
    axes[1].axvline(x=np.log10(fp16_min), color='black', linestyle='--', 
                    linewidth=2, label=f'FP16 min ({fp16_min:.0e})')
    axes[1].set_xlabel('log10(gradient)')
    axes[1].set_ylabel('Count')
    axes[1].set_title('With Loss Scaling (1024x)')
    axes[1].legend()
    
    plt.tight_layout()
    plt.show()

demonstrate_loss_scaling()

## Part 4: Gradient Accumulation

In [None]:
def gradient_accumulation_demo():
    """
    Demonstrate gradient accumulation for large effective batch sizes.
    """
    print('üìä GRADIENT ACCUMULATION')
    print('=' * 60)
    print('\nProblem: Want large batch (e.g., 1024) but only fit batch=8')
    print('\nSolution: Accumulate gradients over multiple mini-batches')
    print('\nPseudocode:')
    print('''    
    accumulation_steps = 128  # 1024 / 8
    for i, (x, y) in enumerate(dataloader):
        loss = model(x, y) / accumulation_steps  # Scale loss
        loss.backward()  # Accumulate gradients
        
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()  # Update weights
            optimizer.zero_grad()  # Reset gradients
    ''')
    
    # Comparison
    scenarios = [
        ('Standard', 8, 1, 8),
        ('4x Accumulation', 8, 4, 32),
        ('16x Accumulation', 8, 16, 128),
        ('128x Accumulation', 8, 128, 1024),
    ]
    
    print('\nüìä EFFECTIVE BATCH SIZE COMPARISON')
    print(f'{"Method":<20} {"Mini-batch":<12} {"Accum Steps":<12} {"Effective":<12}')
    print('-' * 60)
    for name, mini, accum, effective in scenarios:
        print(f'{name:<20} {mini:<12} {accum:<12} {effective:<12}')

gradient_accumulation_demo()

In [None]:
# Verify gradient accumulation produces same result
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)
    
    def forward(self, x):
        return self.fc(x)

# Create data
torch.manual_seed(42)
X = torch.randn(32, 10)
y = torch.randn(32, 1)

# Method 1: Full batch
model1 = SimpleModel()
model1.fc.weight.data = torch.randn_like(model1.fc.weight)
model1.fc.bias.data = torch.zeros_like(model1.fc.bias)
initial_weight = model1.fc.weight.data.clone()

out1 = model1(X)
loss1 = F.mse_loss(out1, y)
loss1.backward()
grad_full = model1.fc.weight.grad.clone()

# Method 2: Accumulated (4 mini-batches of 8)
model2 = SimpleModel()
model2.fc.weight.data = initial_weight.clone()
model2.fc.bias.data = torch.zeros_like(model2.fc.bias)

accum_steps = 4
mini_batch = 8

model2.zero_grad()
for i in range(accum_steps):
    start = i * mini_batch
    end = (i + 1) * mini_batch
    out2 = model2(X[start:end])
    loss2 = F.mse_loss(out2, y[start:end]) / accum_steps  # Scale loss!
    loss2.backward()

grad_accum = model2.fc.weight.grad.clone()

print('üìä GRADIENT ACCUMULATION VERIFICATION')
print('=' * 50)
print(f'Full batch gradient norm: {grad_full.norm():.6f}')
print(f'Accumulated gradient norm: {grad_accum.norm():.6f}')
print(f'Difference: {(grad_full - grad_accum).abs().max():.8f}')
print(f'\n‚úÖ Gradients match!')

## Part 5: Memory-Efficient Optimizers

In [None]:
def optimizer_memory_comparison(params_billions):
    """
    Compare memory usage of different optimizers.
    """
    params = params_billions * 1e9
    
    optimizers = {
        'SGD': params * 4,  # Just momentum (FP32)
        'SGD + Momentum': params * 4,  # Momentum state
        'Adam': params * 4 * 2,  # m and v states (FP32)
        'AdaFactor': params * 4 * 0.5,  # Factorized states
        '8-bit Adam': params * 1 * 2,  # Quantized states
    }
    
    print('üìä OPTIMIZER MEMORY USAGE')
    print('=' * 50)
    print(f'Model: {params_billions}B parameters')
    print(f'\n{"Optimizer":<20} {"State Memory (GB)":<20}')
    print('-' * 40)
    
    for name, mem in optimizers.items():
        mem_gb = mem / 1e9
        print(f'{name:<20} {mem_gb:<20.1f}')
    
    return optimizers

opt_mems = optimizer_memory_comparison(7)  # LLaMA-7B

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))

names = list(opt_mems.keys())
mems = [m / 1e9 for m in opt_mems.values()]
colors = ['#22c55e', '#22c55e', '#ef4444', '#f59e0b', '#3b82f6']

bars = ax.bar(names, mems, color=colors)
ax.set_ylabel('Memory (GB)')
ax.set_title('üìä Optimizer State Memory for 7B Model')
ax.grid(True, alpha=0.3, axis='y')

for bar, mem in zip(bars, mems):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
            f'{mem:.0f}GB', ha='center')

plt.xticks(rotation=15)
plt.tight_layout()
plt.show()

In [None]:
print('üéØ KEY TAKEAWAYS')
print('=' * 60)
print('\n1. Training memory: Model + Grads + Optimizer + Activations')
print('\n2. Gradient Checkpointing: ‚àöN memory, +30% time')
print('\n3. Mixed Precision: 2x faster, needs loss scaling')
print('\n4. Gradient Accumulation: Large batch without memory increase')
print('\n5. 8-bit Optimizers: 4x less optimizer memory')
print('\n6. Combine all for maximum efficiency!')
print('\n' + '=' * 60)
print('\nüìö Next: On-Device Training!')