In [1]:
! nvidia-smi

Sun Jun 29 11:22:01 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.133.20             Driver Version: 570.133.20     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA B200                    On  |   00000000:DC:00.0 Off |                    0 |
| N/A   22C    P0            143W / 1000W |       0MiB / 183359MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
torch.backends.cuda.matmul.allow_tf32 = True

In [3]:
# helpers
def sizeof_fmt(num):
    for unit in ("", "K", "M", "G", "T"):
        if abs(num) < 1000:
            return f"{num:.2f}{unit}B"
        num /= 1000

In [4]:
class Muon(torch.optim.Optimizer):
    """Muon that batches over >2D layers, based on https://github.com/KellerJordan/Muon/blob/master/muon.py"""
    def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
        defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):

        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                if len(state) == 0:
                    state["momentum_buffer"] = torch.zeros_like(p)
                update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
                p.mul_(1 - group["lr"] * group["weight_decay"])
                # p.add_(update.reshape(p.shape), alpha=-group["lr"]) # <-- CHANGE: No longer need reshape as update preserves shape
                p.add_(update, alpha=-group["lr"]) # <-- CHANGED

        return loss

def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
    momentum.lerp_(grad, 1 - beta)
    update = grad.lerp_(momentum, beta) if nesterov else momentum
    # if update.ndim == 4: # for the case of conv filters <-- CHANGE: Removed this block that flattens the tensor.
    #     update = update.view(len(update), -1)
    update = zeropower_via_newtonschulz5(update, steps=ns_steps)
    # update *= max(1, grad.size(-2) / grad.size(-1))**0.5 <-- CHANGE: Swapped numerator/denominator to match JAX logic.
    update *= max(1, grad.size(-1) / grad.size(-2))**0.5 # <-- CHANGED
    return update

def zeropower_via_newtonschulz5(G, steps: int):
    assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
    a, b, c = (3.4445, -4.7750,  2.0315)
    X = G.bfloat16()
    if G.size(-2) > G.size(-1):
        X = X.mT

    # Ensure spectral norm is at most 1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X
    
    if G.size(-2) > G.size(-1):
        X = X.mT
    return X

In [5]:
# define model
class TransformerDecoder(nn.Module):
    def __init__(self, L, D, V, H=128, dtype=None):
        super().__init__()
        self.token_embed_in = nn.Embedding(V, D, dtype=dtype)
        self.token_embed_out = nn.Linear(D, V, bias=False, dtype=dtype)
        self.blocks = nn.ModuleList([TransformerBlock(D, H, dtype) for _ in range(L)])
        self.out_ln = nn.RMSNorm(D, elementwise_affine=False, dtype=dtype)

    def forward(self, x, y=None): # [B, S]

        # token embedding
        h = self.token_embed_in(x) # [B, T, D]

        # transformer blocks
        for block in self.blocks:
            h = checkpoint(block, h, use_reentrant=False)

        # project back to vocabulary
        h = self.out_ln(h)
        logits = self.token_embed_out(h) # [B, T, V]

        # get loss
        loss = F.cross_entropy(logits.flatten(end_dim=-2), y.flatten())

        return loss


class TransformerBlock(nn.Module):
    def __init__(self, D, H, dtype):
        super().__init__()
        self.ln1 = nn.RMSNorm(D, elementwise_affine=False, dtype=dtype)
        self.ln2 = nn.RMSNorm(D, elementwise_affine=False, dtype=dtype)
        self.attn = MultiHeadAttention(D, H, dtype)
        self.mlp = Mlp(D, dtype)

    def forward(self, x): # [B, T, D]
        x = x + self.attn(self.ln1(x)) # attention block
        return x + self.mlp(self.ln2(x)) # MLP block


class MultiHeadAttention(nn.Module):
    """Causal attention layer."""
    def __init__(self, D, H, dtype):
        super().__init__()
        N = D // H # number of heads
        self.qkv_proj = Einsum('BTd,SNdH->SBTNH', (3, N, D, H), dtype=dtype)
        self.out_proj = Einsum('BTnh,nhD->BTD', (N, H, D), dtype=dtype)
        self.query_norm = nn.RMSNorm(H, elementwise_affine=False, dtype=dtype)
        self.key_norm = nn.RMSNorm(H, elementwise_affine=False, dtype=dtype)

    def forward(self, x): # [B, T, D]
        B, T, D = x.shape

        # input projection
        q, k, v = self.qkv_proj(x) # [B, T, N, H]

        # qk-norm
        q = self.query_norm(q)
        k = self.key_norm(k)

        # position embedding
        # (ommited)

        # attention
        out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # [B, T, N, H]

        # output projection followed by contraction back to original dims
        out = self.out_proj(out) # [B, T, D]
        return out


class Mlp(nn.Module):
    """Multilayer perceptron."""
    def __init__(self, D, dtype):
        super().__init__()
        self.fc1 = nn.Linear(in_features=D, out_features=4*D, bias=False, dtype=dtype)
        self.fc2 = nn.Linear(in_features=4*D, out_features=D, bias=False, dtype=dtype)

    def forward(self, x): # [B, T, D]
        h = F.gelu(self.fc1(x)) # [B, T, F]
        return self.fc2(h) # [B, T, D]


class Einsum(nn.Module):
    def __init__(self, einsum_str, kernel_shape, dtype=None):
        super().__init__()
        self.einsum_str = einsum_str
        self.weight = nn.Parameter(torch.randn(kernel_shape, dtype=dtype))

    def forward(self, x):
        return torch.einsum(self.einsum_str, x, self.weight)

In [6]:
def run():

    # model
    V = 50257
    with torch.device('cuda'):
        # model = TransformerDecoder(L=12, D=768, V=V, dtype=torch.bfloat16) # 124M
        # model = TransformerDecoder(L=24, D=2048, V=V, dtype=torch.bfloat16) # 1.3B
        model = TransformerDecoder(L=40, D=5140, V=V, dtype=torch.bfloat16) # 13B
    n_params = sum(p.numel() for p in model.parameters())
    print(f'{n_params=:_}')
    print('size of model:', sizeof_fmt(2*n_params))

    # standard optimizer (not fused)
    # optimizer = torch.optim.Adam(model.parameters(), foreach=False)
    # optimizer_dict = {'opt': optimizer}

    # fused optimizer
    # based on https://lightning.ai/pages/community/tutorial/faster-pytorch-training-by-reducing-peak-memory/
    optimizer_dict = {p:torch.optim.SGD([p], foreach=False) for p in model.parameters()} # all params
    # optimizer_dict = {p:Muon([p]) for p in model.blocks.parameters()} # non-embedding params
    # optimizer_dict |= {p:torch.optim.Adam([p], foreach=False) for p in [*model.token_embed_in.parameters(), *model.token_embed_out.parameters()]} # embedding params
    def optimizer_hook(parameter):
        optimizer_dict[parameter].step()
        optimizer_dict[parameter].zero_grad()
    for p in model.parameters():
        p.register_post_accumulate_grad_hook(optimizer_hook)

    # define training step
    def step():
        T = 1024
        x = torch.randint(V, [1, T], dtype=torch.int32, device='cuda')
        y = torch.randint(V, [1, T], dtype=torch.int64, device='cuda')
        loss = model(x, y)
        loss.backward()
        # optimizer.step()
        # optimizer.zero_grad()

    # warm up model
    for _ in range(2):
        step()
    torch.cuda.synchronize()

    # get optimzier state size
    opt_num_params = 0
    for p, opt in optimizer_dict.items():
        opt_state = opt.state_dict()['state']
        for s1 in opt_state.values():
            for x in s1.values():
                opt_num_params += x.numel()
    print('size of opt. state:', sizeof_fmt(2*opt_num_params))

    # plot step trace
    # with torch.profiler.profile(record_shapes=True, profile_memory=True, with_stack=True) as p:
    #     step()
    # p.export_memory_timeline('stack.html', 'cuda:0')

    # print max. memory during step
    torch.cuda.reset_peak_memory_stats("cuda:0")
    step()
    max_mem = torch.cuda.max_memory_allocated("cuda:0")
    print('max. memory allocated:', sizeof_fmt(max_mem))

    # compute size of 'other'
    other_size = max_mem - 2*n_params - 2*opt_num_params
    print('size of "other":', sizeof_fmt(other_size))

    # manully free memory (required given the circular reference btw model and optimizer)
    del model; optimizer_dict.clear()

run()

n_params=13_181_601_960
size of model: 26.36GB
size of opt. state: 0.00B
max. memory allocated: 27.58GB
size of "other": 1.21GB
