# Advanced Techniques

Mixed precision training, optimization, and performance techniques.

## Mixed Precision Training

Use lower precision (FP16/BF16) for faster, more memory-efficient training while maintaining accuracy.

### Why Mixed Precision?

| Metric | FP32 | FP16 | BF16 | FP8 |
|--------|------|------|------|-----|
| Memory | 4B | 2B | 2B | 1B |
| Speed | 1x | 2-3x | 2-3x | 4-8x |
| Precision | Full | Half | Half | Very Low |
| Training | ✓✓ | ✓ (w/ scaling) | ✓✓ | ✗ |
| Inference | ✓ | ✓ | ✓ | ✓ |

**Key Insight**: Use lower precision for weights/activations, higher for loss computation.

### FP16 Training Requirements

To train with FP16 without numerical issues:

1. **Loss Scaling**: Multiply loss by large scale factor to prevent underflow
   ```mojo
   scaled_loss = loss * 1024.0  # Scale up
   scaled_loss.backward()        # Compute gradients with scaled loss
   gradients = gradients / 1024.0 # Scale back down
   ```

2. **Gradient Clipping**: Prevent exploding gradients
   ```mojo
   for param in model.parameters:
       norm = sqrt(sum(grad^2 for grad in param.grad))
       if norm > max_norm:
           param.grad = param.grad / norm * max_norm
   ```

3. **Keep master weights in FP32**: For accurate parameter updates

In [None]:
# Simulation of mixed precision training
import numpy as np
import matplotlib.pyplot as plt

# Training curves: FP32 vs FP16
epochs = range(1, 11)
fp32_loss = np.array([0.45, 0.20, 0.085, 0.042, 0.025, 0.018, 0.012, 0.010, 0.008, 0.007])
fp16_loss = np.array([0.45, 0.20, 0.086, 0.043, 0.025, 0.018, 0.012, 0.010, 0.008, 0.007])

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Loss comparison
ax1.plot(epochs, fp32_loss, 'b-o', label='FP32', linewidth=2)
ax1.plot(epochs, fp16_loss, 'r--s', label='FP16 (with loss scaling)', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('FP32 vs FP16 Training')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Training time comparison (simulated)
precisions = ['FP32', 'FP16', 'BF16', 'FP8']
times = [100, 35, 35, 18]  # Relative time
colors = ['blue', 'orange', 'green', 'red']
ax2.barh(precisions, times, color=colors, alpha=0.7)
ax2.set_xlabel('Training Time (arbitrary units)')
ax2.set_title('Training Speed by Precision')
ax2.invert_yaxis()
for i, (p, t) in enumerate(zip(precisions, times)):
    ax2.text(t+2, i, f'{t/100:.1f}x', va='center')

plt.tight_layout()
plt.show()

print("Results:")
print(f"  FP32 final loss: {fp32_loss[-1]:.4f}")
print(f"  FP16 final loss: {fp16_loss[-1]:.4f}")
print(f"  Difference: {abs(fp32_loss[-1] - fp16_loss[-1]):.6f}")
print(f"\n  FP16 training is ~3x faster with negligible accuracy loss!")

## Learning Rate Scheduling

Adapt learning rate during training for better convergence:

In [None]:
import numpy as np
import matplotlib.pyplot as plt

epochs = np.arange(0, 100)

# Different learning rate schedules
lr_constant = np.ones_like(epochs) * 0.001
lr_linear = 0.001 * (1 - epochs / 100)
lr_exponential = 0.001 * np.exp(-epochs / 30)
lr_cosine = 0.001 * (1 + np.cos(np.pi * epochs / 100)) / 2
lr_step = 0.001 * (0.5 ** (epochs // 25))

fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(epochs, lr_constant, label='Constant', linewidth=2)
ax.plot(epochs, lr_linear, label='Linear Decay', linewidth=2)
ax.plot(epochs, lr_exponential, label='Exponential Decay', linewidth=2)
ax.plot(epochs, lr_cosine, label='Cosine Annealing', linewidth=2)
ax.plot(epochs, lr_step, label='Step Decay', linewidth=2, linestyle='--')

ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedules')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_yscale('log')
plt.tight_layout()
plt.show()

print("Popular schedules:")
print("  - Step Decay: Reduce by factor every N epochs")
print("  - Cosine Annealing: Smooth decay following cosine curve")
print("  - Linear Decay: Linearly reduce from initial to final LR")
print("  - Exponential: Geometric decay")

## Optimization Algorithms

Different optimizers for different scenarios:

In [None]:
optimizers = {
    "SGD": {
        "description": "Stochastic Gradient Descent",
        "pros": ["Simple", "Generalizes well"],
        "cons": ["Slow convergence", "Sensitive to LR"],
        "best_for": "CNNs, strong baseline",
    },
    "SGD+Momentum": {
        "description": "SGD with momentum term",
        "pros": ["Faster convergence", "Less sensitive to LR"],
        "cons": ["One extra hyperparameter (β)"],
        "best_for": "Most training tasks",
    },
    "Adam": {
        "description": "Adaptive Moment Estimation",
        "pros": ["Adaptive per-parameter LR", "Few hyperparameters"],
        "cons": ["May not generalize as well", "High memory"],
        "best_for": "Transformers, NLP",
    },
    "AdamW": {
        "description": "Adam with decoupled weight decay",
        "pros": ["Better regularization", "SOTA for transformers"],
        "cons": ["Newer, less widely tested"],
        "best_for": "Vision Transformers, modern models",
    },
}

for name, info in optimizers.items():
    print(f"\n{name}: {info['description']}")
    print(f"  Pros: {', '.join(info['pros'])}")
    print(f"  Cons: {', '.join(info['cons'])}")
    print(f"  Best for: {info['best_for']}")

## Regularization Techniques

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Simulate overfitting with different regularization
epochs = np.arange(1, 21)

# No regularization - severe overfitting
train_loss_no_reg = 1.0 / (1 + 2*epochs**0.5)
val_loss_no_reg = 1.0 / (1 + epochs**0.5) + 0.1 * epochs**0.5

# L2 regularization - moderate regularization
train_loss_l2 = 1.0 / (1 + 1.8*epochs**0.5)
val_loss_l2 = 1.0 / (1 + 1.8*epochs**0.5) + 0.02 * epochs**0.3

# Dropout - good regularization
train_loss_dropout = 1.0 / (1 + 1.7*epochs**0.5)
val_loss_dropout = 1.0 / (1 + 1.7*epochs**0.5) + 0.01 * epochs**0.2

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

configs = [
    ("No Regularization", train_loss_no_reg, val_loss_no_reg, axes[0]),
    ("L2 Regularization", train_loss_l2, val_loss_l2, axes[1]),
    ("Dropout", train_loss_dropout, val_loss_dropout, axes[2]),
]

for title, train, val, ax in configs:
    ax.plot(epochs, train, 'b-o', label='Train', linewidth=2)
    ax.plot(epochs, val, 'r-s', label='Val', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0, 1.2])

plt.tight_layout()
plt.show()

print("Regularization Impact:")
print(f"  No Reg - Train/Val gap at epoch 20: {(val_loss_no_reg[-1] - train_loss_no_reg[-1]):.3f}")
print(f"  L2 Reg - Train/Val gap at epoch 20: {(val_loss_l2[-1] - train_loss_l2[-1]):.3f}")
print(f"  Dropout - Train/Val gap at epoch 20: {(val_loss_dropout[-1] - train_loss_dropout[-1]):.3f}")

## Quantization for Inference

Reduce model size for deployment using quantization:

In [None]:
quantization_comparison = {
    "FP32 (baseline)": {
        "model_size": 100,  # Baseline
        "inference_speed": 100,
        "accuracy": 100,
    },
    "FP16": {
        "model_size": 50,
        "inference_speed": 200,
        "accuracy": 99.9,
    },
    "INT8": {
        "model_size": 25,
        "inference_speed": 400,
        "accuracy": 99.0,
    },
    "INT4": {
        "model_size": 12,
        "inference_speed": 800,
        "accuracy": 98.0,
    },
}

fig, axes = plt.subplots(1, 3, figsize=(14, 4))

keys = list(quantization_comparison.keys())
sizes = [quantization_comparison[k]["model_size"] for k in keys]
speeds = [quantization_comparison[k]["inference_speed"] for k in keys]
accs = [quantization_comparison[k]["accuracy"] for k in keys]

axes[0].bar(keys, sizes, color='steelblue', alpha=0.7)
axes[0].set_ylabel('Model Size (relative)')
axes[0].set_title('Model Size by Quantization')
axes[0].tick_params(axis='x', rotation=45)

axes[1].bar(keys, speeds, color='orange', alpha=0.7)
axes[1].set_ylabel('Inference Speed (relative)')
axes[1].set_title('Inference Speed by Quantization')
axes[1].tick_params(axis='x', rotation=45)

axes[2].plot(keys, accs, 'go-', linewidth=2, markersize=8)
axes[2].set_ylabel('Accuracy (%)')
axes[2].set_title('Accuracy vs Quantization')
axes[2].tick_params(axis='x', rotation=45)
axes[2].set_ylim([97, 100.5])
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nQuantization Strategy:")
print("  1. Train in FP32 with full precision")
print("  2. Evaluate accuracy at different bit widths")
print("  3. Use INT8 for 4x speedup with <1% accuracy loss")
print("  4. Use INT4 for extreme compression (edge devices)")

## Performance Benchmarking

In [None]:
import time

# Simulate performance profiling
print("Performance Profiling Example:\n")

operations = [
    {"name": "Conv2D 3×3", "time_ms": 15.2, "memory_mb": 32.5},
    {"name": "Linear (256→120)", "time_ms": 3.1, "memory_mb": 2.1},
    {"name": "BatchNorm", "time_ms": 2.8, "memory_mb": 1.5},
    {"name": "ReLU", "time_ms": 1.2, "memory_mb": 0.8},
    {"name": "MaxPool", "time_ms": 0.9, "memory_mb": 0.5},
]

total_time = sum(op["time_ms"] for op in operations)
total_memory = sum(op["memory_mb"] for op in operations)

print(f"{'Operation':<20} {'Time (ms)':<12} {'Memory (MB)':<15} {'% Time':<10}")
print("-" * 60)
for op in operations:
    pct = (op["time_ms"] / total_time) * 100
    print(f"{op['name']:<20} {op['time_ms']:<12.1f} {op['memory_mb']:<15.1f} {pct:<10.1f}%")
print("-" * 60)
print(f"{'Total':<20} {total_time:<12.1f} {total_memory:<15.1f} {'100.0%':<10}")

print(f"\nConclusions:")
print(f"  - Conv layers are bottleneck (80% of time)")
print(f"  - Consider dilated or depthwise convolutions")
print(f"  - Activations/pooling are cheap (<5% each)")
print(f"  - Memory dominated by activations, not parameters")

## Takeaways

1. **Mixed Precision**: Use FP16 for 2-3x speedup with minimal accuracy loss
2. **Learning Rate Scheduling**: Cosine annealing works well for most tasks
3. **Optimizer Choice**: SGD+Momentum for vision, Adam for transformers
4. **Regularization**: Dropout + L2 are effective and complementary
5. **Quantization**: INT8 offers 4x speedup for inference
6. **Profiling**: Find bottlenecks before optimizing

These techniques are production-ready in ML Odyssey!