# How much peak memory does running AdamW require?

Recall e2e transformer flow:

```
Input:
  x: (B, T, d_model)

For each transformer block:
  x_norm = RMSNorm(x)

  # Attention
  QKV = x_norm @ W_QKV
  Q,K,V → reshape to (B, H, T, d_head)
  attn = softmax(QKᵀ / √d_head + mask)
  out = attn @ V
  out = concat_heads(out) @ W_O
  x = x + out

  # MLP
  x_norm = RMSNorm(x)
  u = x_norm @ W_up
  v = x_norm @ W_gate
  out = swish(v) ⊙ u @ W_down
  x = x + out

Final:
  x = RMSNorm(x)
  logits = x @ W_vocab
```

AdamW per-param cost:
* 1st moment: 3 flops (m = b1*m + (1-b1)g => 2 mult, 1 add)
* 2nd moment: 3 flops
* param update: 5 flops
* param decay: 2 flop
Overall roughly 10x FLOPs per param.

Now looking into a single transformer block
1. RMS norm: d_model params
2. MHA: d_model * d_model params
3. SWIGLU: d_ff * d_model = 4*d_model^2 params

And also after transformer blocks:
1. final RMS norm: d_model params
2. output embedding: d_model * d_vocab
3. cross-entropy on logits: d_vocab

So total num params is:

> num_layers * (d_model + 5*d_model^2) + d_model + d_model * d_vocab + d_vocab

RAM cost:

AdamW requires storing two extra tensors per parameter (m and v), so optimizer state alone costs ~2x model size. Including gradients, total parameter-related memory is ~4x num_params. Peak training memory is usually dominated by activations (i.e. intermediate tensors that must be saved for gradient calculations, e.g. the x's), not AdamW.

Note that although intuitively, x is a vector and W is a matrix and this feels W is more expensive, in practice it's really batch * seq_len number of x vectors, vs. a d_model^2 parameter W matrix.

# Instantiate your answer for a GPT-2 XL-shaped model to get an expression that only depends on the batch_size. What is the maximum batch size you can use and still fit within 80GB memory?

