<a href="https://colab.research.google.com/github/jman4162/Sizing-AI-Training-by-Cost-per-Memory-Bandwidth/blob/main/Sizing_AI_Training_by_Cost_per_Memory_Bandwidth.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Sizing AI Training by **Cost per Memory Bandwidth**

*A practical model (with math + Python) to tell if you’re compute-, memory-, or network-bound—and what to buy next*

Author: John Hodge

Date: 09/03/2025

## Introduction

When you’re training big transformers, the question that actually determines throughput isn’t “How many TFLOPs do I have?” It’s “How many **bytes per second** can I push through HBM, and at what **cost**?”

This post gives you a compact, first-order model—both in math and in runnable Python—to:

* Diagnose whether your run is **compute**, **memory**, or **network** bound,
* Estimate tokens/sec per GPU and GPUs needed to hit a target,
* Compare hardware using **\$/TB/s/hour** (cost per memory bandwidth), which often tracks throughput/\$ better than TFLOPs/\$ for large LLM training.

You’ll get tunable knobs for things like activation checkpointing, FlashAttention, and 8-bit optimizers, so you can adapt the model to your exact stack.

---

## 1) The core idea

Large-scale transformer training frequently hits the **memory wall**: step time is limited by how fast parameters, activations, and optimizer states move to/from HBM, not by peak FLOPs. That’s why the cost metric that matters is:

$$
\textbf{Cost per memory bandwidth} \;\equiv\;
\frac{\$ / \text{GPU-hour}}{\text{HBM bandwidth (TB/s)}}
\quad\Rightarrow\quad \frac{\$}{\text{TB/s·hour}}
$$

Lower is better. Hardware with high HBM BW (and good interconnect) can deliver more tokens/sec per dollar when your workload is bandwidth bound.

---

## 2) A simple per-token model

Let:

* $N$: trainable parameters
* $B_g$: **global tokens per step** (global batch × sequence length)
* $\kappa$: FLOPs/token coefficient (≈ **6** for forward+backward with Adam; ≈ **2** for forward only)
* $\gamma \ge 1$: recompute multiplier (activation checkpointing overhead; e.g., **1.2–1.4**)
* $\alpha_{\text{opt}}$: **optimizer traffic** in **bytes/param/step** (Adam bf16 often **16–20 B**; 8-bit Adam lower)
* $L$: layers, $d_{\text{model}}$: hidden size, $b$: **bytes/element** (2 for bf16/fp16)
* $c_{\text{act}}$: **activation traffic coefficient** (how many hidden-vectors’ worth of traffic per token per layer; lower with FlashAttention/fused kernels)

Work per token (FLOPs)
$$
F_{\text{tok}} \;=\; \gamma \cdot \kappa \cdot N
$$
HBM bytes per token
$$
M_{\text{tok}} \;=\;
\underbrace{\frac{\alpha_{\text{opt}}\,N}{B_g}}_{\text{optimizer + grads (amortized)}}
\;+\;
\underbrace{c_{\text{act}}\;L\,d_{\text{model}}\,b}_{\text{activations}}
$$
Arithmetic intensity and “machine balance”
$$
I \;=\; \frac{F_{\text{tok}}}{M_{\text{tok}}}\quad(\text{FLOPs/byte}),
\qquad
\text{Machine balance} \;=\; \frac{C_{\max}}{W_{\text{HBM}}}
$$
where $C_{\max}$ is the GPU’s usable FLOPs/s and $W_{\text{HBM}}$ is its HBM bytes/s.

If $ I <$ machine balance, you’re **memory-bound**; otherwise you’re **compute-bound**.

---

## 3) From per-token to tokens/sec per GPU

Given hardware per GPU:

* Compute: $C_{\max}$ FLOPs/s
* HBM BW: $W_{\text{HBM}}$ bytes/s
* Utilization $u \in [0.6, 0.9]$ to capture kernel overlap, scheduling, etc.

Then
$$
r_{\text{comp}} = \frac{C_{\max}}{F_{\text{tok}}},\qquad
r_{\text{mem}}  = \frac{W_{\text{HBM}}}{M_{\text{tok}}},\qquad
r_{\text{gpu}}  = u \cdot \min(r_{\text{comp}}, r_{\text{mem}})
$$

---

## 4) Add the network term (data parallel)

All-reduce traffic for data-parallel gradients adds a per-token network cost (per GPU):
$$
M_{\text{tok}}^{\text{net}} \;\approx\; \frac{2\,N\,b_g}{B_g}
\quad\Rightarrow\quad
r_{\text{net}} \;=\; \frac{W_{\text{NIC}}}{M_{\text{tok}}^{\text{net}}}
$$
where $b_g$ is bytes/grad element (2 for bf16) and $W_{\text{NIC}}$ is effective inter-node BW per GPU.

Final bound:
$$
r_{\text{gpu}} = u \cdot \min(r_{\text{comp}},\, r_{\text{mem}},\, r_{\text{net}})
$$
Reading the tea leaves

If **network-bound**, increase $B_g$ (bigger global batch), reduce DP shards (use TP/PP/ZeRO), or increase effective NIC BW (EFA/IB, overlap, grad compression).
If **memory-bound**, raise arithmetic intensity: FlashAttention, fused kernels, selective recompute, 8-bit optimizer/state, larger $B_g$.
If **compute-bound**, you’re doing great—push higher clocks/utilization or add parallelism.

---

## 5) From per-GPU to cluster sizing

For a target throughput $R_{\text{target}}$ tokens/s:
$$
G \;=\; \left\lceil \frac{R_{\text{target}}}{r_{\text{gpu}}} \right\rceil
$$
Approximate cluster HBM needed:
$$
\text{Cluster TB/s} \;\approx\; \frac{R_{\text{target}} \cdot M_{\text{tok}}}{10^{12}}
$$

---

## 6) Converting to cost per bandwidth

For each candidate GPU:
$$
\frac{\$}{\text{TB/s·hr}} \;=\; \frac{\$ / \text{GPU-hr}}{\text{HBM TB/s per GPU}}
$$
Choose the configuration that:

* Meets $R_{\text{target}}$ without being network-bound,
* Minimizes \$/TB/s·hr while maintaining enough compute balance.

---

## 7) The Python reference implementation

Below is a ready-to-use script that encodes the model. Plug in your own model/hardware numbers and compare options. (Numbers shown are placeholders—replace with your real specs and prices.)

In [4]:
from dataclasses import dataclass
from math import ceil

@dataclass
class Hardware:
    name: str
    peak_flops_tflops: float     # usable BF16/FP16 TFLOPs per GPU
    hbm_tbps: float              # HBM/GDDR TB/s per GPU
    nic_gbps: float              # effective inter-node BW per GPU (Gb/s); 0 if none
    price_per_gpu_hr: float      # $/GPU-hour
    utilization: float = 0.75

    @property
    def machine_balance_flops_per_byte(self):
        return (self.peak_flops_tflops * 1e12) / (self.hbm_tbps * 1e12)

@dataclass
class Model:
    n_params: float              # trainable params
    layers: int
    d_model: int
    bytes_per_elem: int = 2      # 2 for bf16/fp16

@dataclass
class TrainingCfg:
    k_flops_per_token: float = 6.0     # ≈6*N for Adam (fwd+bwd)
    recompute_mult: float = 1.0        # ≥1.0 (activation checkpointing)
    alpha_opt_bytes_per_param: float = 16.0  # Adam traffic B/param/step
    c_act: float = 6.0                 # activation traffic coefficient
    global_tokens_per_step: int = 512_000    # global_batch * seq_len
    bytes_per_grad_elem: int = 2

def per_token_flops(model: Model, train: TrainingCfg) -> float:
    return train.recompute_mult * train.k_flops_per_token * model.n_params

def per_token_hbm_bytes(model: Model, train: TrainingCfg) -> float:
    opt_bytes = (train.alpha_opt_bytes_per_param * model.n_params) / train.global_tokens_per_step
    act_bytes = train.c_act * model.layers * model.d_model * model.bytes_per_elem
    return opt_bytes + act_bytes

def per_token_net_bytes(model: Model, train: TrainingCfg, dp_world_size: int) -> float:
    if dp_world_size <= 1:
        return 0.0
    per_step = 2.0 * model.n_params * train.bytes_per_grad_elem
    return per_step / train.global_tokens_per_step

def tokens_per_sec_per_gpu(hw: Hardware, model: Model, train: TrainingCfg, dp_world_size: int = 1) -> dict:
    F_tok = per_token_flops(model, train)
    M_tok = per_token_hbm_bytes(model, train)
    M_tok_net = per_token_net_bytes(model, train, dp_world_size)

    r_comp = (hw.peak_flops_tflops * 1e12) / F_tok
    r_mem  = (hw.hbm_tbps * 1e12)        / M_tok
    if M_tok_net > 0 and hw.nic_gbps > 0:
        nic_Bps = hw.nic_gbps * 1e9 / 8.0
        r_net = nic_Bps / M_tok_net
    else:
        r_net = float('inf')

    r_gpu = hw.utilization * min(r_comp, r_mem, r_net)
    which_bound = ["compute","memory","network"][ [r_comp, r_mem, r_net].index(min(r_comp, r_mem, r_net)) ]
    return dict(
        r_gpu=r_gpu, r_comp=r_comp, r_mem=r_mem, r_net=r_net, bound=which_bound,
        intensity=(F_tok/M_tok), machine_balance=hw.machine_balance_flops_per_byte
    )

def plan_cluster(hw: Hardware, model: Model, train: TrainingCfg, tokens_per_sec_target: float, dp_world_size:int=1):
    rates = tokens_per_sec_per_gpu(hw, model, train, dp_world_size)
    gpus = max(1, ceil(tokens_per_sec_target / rates["r_gpu"]))
    cost_per_hr = gpus * hw.price_per_gpu_hr
    cluster_tbs = (tokens_per_sec_target * per_token_hbm_bytes(model, train)) / 1e12
    return dict(
        hardware=hw.name, gpus=gpus, per_hr=round(cost_per_hr,2),
        per_gpu_tokens_s=round(rates["r_gpu"],2), bound=rates["bound"],
        cluster_HBM_TBs=round(cluster_tbs,3),
        per_TBs_hr=round(hw.price_per_gpu_hr / hw.hbm_tbps, 2)
    )

# --- Example catalog (replace with your real specs/prices) ---
h100 = Hardware(name="H100-SXM", peak_flops_tflops=900, hbm_tbps=3.35, nic_gbps=200, price_per_gpu_hr=3.93)
h200 = Hardware(name="H200",     peak_flops_tflops=1000, hbm_tbps=4.80, nic_gbps=400, price_per_gpu_hr=4.33)
l4   = Hardware(name="L4",       peak_flops_tflops=120,  hbm_tbps=0.30, nic_gbps=100, price_per_gpu_hr=1.15)

# --- Example model/run (tune to your case) ---
model = Model(n_params=70e9, layers=80, d_model=8192, bytes_per_elem=2)  # 70B decoder
train  = TrainingCfg(
    k_flops_per_token=6.0,  # Adam training
    recompute_mult=1.2,     # some checkpointing
    alpha_opt_bytes_per_param=16.0,
    c_act=5.0,              # FlashAttention/fused kernels
    global_tokens_per_step=512_000,
    bytes_per_grad_elem=2
)

def compare_options(target_tokens_s=200_000, dp_world_size=8):
    out = []
    for hw in [h100, h200, l4]:
        out.append(plan_cluster(hw, model, train, target_tokens_s, dp_world_size))
    return sorted(out, key=lambda x: x["per_hr"])

if __name__ == "__main__":
    from pprint import pprint
    pprint(compare_options())

[{'bound': 'compute',
  'cluster_HBM_TBs': 1.748,
  'gpus': 135,
  'hardware': 'H200',
  'per_TBs_hr': 0.9,
  'per_gpu_tokens_s': 1488.1,
  'per_hr': 584.55},
 {'bound': 'compute',
  'cluster_HBM_TBs': 1.748,
  'gpus': 150,
  'hardware': 'H100-SXM',
  'per_TBs_hr': 1.17,
  'per_gpu_tokens_s': 1339.29,
  'per_hr': 589.5},
 {'bound': 'compute',
  'cluster_HBM_TBs': 1.748,
  'gpus': 1120,
  'hardware': 'L4',
  'per_TBs_hr': 3.83,
  'per_gpu_tokens_s': 178.57,
  'per_hr': 1288.0}]


---

## 8) Tuning the knobs (what the coefficients mean)

$\alpha_{\text{opt}}$ (**bytes/param/step**)

* Captures parameter, gradient, and optimizer state traffic.
* Adam bf16/fp16: often **16–20 B**
* 8-bit Adam / sharded optimizers: lower (tune from logs)

$c_{\text{act}}$ (**activation traffic**)

* Encodes how IO-aware your kernels are.
* Baseline attention: higher $c_{\text{act}}$
* FlashAttention + fused MLP/LayerNorm: lowers $c_{\text{act}}$

$\gamma$ (**recompute multiplier**)

* Activation checkpointing trades FLOPs for less memory. Expect **1.1–1.4** depending on policy.

$B_g$ (**global tokens/step**)

* Increasing $B_g$ amortizes optimizer and network bytes per token, often moving you from network/memory bound toward compute bound—until optimizer dynamics/generalization limit further growth.

---

## 9) Common failure modes (and fixes)

- **Network-bound**: bound == "network": Increase $B_g$; reduce pure DP (add TP/PP, ZeRO), overlap comms, or boost effective NIC (EFA/IB, topology-aware placement).

- **Memory-bound**: bound == "memory": Reduce bytes/token: FlashAttention/Flash-decoding, fused kernels, selective recompute, lower-precision optimizer/state, MoE (if appropriate), or simply pick GPUs with better \$/TB/s·hr.

- **Compute-bound** but low tokens/sec: Your utilization factor is pessimistic. Improve kernels (Flash/Flash-MLA, fused ops), overlap (streams), and ensure you’re not secretly constrained by PCIe/NVLink or CPU input pipelines.

---

## 10) Extensions

* Inference: Set $\kappa \approx 2$, $\alpha_{\text{opt}}=0$, and replace activation term with KV-cache bytes/token.
* MoE: Replace $N$ with the active parameters per token; keep memory/network terms aligned with your routing fraction and expert parallelism.
* Long context: $c_{\text{act}}$ rises; IO-aware attention matters even more.
* Optimizer offload: Reduce $\alpha_{\text{opt}}$ but watch network/PCIe traffic tradeoffs.

---

## Conclusion

Training frontier-scale models efficiently is mostly about buying and feeding the right bandwidth—on the GPU (HBM), across GPUs (NVLink/NVSwitch), and across nodes (NIC/fabric). The simple model above lets you quantify that trade space, decide whether you’re compute, memory, or network bound, and choose the cheapest TB/s that achieves your tokens/sec target.

If you want, share your model size, sequence length, global batch, optimizer, and two or three candidate GPU types. I’ll plug them into this model and hand back a concrete capacity and cost comparison.

## References

* **Roofline model (compute vs. memory bound):** Williams, Waterman, Patterson, *CACM* (2009). ([ACM Digital Library][1])
* **FlashAttention (I/O-aware attention):** Dao et al., *arXiv:2205.14135* (2022). ([arXiv][2])
* **FLOPs per token ≈ 6·N (training rule of thumb):** *How To Scale Your Model* (JAX-ML book), “All the transformer math you need to know.” ([jax-ml.github.io][3])
* **Megatron-LM (scaling & comms patterns):** Shoeybi et al., *arXiv:1909.08053*; NVIDIA Megatron-LM repo notes. ([arXiv][4], [GitHub][5])
* **Data-parallel gradient all-reduce (collective definition & cost intuition):** NCCL User Guide (AllReduce); UCSD notes on reduce-scatter + all-gather. ([NVIDIA Docs][6], [Hao AI Lab][7])
* **ZeRO (optimizer/activation sharding to cut memory traffic):** Rajbhandari et al., *SC20* / arXiv:1910.02054. ([arXiv][8], [aiichironakano][9])
* **8-bit optimizers (reduce optimizer-state bytes):** Dettmers et al., *arXiv:2110.02861*; bitsandbytes docs. ([arXiv][10], [Hugging Face][11])
* **HBM bandwidth figures (for \$/TB/s·hr):**

  * NVIDIA **H100** datasheet—HBM3 up to **3.35 TB/s**. ([Megware][12])
  * NVIDIA **H200**—HBM3e **4.8 TB/s**. ([NVIDIA][13])
  * NVIDIA **B200**—HBM3e up to **8.0 TB/s**. ([primeline-solutions.com][14])
  * NVIDIA **L4**—GDDR6 **300 GB/s**. ([NVIDIA][15])
* **AWS EFA + NCCL (networking for multi-node training):** AWS EFA + NCCL guide; aws-ofi-nccl plugin. ([AWS Documentation][16], [GitHub][17])
* **AWS UltraClusters / Capacity Blocks (cluster placement & pricing examples):** Capacity Blocks overview & pricing tables (p5/H100, p5e/H200). ([AWS Documentation][18], [Amazon Web Services, Inc.][19])
* **P6-B200 (Blackwell) instance announcement & specs:** AWS blog (May 15, 2025). ([Amazon Web Services, Inc.][20])

[1]: https://dl.acm.org/doi/10.1145/1498765.1498785 "Roofline: an insightful visual performance model for multicore ..."
[2]: https://arxiv.org/abs/2205.14135 "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness"
[3]: https://jax-ml.github.io/scaling-book/transformers/ "All the Transformer Math You Need to Know | How To Scale ..."
[4]: https://arxiv.org/pdf/1909.08053 "Megatron-LM: Training Multi-Billion Parameter Language ..."
[5]: https://github.com/NVIDIA/Megatron-LM "NVIDIA/Megatron-LM: Ongoing research training ..."
[6]: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html "Collective Operations — NCCL 2.27.5 documentation"
[7]: https://hao-ai-lab.github.io/dsc204a-w24/assets/scribe_notes/Feb_5_scribe_note.pdf "13: Collective Communication 2 - Hao AI Lab @ UCSD"
[8]: https://arxiv.org/abs/1910.02054 "ZeRO: Memory Optimizations Toward Training Trillion ..."
[9]: https://aiichironakano.github.io/cs596/Rajbhandari-ZeRO-SC20.pdf "ZeRO: Memory Optimizations Toward Training Trillion ..."
[10]: https://arxiv.org/abs/2110.02861 "8-bit Optimizers via Block-wise Quantization"
[11]: https://huggingface.co/docs/bitsandbytes/main/en/optimizers "8-bit optimizers"
[12]: https://www.megware.com/fileadmin/user_upload/LandingPage%20NVIDIA/nvidia-h100-datasheet.pdf "NVIDIA H100 Tensor Core GPU Datasheet"
[13]: https://www.nvidia.com/en-us/data-center/h200/ "NVIDIA H200 Tensor Core GPU"
[14]: https://www.primeline-solutions.com/media/categories/server/nach-gpu/nvidia-hgx-h200/nvidia-blackwell-b200-datasheet.pdf "nvidia-blackwell-b200-datasheet.pdf"
[15]: https://www.nvidia.com/en-us/data-center/l4/ "L4 Tensor Core GPU for AI & Graphics"
[16]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/efa-start-nccl.html "Get started with EFA and NCCL for ML workloads on Amazon ..."
[17]: https://github.com/aws/aws-ofi-nccl "aws/aws-ofi-nccl: This is a plugin which lets EC2 ..."
[18]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-capacity-blocks.html "Capacity Blocks for ML - Amazon Elastic Compute Cloud"
[19]: https://aws.amazon.com/ec2/capacityblocks/pricing/ "Amazon EC2 Capacity Blocks for ML Pricing"
[20]: https://aws.amazon.com/blogs/aws/new-amazon-ec2-p6-b200-instances-powered-by-nvidia-blackwell-gpus-to-accelerate-ai-innovations/ "New Amazon EC2 P6-B200 instances powered by NVIDIA ..."
