# Memory Profiling

This notebook profiles memory usage of different attention implementations.

## Learning Objectives
- Understand memory consumption of attention mechanisms
- Use the ATO profiling tools
- Compare memory scaling of standard vs linear attention

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

# Check for CUDA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name()}")
    total_memory = torch.cuda.get_device_properties(0).total_memory
    print(f"Total Memory: {total_memory / 1e9:.1f} GB")

## Memory Tracking Utilities

In [None]:
class MemoryTracker:
    """Simple GPU memory tracker."""
    
    def __init__(self):
        self.reset()
    
    def reset(self):
        if device == 'cuda':
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
    
    def current_allocated(self):
        if device == 'cuda':
            return torch.cuda.memory_allocated() / 1e6  # MB
        return 0
    
    def peak_allocated(self):
        if device == 'cuda':
            return torch.cuda.max_memory_allocated() / 1e6  # MB
        return 0
    
    def reserved(self):
        if device == 'cuda':
            return torch.cuda.memory_reserved() / 1e6  # MB
        return 0

tracker = MemoryTracker()

In [None]:
def measure_memory(fn, *args, **kwargs):
    """
    Measure memory usage of a function.
    
    Returns:
        peak_memory: Peak memory usage in MB
    """
    tracker.reset()
    
    # Get baseline memory
    baseline = tracker.current_allocated()
    
    # Run function
    result = fn(*args, **kwargs)
    
    if device == 'cuda':
        torch.cuda.synchronize()
    
    # Get peak memory
    peak = tracker.peak_allocated()
    
    return peak - baseline, result

## Attention Implementations

In [None]:
def standard_attention(q, k, v):
    """Standard attention - materializes n×n attention matrix."""
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, v)


def linear_attention_bidir(q, k, v, eps=1e-6):
    """Linear attention - O(d²) intermediate state."""
    q = F.elu(q) + 1
    k = F.elu(k) + 1
    
    kv = torch.einsum('bhnd,bhnv->bhdv', k, v)  # (B, H, D, D)
    k_sum = k.sum(dim=2)
    
    out = torch.einsum('bhnd,bhdv->bhnv', q, kv)
    norm = torch.einsum('bhnd,bhd->bhn', q, k_sum).unsqueeze(-1)
    return out / (norm + eps)


def linear_attention_causal(q, k, v, eps=1e-6):
    """Causal linear attention - uses cumsum."""
    q = F.elu(q) + 1
    k = F.elu(k) + 1
    
    kv = torch.einsum('bhnd,bhnv->bhndv', k, v)  # (B, H, N, D, D)
    kv_cumsum = torch.cumsum(kv, dim=2)
    k_cumsum = torch.cumsum(k, dim=2)
    
    out = torch.einsum('bhnd,bhndv->bhnv', q, kv_cumsum)
    norm = torch.einsum('bhnd,bhnd->bhn', q, k_cumsum).unsqueeze(-1)
    return out / (norm + eps)

## Theoretical Memory Analysis

Before measuring, let's compute theoretical memory requirements:

### Standard Attention
- Attention matrix: `B × H × N × N × sizeof(dtype)`
- At N=8192, B=4, H=8, fp16: 4 × 8 × 8192² × 2 = **4.3 GB** just for attention scores!

### Linear Attention (Bidirectional)
- KV state: `B × H × D × D × sizeof(dtype)`
- At D=64, B=4, H=8, fp16: 4 × 8 × 64 × 64 × 2 = **2 MB** regardless of N

### Linear Attention (Causal)
- Cumulative KV: `B × H × N × D × D × sizeof(dtype)`
- At N=8192, D=64, B=4, H=8, fp16: 4 × 8 × 8192 × 64 × 64 × 2 = **17 GB** (worse than standard!)

In [None]:
def theoretical_memory(batch_size, num_heads, seq_len, head_dim, dtype='fp16'):
    """
    Compute theoretical memory requirements.
    
    Returns dict with memory in MB for each attention type.
    """
    bytes_per_elem = 2 if dtype == 'fp16' else 4
    B, H, N, D = batch_size, num_heads, seq_len, head_dim
    
    # Input memory (Q, K, V)
    input_mem = 3 * B * H * N * D * bytes_per_elem / 1e6
    
    # Standard attention: attention matrix
    standard_attn_matrix = B * H * N * N * bytes_per_elem / 1e6
    
    # Linear bidir: KV state
    linear_bidir_state = B * H * D * D * bytes_per_elem / 1e6
    
    # Linear causal: cumulative KV (N copies of state)
    linear_causal_state = B * H * N * D * D * bytes_per_elem / 1e6
    
    return {
        'inputs': input_mem,
        'standard_peak': input_mem + standard_attn_matrix,
        'linear_bidir_peak': input_mem + linear_bidir_state,
        'linear_causal_peak': input_mem + linear_causal_state,
    }

# Example calculation
theory = theoretical_memory(4, 8, 8192, 64)
print("Theoretical memory at seq_len=8192:")
for key, val in theory.items():
    print(f"  {key}: {val:.1f} MB")

## Empirical Memory Measurements

In [None]:
# Parameters
batch_size = 4
num_heads = 8
head_dim = 64

seq_lengths = [512, 1024, 2048, 4096]

# Extend if we have enough memory
if device == 'cuda' and total_memory > 20e9:
    seq_lengths.append(8192)

print(f"Testing sequence lengths: {seq_lengths}")

In [None]:
# Run memory measurements
results = []

for seq_len in seq_lengths:
    print(f"\nMeasuring seq_len={seq_len}...")
    
    # Create inputs
    q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16)
    k = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16)
    v = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=torch.float16)
    
    input_mem = 3 * q.numel() * 2 / 1e6  # 3 tensors, 2 bytes per element
    
    row = {
        'seq_len': seq_len,
        'input_memory_mb': input_mem,
    }
    
    # Standard attention
    try:
        mem, _ = measure_memory(standard_attention, q, k, v)
        row['standard_memory_mb'] = mem
        print(f"  Standard: {mem:.1f} MB")
    except RuntimeError:
        row['standard_memory_mb'] = float('nan')
        print(f"  Standard: OOM")
    
    # Linear bidir
    try:
        mem, _ = measure_memory(linear_attention_bidir, q, k, v)
        row['linear_bidir_memory_mb'] = mem
        print(f"  Linear (bidir): {mem:.1f} MB")
    except RuntimeError:
        row['linear_bidir_memory_mb'] = float('nan')
        print(f"  Linear (bidir): OOM")
    
    # Linear causal
    try:
        mem, _ = measure_memory(linear_attention_causal, q, k, v)
        row['linear_causal_memory_mb'] = mem
        print(f"  Linear (causal): {mem:.1f} MB")
    except RuntimeError:
        row['linear_causal_memory_mb'] = float('nan')
        print(f"  Linear (causal): OOM")
    
    # Theoretical values
    theory = theoretical_memory(batch_size, num_heads, seq_len, head_dim)
    row['standard_theoretical_mb'] = theory['standard_peak']
    row['linear_bidir_theoretical_mb'] = theory['linear_bidir_peak']
    row['linear_causal_theoretical_mb'] = theory['linear_causal_peak']
    
    results.append(row)
    
    # Cleanup
    del q, k, v
    if device == 'cuda':
        torch.cuda.empty_cache()

df = pd.DataFrame(results)
print("\n" + "="*80)
print(df.to_string(index=False))

## Visualize Memory Scaling

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Empirical memory
ax = axes[0]
ax.plot(df['seq_len'], df['standard_memory_mb'], 'o-', label='Standard', linewidth=2)
ax.plot(df['seq_len'], df['linear_bidir_memory_mb'], 's-', label='Linear (bidir)', linewidth=2)
ax.plot(df['seq_len'], df['linear_causal_memory_mb'], '^-', label='Linear (causal)', linewidth=2)
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Peak Memory (MB)')
ax.set_title('Empirical Memory Usage')
ax.legend()
ax.set_xscale('log', base=2)
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

# Theoretical vs empirical for standard
ax = axes[1]
valid = ~df['standard_memory_mb'].isna()
ax.scatter(df.loc[valid, 'seq_len'], df.loc[valid, 'standard_memory_mb'], 
           s=100, label='Standard (measured)')
ax.plot(df['seq_len'], df['standard_theoretical_mb'], '--', label='Standard (theoretical)')
ax.scatter(df['seq_len'], df['linear_bidir_memory_mb'], 
           s=100, marker='s', label='Linear bidir (measured)')
ax.plot(df['seq_len'], df['linear_bidir_theoretical_mb'], '--', label='Linear bidir (theoretical)')
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Peak Memory (MB)')
ax.set_title('Theoretical vs Empirical Memory')
ax.legend()
ax.set_xscale('log', base=2)
ax.set_yscale('log')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../results/plots/memory_scaling.png', dpi=150, bbox_inches='tight')
plt.show()

## Memory Savings Analysis

In [None]:
# Compute memory savings
df['savings_bidir'] = df['standard_memory_mb'] / df['linear_bidir_memory_mb']

fig, ax = plt.subplots(figsize=(10, 5))

valid = ~df['savings_bidir'].isna()
ax.bar(range(len(df[valid])), df.loc[valid, 'savings_bidir'], color='steelblue')
ax.set_xlabel('Sequence Length')
ax.set_ylabel('Memory Savings (Standard / Linear)')
ax.set_title('Linear Attention Memory Savings (Bidirectional)')
ax.set_xticks(range(len(df[valid])))
ax.set_xticklabels(df.loc[valid, 'seq_len'].astype(int))

for i, (idx, row) in enumerate(df[valid].iterrows()):
    ax.annotate(f'{row["savings_bidir"]:.1f}x',
                xy=(i, row['savings_bidir']),
                xytext=(0, 5), textcoords="offset points",
                ha='center', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.savefig('../results/plots/memory_savings.png', dpi=150, bbox_inches='tight')
plt.show()

## Key Observations

1. **Standard attention memory scales O(n²)** - The attention matrix dominates
2. **Linear bidirectional scales O(n)** - Only the KV state (O(d²)) is stored
3. **Linear causal can be worse!** - The naive cumsum implementation stores O(n × d²)

### The Causal Linear Attention Dilemma

The naive causal implementation materializes the entire cumulative state:
```python
kv_cumsum = torch.cumsum(kv, dim=2)  # Shape: (B, H, N, D, D)
```

This is O(N × D²) which can exceed standard attention's O(N²) when D² > N!

**Solutions:**
1. Chunked processing (only keep current state)
2. Recurrent implementation (sequential but memory-efficient)
3. Parallel scan with state checkpointing

In [None]:
# Save results
df.to_csv('../results/benchmarks/memory_profiling.csv', index=False)
print("Results saved to ../results/benchmarks/memory_profiling.csv")

## Next Steps

- [04_benchmark_comparison.ipynb](04_benchmark_comparison.ipynb): Full benchmark with optimized implementations