In [1]:
! nvidia-smi

Thu Aug  7 13:51:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.20             Driver Version: 570.133.20     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA B200                    On  |   00000000:DC:00.0 Off |                    0 |
| N/A   22C    P0            140W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from torch import Tensor
from omegaconf.dictconfig import DictConfig
torch.backends.cuda.matmul.allow_tf32 = True

In [3]:
# helpers
def sizeof_fmt(num):
    for unit in ("", "K", "M", "G", "T"):
        if abs(num) < 1000:
            return f"{num:.2f}{unit}B"
        num /= 1000

# Muon implementation

In [4]:
class Muon(torch.optim.Optimizer):
    """Muon that batches over >2D layers, based on https://github.com/KellerJordan/Muon/blob/master/muon.py"""
    def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                if len(state) == 0:
                    state["momentum_buffer"] = torch.zeros_like(p)
                update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                p.mul_(1 - group["lr"] * group["weight_decay"])
                # p.add_(update.reshape(p.shape), alpha=-group["lr"]) # <-- CHANGE: No longer need reshape as update preserves shape
                p.add_(update, alpha=-group["lr"]) # <-- CHANGED

        return loss

def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
    momentum.lerp_(grad, 1 - beta)
    update = grad.lerp_(momentum, beta) if nesterov else momentum
    # if update.ndim == 4: # for the case of conv filters <-- CHANGE: Removed this block that flattens the tensor.
    #     update = update.view(len(update), -1)
    update = zeropower_via_newtonschulz5(update, steps=ns_steps)
    # update *= max(1, grad.size(-2) / grad.size(-1))**0.5 <-- CHANGE: Swapped numerator/denominator to match JAX logic.
    update *= max(1, grad.size(-1) / grad.size(-2))**0.5 # <-- CHANGED
    return update

def zeropower_via_newtonschulz5(G, steps: int):
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X
    
    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

# RoPE implementation

In [5]:
def apply_rope(
    inputs: Tensor, # [B, N, T, H]
    positions: Tensor, # [B, T]
    max_wavelength: int = 10_000,
    scale_factor: float = 1.0,
) -> Tensor:
    """Applies RoPE."""
    B, N, T, H = inputs.shape
    device = inputs.device
    if scale_factor < 1.0:
        raise ValueError(f'scale_factor must be >= 1.0, got {scale_factor}')

    fraction = 2 * torch.arange(0, H // 2, device=device) / H # [H/2]
    timescale = max_wavelength**fraction # [H/2]

    sinusoid_inp = (positions[:, :, None] / timescale[None, None, :]) # [B, T, H/2]
    sinusoid_inp = sinusoid_inp[:, None, :, :] # [B, 1, T, H/2]
    sinusoid_inp /= scale_factor # [B, 1, T, H/2]

    sin = torch.sin(sinusoid_inp) # [B, 1, T, H/2]
    cos = torch.cos(sinusoid_inp) # [B, 1, T, H/2]

    first_half, second_half = torch.chunk(inputs, 2, dim=-1) # [B, N, T, H/2]
    first_part = first_half * cos - second_half * sin # [B, N, T, H/2]
    second_part = second_half * cos + first_half * sin # [B, N, T, H/2]
    out = torch.concatenate([first_part, second_part], dim=-1) # [B, N, T, H]
    return out.to(inputs.dtype) # [B, N, T, H]

# Transformer implementation

In [6]:
class TransformerDecoder(nn.Module):
    def __init__(self, c: DictConfig, dtype=None):
        super().__init__()
        dtype = getattr(torch, c.dtype)
        self.token_embed_in = nn.Embedding(c.V, c.D, dtype=dtype)
        self.token_embed_out = nn.Linear(c.D, c.V, bias=False, dtype=dtype)
        self.blocks = nn.ModuleList([TransformerBlock(c.D, c.H, dtype) for _ in range(c.L)])
        self.out_ln = nn.RMSNorm(c.D, elementwise_affine=False, dtype=dtype)
        self.remat = c.remat

    def forward(self, x): # [B, S]

        # token embedding
        h = self.token_embed_in(x) # [B, T, D]

        # transformer blocks
        for block in self.blocks:
            h = checkpoint(block, h, use_reentrant=False) if self.remat else block(h)

        # project back to vocabulary
        h = self.out_ln(h)
        logits = self.token_embed_out(h) # [B, T, V]

        # get loss
        # we return loss (rather than logits) to reduce peak memory usage
        y = torch.roll(x, -1, dims=1).to(torch.int64)
        y[:, -1] = -1 # do not train on these indices
        loss = F.cross_entropy(logits.flatten(end_dim=-2), y.flatten(), ignore_index=-1)

        return loss


class TransformerBlock(nn.Module):
    def __init__(self, D, H, dtype):
        super().__init__()
        self.ln1 = nn.RMSNorm(D, elementwise_affine=False, dtype=dtype)
        self.ln2 = nn.RMSNorm(D, elementwise_affine=False, dtype=dtype)
        self.attn = MultiHeadAttention(D, H, dtype)
        self.mlp = MLP(D, dtype)

    def forward(self, x): # [B, T, D]
        x = x + self.attn(self.ln1(x)) # attention block
        return x + self.mlp(self.ln2(x)) # MLP block


class MultiHeadAttention(nn.Module):
    """Causal attention layer."""
    def __init__(self, D, H, dtype):
        super().__init__()
        N = D // H # number of heads
        self.qkv_proj = Einsum('BTd,SNdH->SBNTH', (3, N, D, H), dtype=dtype)
        self.out_proj = Einsum('BnTh,nhD->BTD', (N, H, D), dtype=dtype)
        self.query_norm = nn.RMSNorm(H, elementwise_affine=False, dtype=dtype)
        self.key_norm = nn.RMSNorm(H, elementwise_affine=False, dtype=dtype)

    def forward(self, x): # [B, T, D]
        B, T, D = x.shape
        device = x.device

        # input projection
        q, k, v = self.qkv_proj(x) # [B, N, T, H]

        # qk-norm
        q = self.query_norm(q)
        k = self.key_norm(k)

        # position embedding
        position = torch.arange(T, device=device)
        q = apply_rope(q, position[None])
        k = apply_rope(k, position[None])

        # attention
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # [B, N, T, H]

        # output projection followed by contraction back to original dims
        out = self.out_proj(out) # [B, T, D]
        return out

class MLP(nn.Module):
    """Multilayer perceptron."""
    def __init__(self, D, dtype):
        super().__init__()
        self.fc1 = nn.Linear(in_features=D, out_features=4*D, bias=False, dtype=dtype)
        self.fc2 = nn.Linear(in_features=4*D, out_features=D, bias=False, dtype=dtype)

    def forward(self, x): # [B, T, D]
        h = F.gelu(self.fc1(x)) # [B, T, F]
        return self.fc2(h) # [B, T, D]


class Einsum(nn.Module):
    def __init__(self, einsum_str, kernel_shape, dtype=None):
        super().__init__()
        self.einsum_str = einsum_str
        self.weight = nn.Parameter(torch.empty(kernel_shape, dtype=dtype))
        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x):
        return torch.einsum(self.einsum_str, x, self.weight)

# Profiling

In [7]:
# set model config (GPT-3 13B)
model_config = DictConfig(dict(
    D = 5140, # model/embed/qkv dim
    L = 40, # num. block layers
    H = 128, # head dimension
    F = 5140 * 4, # FF inner dimension
    N = 5140 // 128, # num. attention heads
    T = 1024, # context/sequence length
    V = 50257,
    remat = True, # gradient checkpointing
    dtype = 'bfloat16',
))

In [8]:
def run():

    # model
    with torch.device('cuda'):
        model = TransformerDecoder(model_config)
    n_params = sum(p.numel() for p in model.parameters())
    print(f'{n_params=:_}')
    print('size of model:', sizeof_fmt(2*n_params))

    # standard optimizer (not fused)
    # optimizer = torch.optim.Adam(model.parameters(), foreach=False)
    # optimizer_dict = {'opt': optimizer}

    # fused optimizer
    # based on https://lightning.ai/pages/community/tutorial/faster-pytorch-training-by-reducing-peak-memory/
    optimizer_dict = {p:torch.optim.SGD([p], foreach=False) for p in model.parameters()} # all params
    # optimizer_dict = {p:Muon([p]) for p in model.blocks.parameters()} # non-embedding params
    # optimizer_dict |= {p:torch.optim.Adam([p], foreach=False) for p in [*model.token_embed_in.parameters(), *model.token_embed_out.parameters()]} # embedding params
    def optimizer_hook(parameter):
        optimizer_dict[parameter].step()
        optimizer_dict[parameter].zero_grad()
    for p in model.parameters():
        p.register_post_accumulate_grad_hook(optimizer_hook)

    # define training step
    def step():
        T = 1024
        x = torch.randint(model_config.V, [1, model_config.T], dtype=torch.int32, device='cuda')
        loss = model(x)
        loss.backward()
        # optimizer.step()
        # optimizer.zero_grad()

    # warm up model
    for _ in range(2):
        step()
    torch.cuda.synchronize()

    # get optimzier state size
    opt_num_params = 0
    for p, opt in optimizer_dict.items():
        opt_state = opt.state_dict()['state']
        for s1 in opt_state.values():
            for x in s1.values():
                opt_num_params += x.numel()
    print('size of opt. state:', sizeof_fmt(2*opt_num_params))

    # plot step trace
    # with torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=True) as p:
    #     step()
    # p.export_memory_timeline('stack.html', 'cuda:0')

    # print max. memory during step
    torch.cuda.reset_peak_memory_stats("cuda:0")
    step()
    max_mem = torch.cuda.max_memory_allocated("cuda:0")
    print('max. memory allocated:', sizeof_fmt(max_mem))

    # compute size of 'other'
    other_size = max_mem - 2*n_params - 2*opt_num_params
    print('size of "other":', sizeof_fmt(other_size))

    # manully free memory (required given the circular reference btw model and optimizer)
    del model; optimizer_dict.clear()

run()

n_params=13_181_601_960
size of model: 26.36GB
size of opt. state: 0.00B
max. memory allocated: 27.58GB
size of "other": 1.21GB
