# Problem 7: High Memory Footprint

Demonstrates GPU memory issues and solutions like gradient checkpointing.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/transformer_problems/blob/main/07_memory_footprint/demo.ipynb)


In [None]:
!pip install torch -q
import torch
import torch.nn as nn

# Calculate memory usage for attention
def attention_memory_mb(seq_len, d_model, n_heads, n_layers, batch_size=1, dtype_bytes=2):
    """Calculate memory for attention matrices"""
    # Attention matrix: batch × heads × seq × seq
    attn_mem = batch_size * n_heads * seq_len * seq_len * dtype_bytes
    # KV cache: 2 × batch × layers × seq × d_model
    kv_cache = 2 * batch_size * n_layers * seq_len * d_model * dtype_bytes
    return (attn_mem + kv_cache) / (1024 * 1024)

# Memory scaling with sequence length
print("Memory Usage vs Sequence Length (7B model config)")
print("Seq Length | Attention+KV (MB)")
print("-" * 35)

for seq_len in [512, 1024, 2048, 4096, 8192, 16384]:
    mem = attention_memory_mb(seq_len, 4096, 32, 32)
    print(f"{seq_len:>10} | {mem:>15.1f}")

print("\n⚠️ Memory grows quadratically with sequence length!")


In [None]:
# Solution: Gradient Checkpointing
class CheckpointedBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.linear2 = nn.Linear(d_model * 4, d_model)
    
    def forward(self, x):
        return self.linear2(torch.relu(self.linear1(x)))

# Without checkpointing: stores all activations
# With checkpointing: recomputes during backward pass

model = CheckpointedBlock(512)
x = torch.randn(1, 100, 512, requires_grad=True)

# Normal forward
out_normal = model(x)

# With gradient checkpointing
from torch.utils.checkpoint import checkpoint
out_ckpt = checkpoint(model, x, use_reentrant=False)

print("Gradient Checkpointing:")
print("  - Normal: Store all intermediate activations")
print("  - Checkpointed: Recompute activations during backward")
print("  - Trade-off: ~30% more compute, ~50% less memory")
print("\n✓ Essential for training large models!")
