In [6]:
# Communication Accounting, part a)
# Assumptions:
# - Each FFN is two linear layers — one from d_model to d_ff and one from d_ff to d_model (no activation)
# - Each block consists entirely of a single FFN
# - We omit attention, input embeddings, norms, and the output linear layer (i.e. we only count the FFN parameters)
# - We are not using activation checkpointing, so must keep the input to each linear layer

import math

print("Assumption: batch size = 128, sequence length = 1024\n-")

def bytes_to_mb(bytes):
    return bytes / (1024 * 1024)

def bytes_to_gb(bytes):
    return bytes / (1024 * 1024 * 1024)

d_model = 16384
d_ff = 53248
num_blocks = 126

params_per_block = d_model * d_ff + d_ff * d_model
total_params = params_per_block * num_blocks

print(f"Total parameters: {total_params}")

param_bytes = total_params * 4 # keeping params, accumulated gradients, and optimizer state in float32
accum_grad_bytes = param_bytes # one gradient per parameter, in float32
optim_bytes = param_bytes * 2 # first and second moment for AdamW, in float32
m_static = param_bytes + accum_grad_bytes + optim_bytes

print(f"Parameter bytes: {bytes_to_gb(param_bytes)} GB")
print(f"Accumulated gradient bytes: {bytes_to_gb(accum_grad_bytes)} GB")
print(f"Optimizer bytes: {bytes_to_gb(optim_bytes)} GB")
print(f"Static memory: {bytes_to_gb(m_static)} GB")
print("-")

# ----- Activation memory -----

# To compute gradients, we need the input to the second linear layer (d_ff) and the first linear layer (d_model)
batch_size = 128
seq_len = 1024
num_samples = batch_size * seq_len
elements_per_sample = num_blocks * (d_model + d_ff) 
activation_bytes_per_sample = elements_per_sample * 2 # we keep activations in bfloat16
activation_bytes = activation_bytes_per_sample * num_samples
m_act = activation_bytes

print(f"Activation bytes: {bytes_to_gb(activation_bytes)} GB")
print(f"Number of samples (B * L): {num_samples}")
print(f"Activation bytes per sample: {bytes_to_gb(activation_bytes_per_sample)} GB")

print("-")

bytes_per_h100 = 80 * 1024**3
n_h100 = math.ceil((m_static + m_act) / bytes_per_h100)

print(f"Number of H100 GPUs: {n_h100}")


Assumption: batch size = 128, sequence length = 1024
-
Total parameters: 219848638464
Parameter bytes: 819.0 GB
Accumulated gradient bytes: 819.0 GB
Optimizer bytes: 1638.0 GB
Static memory: 3276.0 GB
-
Activation bytes: 2142.0 GB
Number of samples (B * L): 131072
Activation bytes per sample: 0.0163421630859375 GB
-
Number of H100 GPUs: 68


In [10]:
# Communication Accounting, part b)
# - Assume your master weights, optimizer state, and half of the activations (in practice every second layer) are sharded across n_fsdp devices
# - Write an expression for how much memory this would take per device
# - What value does n_fsdp need to be to keep the memory cost under 1 v5p TPU (95GB per device)?
# - Deliverable: calculations and a one-sentence response

m_act_half = m_act / 2
# m_fsdp = (m_static + m_act_half) / n_fsdp
m_total = m_static + m_act_half

bytes_per_v5p = 95 * 1024**3
n_v5p = math.ceil(m_total / bytes_per_v5p)

print(f"Number of v5p TPUs: {n_v5p}")


Number of v5p TPUs: 46
