# Lab 4: Memory-Efficient Transformer Training Techniques

This notebook compares several modern memory-optimization techniques used during Transformer training:
1. **Baseline** with TF32 + FP32 full precision
2. **BF16 Automatic Mixed Precision** 
3. **FlashAttention** (FlashAttention 2)
4. **Windowed (Local) Attention**
5. **Gradient Checkpointing**

For each technique, we measure:
- GPU memory usage
- Maximum batch size that fits into memory
- Training speed (time per step and total time for 1 epoch)
- Final model performance (perplexity after 1 epoch)

## 1. Setup and Imports

In [10]:
import os
import gc
import json
import torch
import pandas as pd
import matplotlib.pyplot as plt

# Aggressively clear any leftover CUDA memory
if torch.cuda.is_available():
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

# Import from modules
from modules import (
    DEVICE,
    TransformerConfig,
    TrainingConfig,
    TransformerLanguageModel,
    build_tokenizer,
    run_experiment,
    count_parameters,
    clear_cuda_memory,
)
from modules.transformer import FLASH_ATTN_AVAILABLE

# Check CUDA availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU Memory: {gpu_mem:.2f} GB")
    print(f"Currently allocated: {torch.cuda.memory_allocated(0)/1e9:.2f} GB")
    print(f"Currently reserved: {torch.cuda.memory_reserved(0)/1e9:.2f} GB")
    # Test CUDA works
    try:
        test_tensor = torch.zeros(1, device="cuda")
        del test_tensor
        torch.cuda.empty_cache()
        print("CUDA test: OK")
    except Exception as e:
        print(f"CUDA test failed: {e}")
print(f"\nUsing device: {DEVICE}")
print(f"FlashAttention available: {FLASH_ATTN_AVAILABLE}")

PyTorch version: 2.8.0+cu128
CUDA available: True
CUDA version: 12.8
GPU: NVIDIA GeForce RTX 4090
GPU Memory: 25.26 GB
Currently allocated: 4.37 GB
Currently reserved: 7.71 GB
CUDA test: OK

Using device: cuda
FlashAttention available: True


## 2. Initialize Tokenizer and Results Storage

In [2]:
# Initialize tokenizer
tokenizer = build_tokenizer("gpt2")
print(f"Tokenizer vocab size: {tokenizer.vocab_size}")
print(f"Pad token ID: {tokenizer.pad_token_id}")

# Store all experiment results
all_results = []

Tokenizer vocab size: 50257
Pad token ID: 50256


## 3. Experiment 0: Baseline (TF32/FP32 Full Precision)

This is the baseline experiment with full precision training using TF32 for matrix multiplications.

In [3]:
# Experiment 0: Baseline with TF32/FP32
baseline_model_config = TransformerConfig(
    vocab_size=tokenizer.vocab_size,
    emb_dim=256,
    n_heads=8,
    n_layers=4,
    ff_dim=1024,
    dropout=0.1,
    max_seq_len=512,
    pad_token_id=tokenizer.pad_token_id,
    use_flash_attention=False,
    use_windowed_attention=False,
    gradient_checkpointing=False,
)

baseline_train_config = TrainingConfig(
    batch_size=16,  # Will be adjusted by find_max_batch_size
    max_length=256,
    steps_per_epoch=200,
    num_epochs=1,
    lr=3e-4,
    warmup_steps=200,
    grad_clip=1.0,
    use_bf16=False,
    use_flash_attention=False,
    use_windowed_attention=False,
    gradient_checkpointing=False,
)

baseline_results = run_experiment(
    "Baseline (TF32/FP32)",
    baseline_model_config,
    baseline_train_config,
    tokenizer,
    find_max_bs=True,
)
all_results.append(baseline_results)


Experiment: Baseline (TF32/FP32)

Finding maximum batch size...
  Batch size 2: OK (Peak: 0.56 GB)
  Batch size 4: OK (Peak: 1.06 GB)
  Batch size 8: OK (Peak: 2.03 GB)
  Batch size 16: OK (Peak: 3.98 GB)
  Batch size 32: OK (Peak: 7.87 GB)
  Batch size 64: OK (Peak: 15.67 GB)
  Batch size 128: OOM
Maximum batch size: 64

Model parameters: 16,021,248
Batch size: 64
BF16: False
Flash Attention: False
Windowed Attention: False
Gradient Checkpointing: False

Training...


Training:   0%|          | 0/200 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]


Evaluating perplexity on validation set...


Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]


Results for Baseline (TF32/FP32):
  name: Baseline (TF32/FP32)
  max_batch_size: 64
  avg_loss: 8.6448
  perplexity: 5680.2949
  avg_step_time: 0.0795
  total_time: 16.0645
  steps: 200
  forward_memory_gb: 3.4410
  backward_memory_gb: 3.4410
  peak_memory_gb: 15.7940
  val_perplexity: 1083.6308
  val_loss: 6.9881


## 4. Experiment 1: BF16 Automatic Mixed Precision

Use `torch.cuda.amp.autocast(dtype=torch.bfloat16)` for mixed precision training.

In [7]:
# Experiment 1: BF16 Mixed Precision
bf16_model_config = TransformerConfig(
    vocab_size=tokenizer.vocab_size,
    emb_dim=256,
    n_heads=8,
    n_layers=4,
    ff_dim=1024,
    dropout=0.1,
    max_seq_len=512,
    pad_token_id=tokenizer.pad_token_id,
    use_flash_attention=False,
    use_windowed_attention=False,
    gradient_checkpointing=False,
)

bf16_train_config = TrainingConfig(
    batch_size=16,  # Will be adjusted by find_max_batch_size
    max_length=256,
    steps_per_epoch=200,
    num_epochs=1,
    lr=3e-4,
    warmup_steps=200,
    grad_clip=1.0,
    use_bf16=True,  # Enable BF16
    use_flash_attention=False,
    use_windowed_attention=False,
    gradient_checkpointing=False,
)

bf16_results = run_experiment(
    "BF16 Mixed Precision",
    bf16_model_config,
    bf16_train_config,
    tokenizer,
    find_max_bs=True,
)
all_results.append(bf16_results)


Experiment: BF16 Mixed Precision

Finding maximum batch size...
  Batch size 2: OK (Peak: 4.83 GB)
  Batch size 4: OK (Peak: 5.19 GB)
  Batch size 8: OK (Peak: 5.91 GB)
  Batch size 16: OK (Peak: 7.35 GB)
  Batch size 32: OK (Peak: 10.24 GB)
  Batch size 64: OK (Peak: 16.01 GB)
  Batch size 128: OOM
Maximum batch size: 64

Model parameters: 16,021,248
Batch size: 64
BF16: True
Flash Attention: False
Windowed Attention: False
Gradient Checkpointing: False

Training...


Training:   0%|          | 0/200 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]


Evaluating perplexity on validation set...


Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

  with autocast(dtype=torch.bfloat16):



Results for BF16 Mixed Precision:
  name: BF16 Mixed Precision
  max_batch_size: 64
  avg_loss: 8.6508
  perplexity: 5714.9407
  avg_step_time: 0.0606
  total_time: 12.1390
  steps: 200
  forward_memory_gb: 6.1460
  backward_memory_gb: 6.1460
  peak_memory_gb: 16.1360
  val_perplexity: 1082.1886
  val_loss: 6.9867


## 5. Experiment 2: FlashAttention

Replace the default attention mechanism with FlashAttention 2 for memory-efficient attention computation.

In [9]:
# Experiment 2: FlashAttention (with BF16 - required for FlashAttention)
if FLASH_ATTN_AVAILABLE:
    flash_model_config = TransformerConfig(
        vocab_size=tokenizer.vocab_size,
        emb_dim=256,
        n_heads=8,
        n_layers=4,
        ff_dim=1024,
        dropout=0.1,
        max_seq_len=512,
        pad_token_id=tokenizer.pad_token_id,
        use_flash_attention=True,  # Enable FlashAttention
        use_windowed_attention=False,
        gradient_checkpointing=False,
    )

    flash_train_config = TrainingConfig(
        batch_size=16,  # Will be adjusted by find_max_batch_size
        max_length=256,
        steps_per_epoch=200,
        num_epochs=1,
        lr=3e-4,
        warmup_steps=200,
        grad_clip=1.0,
        use_bf16=True,  # FlashAttention requires BF16/FP16
        use_flash_attention=True,
        use_windowed_attention=False,
        gradient_checkpointing=False,
    )

    flash_results = run_experiment(
        "FlashAttention + BF16",
        flash_model_config,
        flash_train_config,
        tokenizer,
        find_max_bs=True,
    )
    all_results.append(flash_results)
else:
    print("FlashAttention not available. Skipping this experiment.")
    print("To install FlashAttention, run: pip install flash-attn --no-build-isolation")


Experiment: FlashAttention + BF16

Finding maximum batch size...
  Batch size 2: OK (Peak: 4.80 GB)
  Batch size 4: OK (Peak: 5.13 GB)
  Batch size 8: OK (Peak: 5.79 GB)
  Batch size 16: OK (Peak: 7.12 GB)
  Batch size 32: OK (Peak: 9.77 GB)
  Batch size 64: OK (Peak: 15.07 GB)
  Batch size 128: OOM
Maximum batch size: 64

Model parameters: 16,021,248
Batch size: 64
BF16: True
Flash Attention: True
Windowed Attention: False
Gradient Checkpointing: False

Training...


Training:   0%|          | 0/200 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]


Evaluating perplexity on validation set...


Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]


Results for FlashAttention + BF16:
  name: FlashAttention + BF16
  max_batch_size: 64
  avg_loss: 8.6575
  perplexity: 5753.1768
  avg_step_time: 0.0436
  total_time: 8.8873
  steps: 200
  forward_memory_gb: 6.1460
  backward_memory_gb: 6.1460
  peak_memory_gb: 15.1980
  val_perplexity: 1083.4000
  val_loss: 6.9879


## 6. Experiment 3: Windowed (Local) Attention

Replace full self-attention with sliding-window attention to reduce memory complexity from O(n²) to O(n × window_size).

In [None]:
# Experiment 3: Windowed Attention (using FlashAttention's sliding window if available)
window_model_config = TransformerConfig(
    vocab_size=tokenizer.vocab_size,
    emb_dim=256,
    n_heads=8,
    n_layers=4,
    ff_dim=1024,
    dropout=0.1,
    max_seq_len=512,
    pad_token_id=tokenizer.pad_token_id,
    use_flash_attention=FLASH_ATTN_AVAILABLE,  # Use FlashAttention for windowed if available
    use_windowed_attention=True,  # Enable windowed attention
    window_size=128,  # Sliding window size
    gradient_checkpointing=False,
)

window_train_config = TrainingConfig(
    batch_size=16,  # Will be adjusted by find_max_batch_size
    max_length=256,
    steps_per_epoch=200,
    num_epochs=1,
    lr=3e-4,
    warmup_steps=200,
    grad_clip=1.0,
    use_bf16=FLASH_ATTN_AVAILABLE,  # BF16 required for FlashAttention
    use_flash_attention=FLASH_ATTN_AVAILABLE,
    use_windowed_attention=True,
    window_size=128,
    gradient_checkpointing=False,
)

window_results = run_experiment(
    "Windowed Attention (w=128)" + (" + FlashAttn" if FLASH_ATTN_AVAILABLE else ""),
    window_model_config,
    window_train_config,
    tokenizer,
    find_max_bs=True,
)
all_results.append(window_results)


Experiment: Windowed Attention (w=128) + FlashAttn

Finding maximum batch size...
  Batch size 2: OK (Peak: 4.80 GB)
  Batch size 4: OK (Peak: 5.13 GB)
  Batch size 8: OK (Peak: 5.79 GB)
  Batch size 16: OK (Peak: 7.12 GB)
  Batch size 32: OK (Peak: 9.77 GB)
  Batch size 64: OK (Peak: 15.07 GB)
  Batch size 128: OOM
Maximum batch size: 64

Model parameters: 16,021,248
Batch size: 64
BF16: True
Flash Attention: True
Windowed Attention: True
Gradient Checkpointing: False

Training...


Training:   0%|          | 0/200 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]


Evaluating perplexity on validation set...


Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

## 7. Experiment 4: Gradient Checkpointing

Enable gradient checkpointing to trade compute for memory by recomputing activations during backward pass.

In [None]:
# Experiment 4: Gradient Checkpointing (with BF16)
gc_model_config = TransformerConfig(
    vocab_size=tokenizer.vocab_size,
    emb_dim=256,
    n_heads=8,
    n_layers=4,
    ff_dim=1024,
    dropout=0.1,
    max_seq_len=512,
    pad_token_id=tokenizer.pad_token_id,
    use_flash_attention=False,
    use_windowed_attention=False,
    gradient_checkpointing=True,  # Enable gradient checkpointing
)

gc_train_config = TrainingConfig(
    batch_size=16,  # Will be adjusted by find_max_batch_size
    max_length=256,
    steps_per_epoch=200,
    num_epochs=1,
    lr=3e-4,
    warmup_steps=200,
    grad_clip=1.0,
    use_bf16=True,  # Also use BF16 for fair comparison
    use_flash_attention=False,
    use_windowed_attention=False,
    gradient_checkpointing=True,
)

gc_results = run_experiment(
    "Gradient Checkpointing + BF16",
    gc_model_config,
    gc_train_config,
    tokenizer,
    find_max_bs=True,
)
all_results.append(gc_results)

## 8. Results Summary and Comparison

In [None]:
# Create results DataFrame
results_df = pd.DataFrame(all_results)

# Select key columns for display
display_cols = [
    'name', 
    'max_batch_size', 
    'peak_memory_gb', 
    'avg_step_time', 
    'total_time',
    'val_perplexity',
    'avg_loss',
]

# Filter to only existing columns
display_cols = [c for c in display_cols if c in results_df.columns]
summary_df = results_df[display_cols].copy()

# Rename columns for clarity
summary_df.columns = [
    'Technique',
    'Max Batch Size',
    'Peak Memory (GB)',
    'Step Time (s)',
    'Total Time (s)',
    'Validation Perplexity',
    'Training Loss',
]

print("\n" + "="*80)
print("RESULTS SUMMARY")
print("="*80)
print(summary_df.to_string(index=False))
print("="*80)

# Save results to JSON
results_path = "outputs/summary.json"
os.makedirs("outputs", exist_ok=True)
with open(results_path, "w") as f:
    json.dump(all_results, f, indent=2)
print(f"\nResults saved to {results_path}")

## 9. Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Peak Memory Usage
ax1 = axes[0, 0]
names = [r['name'] for r in all_results]
peak_mem = [r.get('peak_memory_gb', 0) for r in all_results]
bars = ax1.bar(range(len(names)), peak_mem, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'][:len(names)])
ax1.set_xticks(range(len(names)))
ax1.set_xticklabels(names, rotation=45, ha='right')
ax1.set_ylabel('Peak Memory (GB)')
ax1.set_title('Peak GPU Memory Usage')
for bar, val in zip(bars, peak_mem):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, f'{val:.2f}', ha='center', va='bottom')

# Plot 2: Maximum Batch Size
ax2 = axes[0, 1]
max_bs = [r.get('max_batch_size', 0) for r in all_results]
bars = ax2.bar(range(len(names)), max_bs, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'][:len(names)])
ax2.set_xticks(range(len(names)))
ax2.set_xticklabels(names, rotation=45, ha='right')
ax2.set_ylabel('Max Batch Size')
ax2.set_title('Maximum Batch Size Fitting in Memory')
for bar, val in zip(bars, max_bs):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, str(val), ha='center', va='bottom')

# Plot 3: Step Time
ax3 = axes[1, 0]
step_times = [r.get('avg_step_time', 0) for r in all_results]
bars = ax3.bar(range(len(names)), step_times, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'][:len(names)])
ax3.set_xticks(range(len(names)))
ax3.set_xticklabels(names, rotation=45, ha='right')
ax3.set_ylabel('Step Time (s)')
ax3.set_title('Average Step Time')
for bar, val in zip(bars, step_times):
    ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{val:.3f}', ha='center', va='bottom')

# Plot 4: Validation Perplexity
ax4 = axes[1, 1]
val_ppl = [r.get('val_perplexity', 0) for r in all_results]
bars = ax4.bar(range(len(names)), val_ppl, color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd'][:len(names)])
ax4.set_xticks(range(len(names)))
ax4.set_xticklabels(names, rotation=45, ha='right')
ax4.set_ylabel('Perplexity')
ax4.set_title('Validation Perplexity After 1 Epoch')
for bar, val in zip(bars, val_ppl):
    ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, f'{val:.1f}', ha='center', va='bottom')

plt.tight_layout()
plt.savefig('outputs/comparison_plots.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nPlots saved to outputs/comparison_plots.png")

## 10. Analysis and Conclusions

### Memory Efficiency
- **BF16 Mixed Precision** reduces memory by storing weights and activations in 16-bit format instead of 32-bit
- **FlashAttention** avoids materializing the full attention matrix, reducing memory from O(n²) to O(n)
- **Windowed Attention** limits attention to local context, further reducing memory requirements
- **Gradient Checkpointing** trades compute for memory by not storing intermediate activations

### Speed Trade-offs
- **BF16** is generally faster due to reduced memory bandwidth and Tensor Core utilization
- **FlashAttention** is optimized for modern GPUs and often faster than standard attention
- **Windowed Attention** can be faster for long sequences but may hurt model quality
- **Gradient Checkpointing** increases compute time (~20-30%) due to recomputation

### Perplexity Impact
- **BF16 and FlashAttention** should have minimal impact on final perplexity
- **Windowed Attention** may degrade perplexity as it limits the model's ability to attend to distant tokens
- **Gradient Checkpointing** should have no impact on perplexity (mathematically equivalent)

In [None]:
# Print detailed analysis
print("\n" + "="*80)
print("DETAILED ANALYSIS")
print("="*80)

if len(all_results) > 0:
    baseline = all_results[0]
    
    print(f"\nBaseline ({baseline['name']}):")
    print(f"  Peak Memory: {baseline.get('peak_memory_gb', 'N/A')} GB")
    print(f"  Max Batch Size: {baseline.get('max_batch_size', 'N/A')}")
    print(f"  Step Time: {baseline.get('avg_step_time', 'N/A'):.4f}s")
    print(f"  Val Perplexity: {baseline.get('val_perplexity', 'N/A'):.2f}")
    
    print("\nComparison to Baseline:")
    for result in all_results[1:]:
        name = result['name']
        print(f"\n{name}:")
        
        # Memory reduction
        if baseline.get('peak_memory_gb') and result.get('peak_memory_gb'):
            mem_reduction = (1 - result['peak_memory_gb'] / baseline['peak_memory_gb']) * 100
            print(f"  Memory: {result['peak_memory_gb']:.2f} GB ({mem_reduction:+.1f}%)")
        
        # Batch size increase
        if baseline.get('max_batch_size') and result.get('max_batch_size'):
            bs_increase = (result['max_batch_size'] / baseline['max_batch_size'] - 1) * 100
            print(f"  Max Batch Size: {result['max_batch_size']} ({bs_increase:+.1f}%)")
        
        # Speed comparison
        if baseline.get('avg_step_time') and result.get('avg_step_time'):
            speed_change = (1 - result['avg_step_time'] / baseline['avg_step_time']) * 100
            print(f"  Step Time: {result['avg_step_time']:.4f}s ({speed_change:+.1f}%)")
        
        # Perplexity comparison
        if baseline.get('val_perplexity') and result.get('val_perplexity'):
            ppl_change = result['val_perplexity'] - baseline['val_perplexity']
            print(f"  Val Perplexity: {result['val_perplexity']:.2f} ({ppl_change:+.2f})")

print("\n" + "="*80)