<a href="https://colab.research.google.com/github/kiankyars/Ultra-Scale-Playbook-Series/blob/main/2_activation_memory_recomputation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Ultra-Scale Playbook: Part 2

In this notebook, we'll explore:

- Memory costs per parameter in FP32 and mixed precision (BF16)
- What activation memory is and why it becomes the dominant cost at scale
- Gradient checkpointing (activation recomputation) to reduce memory usage
- Measuring memory and recomputation tradeoffs in a toy transformer

This notebook accompanies Video 2 in the series.

In [None]:
# Memory footprint breakdown
fp32_bytes = 4  # 4 bytes per parameter
momentum_bytes = 4
variance_bytes = 4
total_optimizer_bytes = fp32_bytes + momentum_bytes + variance_bytes

print(f"Total per-parameter memory (FP32): {total_optimizer_bytes} bytes")

# Mixed precision (BF16)
bf16_param = 2
bf16_grad = 2
master_weight = 4

total_mp_bytes = bf16_param + bf16_grad + master_weight + momentum_bytes + variance_bytes
print(f"Total per-parameter memory (Mixed Precision): {total_mp_bytes} bytes")

### What are activations?

Activations are the intermediate outputs during the forward pass of a model. These are cached to compute gradients during the backward pass (via chain rule).

Their memory grows **linearly** with:

- Batch size
- Sequence length
- Hidden dimension
- Number of layers

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Simulate activation memory growth
context_lengths = [512, 1024, 2048, 4096, 8192]
layers = 24
hidden_dim = 4096
batch_size = 1

# Each activation token size: 2 bytes * 2 * hidden_dim (Q/K/V)
bytes_per_token = 2 * hidden_dim * 2  

activation_memory = [cl * batch_size * bytes_per_token * layers / 1e9 for cl in context_lengths]

plt.plot(context_lengths, activation_memory, marker='o')
plt.title("Activation Memory vs. Context Length")
plt.xlabel("Context Length (tokens)")
plt.ylabel("Activation Memory (GB)")
plt.grid(True)
plt.show()

## Gradient Checkpointing Demo

We trade compute for memory by recomputing activations during the backward pass.

In [None]:
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

class ToyBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(4096, 4096)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(4096, 4096)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

block = ToyBlock()
x = torch.randn(8, 4096, requires_grad=True)

# Regular forward
out_regular = block(x)
loss_regular = out_regular.mean()
loss_regular.backward()
print("Done regular forward/backward")

# Clear gradients
block.zero_grad()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# With gradient checkpointing
x_cp = x.detach().requires_grad_()
out_cp = checkpoint(block, x_cp)
loss_cp = out_cp.mean()
loss_cp.backward()
print("Done checkpointing forward/backward")

### Quick Quiz

1. Why does activation memory grow linearly with sequence length?
2. What trade-off does gradient checkpointing make?
3. In what case is full recomputation useful?
4. Which component benefits most from recomputation in Transformers?

### Answers

1. Each token generates intermediate outputs that must be stored.
2. Reduces memory usage by adding extra computation.
3. When GPU memory is very limited but training time is less critical.
4. Attention layers (large activations, cheap recompute).