# Memory Optimization with torch.compile

This notebook explores memory optimization techniques when using torch.compile, including memory profiling, gradient checkpointing, and memory-efficient compilation modes.

In [None]:
import torch
import torch.nn as nn
import torch.profiler
from torch.utils.checkpoint import checkpoint
import matplotlib.pyplot as plt
import numpy as np
from triton.testing import do_bench
import gc

## Memory Profiling Utilities

In [None]:
def get_memory_usage():
    """Get current GPU memory usage in GB"""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated() / 1024**3
    return 0

def memory_profile(func, *args, **kwargs):
    """Profile memory usage of a function"""
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    initial_memory = get_memory_usage()
    result = func(*args, **kwargs)
    peak_memory = get_memory_usage()
    
    return result, peak_memory - initial_memory

def compare_memory_usage(eager_fn, compiled_fn, *args, **kwargs):
    """Compare memory usage between eager and compiled modes"""
    _, eager_memory = memory_profile(eager_fn, *args, **kwargs)
    _, compiled_memory = memory_profile(compiled_fn, *args, **kwargs)
    
    print(f"Eager mode memory: {eager_memory:.3f} GB")
    print(f"Compiled mode memory: {compiled_memory:.3f} GB")
    print(f"Memory reduction: {(eager_memory - compiled_memory) / eager_memory * 100:.1f}%")
    
    return eager_memory, compiled_memory

## Memory-Efficient Model Architectures

In [None]:
class MemoryEfficientModel(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=2048, num_layers=8):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim if i == 0 else hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            ) for i in range(num_layers)
        ])
        self.output = nn.Linear(hidden_dim, 10)
        self.use_checkpointing = False
    
    def enable_gradient_checkpointing(self):
        self.use_checkpointing = True
    
    def forward(self, x):
        for layer in self.layers:
            if self.use_checkpointing and self.training:
                x = checkpoint(layer, x, use_reentrant=False)
            else:
                x = layer(x)
        return self.output(x)

# Create models for comparison
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_eager = MemoryEfficientModel().to(device)
model_compiled = torch.compile(MemoryEfficientModel().to(device), mode="reduce-overhead")
model_checkpointed = MemoryEfficientModel().to(device)
model_checkpointed.enable_gradient_checkpointing()
model_checkpointed_compiled = torch.compile(model_checkpointed, mode="reduce-overhead")

## Memory Usage Comparison

In [None]:
def train_step(model, x, target, optimizer):
    optimizer.zero_grad()
    output = model(x)
    loss = torch.nn.functional.cross_entropy(output, target)
    loss.backward()
    optimizer.step()
    return loss.item()

# Create test data
batch_size = 32
x = torch.randn(batch_size, 1024).to(device)
target = torch.randint(0, 10, (batch_size,)).to(device)

# Test different configurations
configs = [
    ("Eager", model_eager),
    ("Compiled", model_compiled),
    ("Checkpointed", model_checkpointed),
    ("Checkpointed + Compiled", model_checkpointed_compiled)
]

memory_results = {}
for name, model in configs:
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    def train_fn():
        return train_step(model, x, target, optimizer)
    
    _, memory_used = memory_profile(train_fn)
    memory_results[name] = memory_used
    print(f"{name}: {memory_used:.3f} GB")

## Detailed Memory Profiling with PyTorch Profiler

In [None]:
def profile_memory_timeline(model, x, target):
    """Profile memory usage over time during training"""
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        profile_memory=True,
        record_shapes=True,
        with_stack=True
    ) as prof:
        for _ in range(3):
            train_step(model, x, target, optimizer)
    
    # Print memory summary
    print(prof.key_averages().table(
        sort_by="cuda_memory_usage", 
        row_limit=10
    ))

print("Memory profiling - Eager mode:")
profile_memory_timeline(model_eager, x, target)

print("\nMemory profiling - Compiled mode:")
profile_memory_timeline(model_compiled, x, target)

## Memory-Efficient Compilation Modes

In [None]:
# Test different compilation modes for memory efficiency
base_model = MemoryEfficientModel().to(device)
compilation_modes = {
    "default": torch.compile(base_model, mode="default"),
    "reduce-overhead": torch.compile(base_model, mode="reduce-overhead"),
    "max-autotune": torch.compile(base_model, mode="max-autotune")
}

mode_memory_results = {}
for mode_name, compiled_model in compilation_modes.items():
    compiled_model.train()
    optimizer = torch.optim.Adam(compiled_model.parameters(), lr=0.001)
    
    def train_fn():
        return train_step(compiled_model, x, target, optimizer)
    
    _, memory_used = memory_profile(train_fn)
    mode_memory_results[mode_name] = memory_used
    print(f"Mode '{mode_name}': {memory_used:.3f} GB")

## Visualization of Memory Usage

In [None]:
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Memory usage by configuration
configs = list(memory_results.keys())
memory_usage = list(memory_results.values())

bars1 = ax1.bar(configs, memory_usage, color=['red', 'blue', 'green', 'purple'])
ax1.set_ylabel('Memory Usage (GB)')
ax1.set_title('Memory Usage by Configuration')
ax1.tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, value in zip(bars1, memory_usage):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{value:.3f}', ha='center', va='bottom')

# Memory usage by compilation mode
modes = list(mode_memory_results.keys())
mode_memory = list(mode_memory_results.values())

bars2 = ax2.bar(modes, mode_memory, color=['orange', 'cyan', 'magenta'])
ax2.set_ylabel('Memory Usage (GB)')
ax2.set_title('Memory Usage by Compilation Mode')

# Add value labels on bars
for bar, value in zip(bars2, mode_memory):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{value:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

## Memory Optimization Best Practices

### Key Takeaways:

1. **Gradient Checkpointing**: Trades computation for memory by not storing intermediate activations
2. **Compilation Modes**: Different modes have varying memory footprints
3. **Memory Profiling**: Essential for identifying memory bottlenecks
4. **Combined Approaches**: Checkpointing + compilation can provide best memory efficiency

### Recommendations:

- Use `reduce-overhead` mode for memory-constrained environments
- Enable gradient checkpointing for very deep models
- Profile regularly to monitor memory usage patterns
- Consider mixed precision training for further memory savings