# Week 10 — Efficient Training (Training at Scale)

This notebook focuses on building efficient training pipelines. You'll:
- Implement efficient DataLoaders with prefetching
- Profile and optimize data loading bottlenecks
- Use learning rate schedulers
- Build robust checkpoint systems

In [None]:
# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import time
from pathlib import Path
%matplotlib inline

torch.manual_seed(42)
print(f"PyTorch version: {torch.__version__}")

## 1. Custom Dataset with Efficient Loading

Implement a custom dataset with proper data loading patterns.

In [None]:
# Custom dataset with simulated I/O delay
class SyntheticDataset(Dataset):
    def __init__(self, n_samples=1000, n_features=100, delay_ms=0):
        """
        Args:
            n_samples: Number of samples
            n_features: Number of features
            delay_ms: Simulated I/O delay in milliseconds
        """
        self.n_samples = n_samples
        self.n_features = n_features
        self.delay_ms = delay_ms
        
        # Pre-generate data (in practice, you'd load from disk)
        self.data = torch.randn(n_samples, n_features)
        self.labels = torch.randint(0, 10, (n_samples,))
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        # Simulate I/O delay
        if self.delay_ms > 0:
            time.sleep(self.delay_ms / 1000.0)
        
        return self.data[idx], self.labels[idx]

# Create datasets with different delays
dataset_fast = SyntheticDataset(n_samples=1000, delay_ms=0)
dataset_slow = SyntheticDataset(n_samples=1000, delay_ms=1)

print(f"Dataset size: {len(dataset_fast)}")
sample_x, sample_y = dataset_fast[0]
print(f"Sample shape: {sample_x.shape}, label: {sample_y}")

## 2. DataLoader Performance Comparison

Compare different DataLoader configurations and measure throughput.

In [None]:
# Benchmark function
def benchmark_dataloader(dataset, batch_size=32, num_workers=0, pin_memory=False, n_batches=50):
    """
    Benchmark DataLoader throughput
    """
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )
    
    start_time = time.time()
    for i, (batch_x, batch_y) in enumerate(loader):
        if i >= n_batches:
            break
        # Simulate processing
        _ = batch_x.mean()
    
    elapsed = time.time() - start_time
    throughput = n_batches / elapsed
    
    return elapsed, throughput

# Compare configurations
configs = [
    {'num_workers': 0, 'pin_memory': False, 'name': '0 workers'},
    {'num_workers': 2, 'pin_memory': False, 'name': '2 workers'},
    {'num_workers': 4, 'pin_memory': False, 'name': '4 workers'},
    {'num_workers': 4, 'pin_memory': True, 'name': '4 workers + pin_memory'},
]

print("Benchmarking DataLoader configurations...\n")
results = []

for config in configs:
    elapsed, throughput = benchmark_dataloader(
        dataset_slow,
        batch_size=32,
        num_workers=config['num_workers'],
        pin_memory=config['pin_memory'],
        n_batches=30
    )
    results.append({'name': config['name'], 'elapsed': elapsed, 'throughput': throughput})
    print(f"{config['name']:25s}: {elapsed:.2f}s ({throughput:.1f} batches/s)")

# Plot results
names = [r['name'] for r in results]
throughputs = [r['throughput'] for r in results]

plt.figure(figsize=(10, 5))
plt.barh(names, throughputs, color='skyblue')
plt.xlabel('Throughput (batches/second)')
plt.title('DataLoader Performance Comparison')
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()

print("\n→ Using multiple workers significantly improves throughput!")

## 3. Learning Rate Schedulers

Experiment with different LR scheduling strategies.

In [None]:
# Simple model for demonstration
model = nn.Sequential(
    nn.Linear(100, 128),
    nn.ReLU(),
    nn.Linear(128, 10)
)

# Create different schedulers
def visualize_schedulers(initial_lr=0.1, n_epochs=100):
    schedulers_to_test = {
        'StepLR': optim.lr_scheduler.StepLR(
            optim.SGD(model.parameters(), lr=initial_lr),
            step_size=30,
            gamma=0.1
        ),
        'ExponentialLR': optim.lr_scheduler.ExponentialLR(
            optim.SGD(model.parameters(), lr=initial_lr),
            gamma=0.95
        ),
        'CosineAnnealingLR': optim.lr_scheduler.CosineAnnealingLR(
            optim.SGD(model.parameters(), lr=initial_lr),
            T_max=n_epochs
        ),
        'ReduceLROnPlateau': optim.lr_scheduler.ReduceLROnPlateau(
            optim.SGD(model.parameters(), lr=initial_lr),
            mode='min',
            factor=0.5,
            patience=10
        )
    }
    
    lr_histories = {name: [] for name in schedulers_to_test.keys()}
    
    for epoch in range(n_epochs):
        for name, scheduler in schedulers_to_test.items():
            current_lr = scheduler.optimizer.param_groups[0]['lr']
            lr_histories[name].append(current_lr)
            
            # Step scheduler
            if name == 'ReduceLROnPlateau':
                # Simulate a metric (e.g., loss)
                fake_loss = 1.0 / (epoch + 1) + np.random.randn() * 0.01
                scheduler.step(fake_loss)
            else:
                scheduler.step()
    
    return lr_histories

# Visualize
lr_histories = visualize_schedulers(initial_lr=0.1, n_epochs=100)

plt.figure(figsize=(12, 5))
for name, lrs in lr_histories.items():
    plt.plot(lrs, label=name, linewidth=2)

plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedulers Comparison')
plt.legend()
plt.grid(alpha=0.3)
plt.yscale('log')
plt.tight_layout()
plt.show()

## 4. Robust Checkpointing System

Build a comprehensive checkpoint system with metadata and resumption support.

In [None]:
# Checkpoint manager
class CheckpointManager:
    def __init__(self, checkpoint_dir='checkpoints', keep_last_n=3):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        self.keep_last_n = keep_last_n
    
    def save(self, model, optimizer, scheduler, epoch, metrics, is_best=False):
        """
        Save checkpoint with metadata
        
        Args:
            model: PyTorch model
            optimizer: Optimizer state
            scheduler: LR scheduler state
            epoch: Current epoch
            metrics: Dict of metrics (e.g., {'train_loss': 0.5, 'val_acc': 0.9})
            is_best: Whether this is the best model so far
        """
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'metrics': metrics,
        }
        
        # Save regular checkpoint
        checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pth'
        torch.save(checkpoint, checkpoint_path)
        print(f"Checkpoint saved: {checkpoint_path}")
        
        # Save best model
        if is_best:
            best_path = self.checkpoint_dir / 'best_model.pth'
            torch.save(checkpoint, best_path)
            print(f"Best model saved: {best_path}")
        
        # Clean up old checkpoints
        self._cleanup_old_checkpoints()
    
    def load(self, checkpoint_path, model, optimizer=None, scheduler=None):
        """
        Load checkpoint and restore state
        """
        checkpoint = torch.load(checkpoint_path)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        if optimizer:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scheduler and checkpoint['scheduler_state_dict']:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        return checkpoint['epoch'], checkpoint['metrics']
    
    def _cleanup_old_checkpoints(self):
        """Keep only the last N checkpoints"""
        checkpoints = sorted(self.checkpoint_dir.glob('checkpoint_epoch_*.pth'))
        if len(checkpoints) > self.keep_last_n:
            for ckpt in checkpoints[:-self.keep_last_n]:
                ckpt.unlink()
                print(f"Removed old checkpoint: {ckpt}")

# Example usage
ckpt_manager = CheckpointManager(checkpoint_dir='./checkpoints', keep_last_n=3)

# Simulate training and saving checkpoints
model = nn.Linear(10, 2)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10)

best_val_acc = 0.0
for epoch in range(1, 6):
    # Simulate training
    train_loss = 1.0 / epoch + np.random.rand() * 0.1
    val_acc = min(0.95, 0.5 + epoch * 0.1 + np.random.rand() * 0.05)
    
    metrics = {'train_loss': train_loss, 'val_acc': val_acc}
    is_best = val_acc > best_val_acc
    
    if is_best:
        best_val_acc = val_acc
    
    ckpt_manager.save(model, optimizer, scheduler, epoch, metrics, is_best=is_best)
    scheduler.step()
    print()

# Load best model
print("\nLoading best model...")
best_epoch, best_metrics = ckpt_manager.load('./checkpoints/best_model.pth', model, optimizer, scheduler)
print(f"Loaded best model from epoch {best_epoch}")
print(f"Metrics: {best_metrics}")

## Exercises for Further Practice

1. **Mixed Precision Training**: Use `torch.cuda.amp` for faster training
2. **Gradient Accumulation**: Implement gradient accumulation for larger effective batch sizes
3. **Profiling**: Use PyTorch profiler to identify bottlenecks
4. **Distributed Training**: Explore `torch.nn.DataParallel` or `DistributedDataParallel`
5. **Real Dataset**: Apply these techniques to a large dataset (e.g., ImageNet subset)

## Deliverables Checklist

- [ ] Custom dataset with efficient data loading
- [ ] DataLoader performance benchmarks
- [ ] Learning rate scheduler experiments
- [ ] Robust checkpoint system implementation
- [ ] Short report on throughput optimizations

## Recommended Resources

- PyTorch documentation on data loading and DataLoader
- PyTorch performance tuning guide
- Papers on distributed training and mixed precision