In [9]:
from dataclasses import dataclass

import torch
import torch.distributed as dist
import torch.nn.functional as F
import torch.utils.benchmark as benchmark

In [18]:
def benchmark_in_us(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e3

In [19]:
gs = [
    *[torch.rand(768, 4*768, dtype=torch.float32, device="cuda") for _ in range(12)],
    *[torch.rand(4*768, 768, dtype=torch.float32, device="cuda") for _ in range(12)],
    *[torch.rand(768, 768, dtype=torch.float32, device="cuda") for _ in range(4*12)],
]

s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
s3 = torch.cuda.Stream()

In [None]:
def zeropower_via_newtonschulz5(G: torch.Tensor, n: int) -> torch.Tensor:
    ...


def single(gs: list[torch.Tensor], n: int):
    g = gs[0]
    zeropower_via_newtonschulz5(g, n)


def multiple(gs: list[torch.Tensor], n: int):
    for g in gs:
        zeropower_via_newtonschulz5(g, n)


def multiplexed_2_streams(gs: list[torch.Tensor], n: int):
    with torch.cuda.stream(s1):
        for g in gs[:9]:
            zeropower_via_newtonschulz5(g, n)
    with torch.cuda.stream(s2):
        for g in gs[9:]:
            zeropower_via_newtonschulz5(g, n)


def multiplexed_3_streams(gs: list[torch.Tensor], n: int):
    with torch.cuda.stream(s1):
        for g in gs[:6]:
            zeropower_via_newtonschulz5(g, n)
    with torch.cuda.stream(s2):
        for g in gs[6:12]:
            zeropower_via_newtonschulz5(g, n)
    with torch.cuda.stream(s3):
        for g in gs[12:]:
            zeropower_via_newtonschulz5(g, n)


def run_benchmarks():
    benchmark_in_us(single, gs, 4)
    benchmark_in_us(single, gs, 5)

    runtime_single_4_steps = benchmark_in_us(single, gs, 4)
    runtime_single_5_steps = benchmark_in_us(single, gs, 5)
    runtime_multiple_4_steps = benchmark_in_us(multiple, gs, 4) / len(gs)
    runtime_multiple_5_steps = benchmark_in_us(multiple, gs, 5) / len(gs)
    runtime_mux_2_streams_4_steps = benchmark_in_us(multiplexed_2_streams, gs, 4) / len(gs)
    runtime_mux_2_streams_5_steps = benchmark_in_us(multiplexed_2_streams, gs, 5) / len(gs)
    runtime_mux_3_streams_4_steps = benchmark_in_us(multiplexed_3_streams, gs, 4) / len(gs)
    runtime_mux_3_streams_5_steps = benchmark_in_us(multiplexed_3_streams, gs, 5) / len(gs)
    print(f"Single, 4 steps: {runtime_single_4_steps} ms")
    print(f"Single, 5 steps: {runtime_single_5_steps} ms")
    print(f"Multiple, 4 steps: {runtime_multiple_4_steps} ms")
    print(f"Multiple, 5 steps: {runtime_multiple_5_steps} ms")
    print(f"Multiplexed 2 streams, 4 steps: {runtime_mux_2_streams_4_steps} ms")
    print(f"Multiplexed 2 streams, 5 steps: {runtime_mux_2_streams_5_steps} ms")
    print(f"Multiplexed 3 streams, 4 steps: {runtime_mux_3_streams_4_steps} ms")
    print(f"Multiplexed 3 streams, 5 steps: {runtime_mux_3_streams_5_steps} ms")

In [None]:
@torch.compile
def zeropower_via_newtonschulz5(G: torch.Tensor, n: int):
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    X.div_(X.norm() + 1e-7)
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(n):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("Original - naive compile")
run_benchmarks()

In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, n: int):
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    X.div_(X.norm() + 1e-7)
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(n):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("Original")
run_benchmarks()



Single, 4 steps: 17.88294709995171 ms
Single, 5 steps: 22.460382799999934 ms
Multiple, 4 steps: 35.80253222222988 ms
Multiple, 5 steps: 44.54966661109615 ms
Multiplexed 2 streams, 4 steps: 35.27657105557106 ms
Multiplexed 2 streams, 5 steps: 43.89762205553578 ms
Multiplexed 3 streams, 4 steps: 35.296148777787394 ms
Multiplexed 3 streams, 5 steps: 43.86552161112276 ms


In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        X = a * X + (b * A + c * A @ A) @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("X = a * X + (b * A + c * A @ A) @ X")
run_benchmarks()



Single, 4 steps: 13330.02134997514 us
Single, 5 steps: 16657.34174998761 us
Multiple, 4 steps: 28866.108499995687 us
Multiple, 5 steps: 36011.61483331655 us
Multiplexed 2 streams, 4 steps: 28352.225111120788 us
Multiplexed 2 streams, 5 steps: 35578.03144440186 us
Multiplexed 3 streams, 4 steps: 28533.46405557507 us
Multiplexed 3 streams, 5 steps: 35852.94805558684 us


In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        X = (a * I + A @ (b * I + c * A)) @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("X = (a * I + A @ (b * I + c * A)) @ X")
run_benchmarks()



Single, 4 steps: 12.982452750020457 ms
Single, 5 steps: 16.230022550053036 ms
Multiple, 4 steps: 29.205880777782212 ms
Multiple, 5 steps: 36.75411722224453 ms
Multiplexed 2 streams, 4 steps: 28.635804111066438 ms
Multiplexed 2 streams, 5 steps: 35.807331499957705 ms
Multiplexed 3 streams, 4 steps: 28.552707444456853 ms
Multiplexed 3 streams, 5 steps: 35.63938483335328 ms


In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        S = A @ (b * I + c * A)
        torch.diagonal(S).add_(a)
        X = S @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("w/ S")
run_benchmarks()

  check(


Single, 4 steps: 13.24205756000083 ms
Single, 5 steps: 16.41230694999649 ms
Multiple, 4 steps: 29.21538677780215 ms
Multiple, 5 steps: 37.094243277756 ms
Multiplexed 2 streams, 4 steps: 28.628497833324218 ms
Multiplexed 2 streams, 5 steps: 35.851054277777116 ms
Multiplexed 3 streams, 4 steps: 28.66707005558864 ms
Multiplexed 3 streams, 5 steps: 35.92424205554481 ms


In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for a, b, c in (
        (4.8969, -14.0610, 10.1415),
        (4.7285, -10.0664, 5.4487),
        (4.0968, -5.9557, 2.3200),
        (3.0319, -3.3993, 1.1814),
    ):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("Original - 4-step")
run_benchmarks()



Single, 4 steps: 18210.466850041485 us
Single, 5 steps: 18138.601850023402 us
Multiple, 4 steps: 35475.30911110799 us
Multiple, 5 steps: 36041.44805553409 us
Multiplexed 2 streams, 4 steps: 35088.7146666739 us
Multiplexed 2 streams, 5 steps: 35514.98688890812 us
Multiplexed 3 streams, 4 steps: 35314.020166677234 us
Multiplexed 3 streams, 5 steps: 35532.948000006094 us


In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    a, b, c = (3.4445, -4.7750, 2.0315)
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for a, b, c in (
        (4.8969, -14.0610, 10.1415),
        (4.7285, -10.0664, 5.4487),
        (4.0968, -5.9557, 2.3200),
        (3.0319, -3.3993, 1.1814),
    ):
        A = X @ X.T
        X = (a * I + b * A @ (I + c/b * A)) @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("X = (a * I + b * A @ (I + c/b * A)) @ X - 4-step")
run_benchmarks()



Single, 4 steps: 13137.565900024128 us
Single, 5 steps: 13093.40089997022 us
Multiple, 4 steps: 13150.539777775015 us
Multiple, 5 steps: 13062.95088887863 us
Multiplexed 2 streams, 4 steps: 12591.63422221516 us
Multiplexed 2 streams, 5 steps: 12495.095277775667 us
Multiplexed 3 streams, 4 steps: 12366.734222218332 us
Multiplexed 3 streams, 5 steps: 12408.445333322663 us


In [None]:
# reduce-overhead is slower
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5(G: torch.Tensor, steps: int):
    n, m = G.size()
    X = G.bfloat16()
    I = torch.eye(min(n, m), dtype=X.dtype, device=X.device)
    X.div_(X.norm() + 1e-7)
    if n > m:
        X = X.T
    for a, b, c in (
        (4.8969, -14.0610, 10.1415),
        (4.7285, -10.0664, 5.4487),
        (4.0968, -5.9557, 2.3200),
        (3.0319, -3.3993, 1.1814),
    ):
        A = X @ X.T
        S = A @ (b * I + c * A)
        torch.diagonal(S).add_(a)
        X = S @ X
    if G.size(0) > G.size(1):
        X = X.T
    return X


print("w/ S - 4-step")
run_benchmarks()

In [None]:
def zeropower_via_svd(G, steps=None, eps=1e-7):
    U, _, V = G.svd()
    return U @ V.T


@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7, coeffs=(3.4445, -4.7750,  2.0315)):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = coeffs
    X = G.bfloat16()
    X /= (X.norm() + eps) # ensure top singular value <= 1
    if G.size(0) > G.size(1):
        X = X.T
    for _ in range(steps):
        A = X @ X.T
        B = A @ X
        X = a * X + b * B + c * A @ B
    if G.size(0) > G.size(1):
        X = X.T
    return X

zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)

class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz

    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
    the advantage that it can be stably run in bfloat16 on the GPU.

    Some warnings:
    - This optimizer assumes that all parameters passed in are 2D.
    - It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
    parameters; those should all be optimized by a standard method (e.g., AdamW).
    - To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
    - We believe it is unlikely to work well for training with small batch size.
    - We believe it may not work well for finetuning pretrained models, but we haven't tested this.
    - We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).

    Arguments:
        lr: The learning rate used by the internal SGD.
        momentum: The momentum used by the internal SGD.
        nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
        backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
        backend_steps: The number of iteration steps to use in the backend, if it is iterative.
    """
    def __init__(self, params, lr=3e-4, momentum=0.95, nesterov=True,
                 backend='newtonschulz5', backend_steps=5, backend_coeffs=(3.4445, -4.7750,  2.0315),
                 rank=0, world_size=1):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps, backend_coeffs=backend_coeffs)
        super().__init__(params, defaults)
        self.rank = rank
        self.world_size = world_size

    @torch.compile
    def step(self):

        for group in self.param_groups:

            lr = group['lr']
            momentum = group['momentum']
            zeropower_backend = zeropower_backends[group['backend']]

            # generate weight updates in distributed fashion
            total_params = sum(p.numel() for p in group['params'])
            updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
            curr_idx = 0
            for i, p in enumerate(group['params']):
                # luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
                if i % self.world_size == self.rank:
                    g = p.grad
                    if g is None:
                        continue
                    state = self.state[p]
                    if 'momentum_buffer' not in state:
                        state['momentum_buffer'] = torch.zeros_like(g)
                    buf = state['momentum_buffer']
                    buf.mul_(momentum).add_(g)
                    if group['nesterov']:
                        g = g.add(buf, alpha=momentum)
                    g = zeropower_backend(g, steps=group['backend_steps'])
                    g *= max(g.size(0), g.size(1))**0.5 # scale to have update.square().mean() == 1
                    updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
                curr_idx += p.numel()

            # sync updates across devices. we are not memory-constrained so can do this simple deserialization
            dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)

            # deserialize and apply updates
            curr_idx = 0
            for p in group['params']:
                g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
                p.data.add_(g, alpha=-lr)
                curr_idx += p.numel()

In [4]:
class Rotary(torch.nn.Module):

    def __init__(self, dim, base=10000):
        super().__init__()
        self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x):
        seq_len = x.shape[1]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t, self.inv_freq).to(x.device)
            self.cos_cached = freqs.cos().bfloat16()
            self.sin_cached = freqs.sin().bfloat16()
        return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]

def apply_rotary_emb(x, cos, sin):
    assert x.ndim == 4 # multihead attention
    d = x.shape[3]//2
    x1 = x[..., :d]
    x2 = x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3).type_as(x)

class CausalSelfAttention(torch.nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = self.n_embd // self.n_head
        assert self.n_embd % self.n_head == 0
        self.c_q = torch.nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_k = torch.nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_v = torch.nn.Linear(self.n_embd, self.n_embd, bias=False)
        # output projection
        self.c_proj = torch.nn.Linear(self.n_embd, self.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
        self.rotary = Rotary(self.head_dim)

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
        k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
        v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
        cos, sin = self.rotary(q)
        q, k = F.rms_norm(q, (q.size(-1),)), F.rms_norm(k, (k.size(-1),)) # QK norm suggested by @Grad62304977
        q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
        y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
        y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y

class MLP(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
        self.c_proj  = torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
        self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977

    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
        x = self.c_proj(x)
        return x

class Block(torch.nn.Module):

    def __init__(self, config):
        super().__init__()
        self.attn = CausalSelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(F.rms_norm(x, (x.size(-1),)))
        x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
        return x

In [5]:
@dataclass
class GPTConfig:
    vocab_size : int = 50304
    n_layer : int = 12
    n_head : int = 6 # head dim 128 suggested by @Grad62304977
    n_embd : int = 768
    sequence_length : int = 1024

class GPT(torch.nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = torch.nn.ModuleDict(dict(
            wte = torch.nn.Embedding(config.vocab_size, config.n_embd),
            h = torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ))
        self.lm_head = torch.nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

    def forward(self, idx, targets=None, return_logits=True):

        # forward the GPT model itself
        x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        for block in self.transformer.h:
            x = block(x)
        x = F.rms_norm(x, (x.size(-1),))

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            logits = logits.float() # use tf32/fp32 for logits
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            logits = logits.float() # use tf32/fp32 for logits
            loss = None

        # there are performance reasons why not returning logits is prudent, if not needed
        if not return_logits:
            logits = None

        return logits, loss

In [6]:
num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))

In [7]:
model = model.cuda()
model = torch.compile(model)
model

OptimizedModule(
  (_orig_mod): GPT(
    (transformer): ModuleDict(
      (wte): Embedding(50304, 768)
      (h): ModuleList(
        (0-11): 12 x Block(
          (attn): CausalSelfAttention(
            (c_q): Linear(in_features=768, out_features=768, bias=False)
            (c_k): Linear(in_features=768, out_features=768, bias=False)
            (c_v): Linear(in_features=768, out_features=768, bias=False)
            (c_proj): Linear(in_features=768, out_features=768, bias=False)
            (rotary): Rotary()
          )
          (mlp): MLP(
            (c_fc): Linear(in_features=768, out_features=3072, bias=False)
            (c_proj): Linear(in_features=3072, out_features=768, bias=False)
          )
        )
      )
    )
    (lm_head): Linear(in_features=768, out_features=50304, bias=False)
  )
)

In [8]:
model_param_size = sum(p.numel() for p in model.parameters())
print(f"Num params: {model_param_size}")

Num params: 123568128


In [9]:
optimizer = Muon(model.transformer.h.parameters(),           lr=0.02,  momentum=0.95)

In [10]:
x = torch.arange(10).unsqueeze(0).expand(2, -1)
y = torch.arange(20)

In [11]:
_, loss = model(x, y)

TorchRuntimeError: Failed running call_function <function cross_entropy at 0x7f507b23e700>(*(FakeTensor(..., device='cuda:0', size=(20, 50304), grad_fn=<ViewBackward0>), FakeTensor(..., size=(20,), dtype=torch.int64)), **{'ignore_index': -1}):
Unhandled FakeTensor Device Propagation for aten.gather.default, found two different devices cuda:0, cpu

from user code:
   File "/tmp/ipykernel_3468/802198899.py", line 34, in forward
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In [28]:
loss.backward()

In [29]:
optimizer.step()

KeyboardInterrupt: 

In [20]:
learning_rate = 0.001
weight_decay = 0.9
ddp_rank = 1
ddp_world_size = 2

In [21]:
square_params = [
    param
    for param in model.transformer.h.parameters()
    if param.shape[0] == param.shape[1]
]
nonsquare_params = [
    param
    for param in model.transformer.h.parameters()
    if param.shape[0] != param.shape[1]
]

In [22]:
optimizer1 = torch.optim.AdamW(model.lm_head.parameters(), lr=learning_rate, betas=(0.9, 0.95),
                               weight_decay=weight_decay, fused=True)
optimizer2 = Muon(square_params, lr=0.1*learning_rate, momentum=0.95,
                  rank=ddp_rank, world_size=ddp_world_size)
optimizer3 = Muon(nonsquare_params, lr=0.1*learning_rate, momentum=0.95,
                  rank=ddp_rank, world_size=ddp_world_size,
                  backend_steps=3, backend_coeffs=(3.5981, -5.1223,  2.2324))
optimizers = [optimizer1, optimizer2, optimizer3]

In [111]:
nonsquare_param_size = sum(p.numel() for p in nonsquare_params)
print(f"Num params: {nonsquare_param_size} or {nonsquare_param_size/model_param_size:.2f}")

Num params: 56623104 or 0.46


In [112]:
square_param_size = sum(p.numel() for p in square_params)
print(f"Num params: {square_param_size} or {square_param_size/model_param_size:.2f}")

Num params: 28311552 or 0.23


## MISC

In [None]:
key = jax.random.PRNGKey(0)

for N, M in MATRIX_SHAPES:
    errors = run_experiment(key, N, M, iterator_newton_schulz, 10, 10)
    plt.plot(errors, label=f"{N}x{M}")

plt.xlabel("Iteration steps")
plt.ylabel("Mean squared distance from 1.0")
plt.ylim(0, 1)
plt.legend()
plt.show()

In [None]:
key = jax.random.PRNGKey(0)

for N, M in MATRIX_SHAPES:
    errors = run_experiment(key, N, M, iterator_keller_jordan_opt, 10, 10)
    print(f"{N}x{M}: {errors[3]}")
    plt.plot(errors, label=f"{N}x{M}")

plt.xlabel("Iteration steps")
plt.ylabel("Mean squared distance from 1.0")
plt.ylim(0, 1)
plt.legend()
plt.show()

In [None]:
key = jax.random.PRNGKey(0)

gamma = 2.5
basin_radius = 0.15
roots = [
    1 - basin_radius,
    1 + basin_radius,
    0,
    -1 - basin_radius,
    -1 + basin_radius
]
h = lambda x: x + gamma * (x - roots[0])*(x - roots[1])*(x - roots[2])*(x - roots[3])*(x - roots[4])

for N, M in MATRIX_SHAPES:
    errors = run_experiment(key, N, M, h, 10, 10)
    print(f"{N}x{M}: {errors[3]}, {errors[5]}, {errors[7]}")
    plt.plot(errors, label=f"{N}x{M}")

plt.xlabel("Iteration steps")
plt.ylabel("Mean squared distance from 1.0")
plt.ylim(0, 1)
plt.legend()
plt.show()

In [None]:
peak, trough = get_odd_degree_iterator_stats(2.4, 0.13)
peak, trough

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


x = jnp.linspace(0, 1.2, 100)

plt.plot(x, x, label=f"$y = x$")

# y = g(x)
# plt.plot(x, y, label="g(x)")

# dy = jax.vmap(jax.grad(g))(x)
# plt.plot(x, dy, label="g'(x)")

for gamma, r in [[-0.7530, 0.13], [-0.9223, 0.11], [2.4, 0.13], [1.2284936824712378, 0.05], [1.25, 0]]:
    b1 = 1 - r
    b2 = 1 + r
    h = lambda x: gamma * (x - b1)*(x - b2)*(x)*(x + b1)*(x + b2) + x

    y = h(x)
    plt.plot(x, y, label=f"$h(x)$ w/ $\gamma = {gamma}$, $b1 = {b1}$, $b2 = {b2}$")

    # dy = jax.vmap(jax.grad(h))(x)
    # plt.plot(x, dy, label=f"$h'(x)$ w/ $\gamma = {gamma}$, $b1 = {b1}$, $b2 = {b2}$")

plt.legend()
plt.show()

In [None]:
x = np.linspace(-1.3, 1.3, 100)
y = g(x)

plt.plot(x, x, label='y=x', color="black", linestyle="--")
plt.plot(x, y, label='y=g(x)', color="blue")

for i, (in_features, out_features) in enumerate(MATRIX_SHAPES[-1:]):
    y_ = h_iterators[i](x)
    plt.plot(x, y_, label=f"y=h(x) for {out_features}x{in_features}")

gamma = -3
r = 0.11
fp = [
    0,
    1 - r,
    1 + r,
    -1 - r,
    -1 + r,
    -1,
    1,
]
h = lambda x: x + gamma * (x - fp[0])*(x - fp[1])*(x - fp[2])*(x - fp[3])*(x - fp[4])*(x - fp[5])*(x - fp[6])

y_ = h(x)
plt.plot(x, y_, label=f"test")

plt.xlim(-(1.26373+0.1), 1.26373+0.1)
plt.ylim(-(1.26373+0.1), 1.26373+0.1)
plt.grid()
plt.legend()

In [None]:
in_features, out_features = 2**14, 768
K = jnp.sqrt(1/in_features)

X = jax.random.uniform(
    key,
    shape=(out_features, in_features),
    minval=-K,
    maxval=K,
)

In [None]:
S = jnp.linalg.svd(X, compute_uv=False)
jnp.median(S)

In [None]:
gamma, r = 0.8, 0.11
fp = [1+r, 1-r, -1-r, -1+r]
h = lambda x: x + gamma * (x - fp[0])*(x - fp[1])*(x - fp[2])*(x - fp[3])

In [None]:
# a, b, c = 3.4445, -4.7750, 2.0315
a, b, c = 3.3196, -4.8811, 2.4

In [None]:
X_1 = X / (jnp.linalg.norm(X, ord='fro') + 1e-6)

S_1 = jnp.linalg.svd(X_1, compute_uv=False)
print(jnp.median(S_1))

for _ in range(5):
    A = X_1 @ X_1.T
    B = A @ X_1
    X_1 = a * X_1 + b * B + c * A @ B
    S_1 = jnp.linalg.svd(X_1, compute_uv=False)
    print(jnp.median(S_1))

In [None]:
X_1 = X / (jnp.linalg.norm(X) + 1e-6)

S_1 = jnp.linalg.svd(X_1, compute_uv=False)
print(jnp.median(S_1))

for _ in range(5):
    S_1 = h(S_1)
    print(jnp.median(S_1))

In [None]:
in_features, out_features = 768, 768
K = jnp.sqrt(1/in_features)

X = jax.random.uniform(
    key,
    shape=(out_features, in_features),
    minval=-K,
    maxval=K,
)

In [None]:
S = jnp.linalg.svd(X, compute_uv=False)
jnp.median(S)

In [None]:
gamma, r = 0.8, 0.11
fp = [1+r, 1-r, -1+r, -1-r]
h = lambda x: x + gamma * (x - fp[0])*(x - fp[1])*(x - fp[2])*(x-fp[3])

In [None]:
X_1 = X / (jnp.linalg.norm(X, ord='fro') + 1e-6)

S_1 = jnp.linalg.svd(X_1, compute_uv=False)
print(jnp.median(S_1))

for i in range(5):
    if i % 2 == 0:
        A = X_1 @ X_1.T
    else:
        A = X_1.T @ X_1
        X_1 = X_1.T
    X_1 = A @ A.T - 2.0242 * A + X_1 + 0.97594641
    S_1 = jnp.linalg.svd(X_1, compute_uv=False)
    print(jnp.median(S_1))

In [None]:
X_1 = X / (jnp.linalg.norm(X) + 1e-6)

S_1 = jnp.linalg.svd(X_1, compute_uv=False)
print(jnp.median(S_1))

for _ in range(5):
    S_1 = h(S_1)
    print(jnp.median(S_1))

In [None]:
import sympy as sp


gamma = sp.Symbol("gamma", positive=True)
r = sp.Symbol("r", interval=(0, 1), left_open=False, right_open=True)
x = sp.Symbol("x")

In [None]:
# fp = [1 - r, 1 + r, 0, -(1 + r), -(1 - r)]
# h = x + gamma * (x - fp[0])*(x - fp[1])*(x - fp[2])*(x - fp[3])*(x - fp[4])
fp = [1 - r, 1 + r, -1-r, -1+r]
h = x + gamma * (x - fp[0])*(x - fp[1])*(x - fp[2])*(x-fp[3])
h

In [None]:
h_simplified = sp.collect(sp.expand(h), x)
h_simplified = h_simplified.refine(sp.Q.positive(gamma)).refine(sp.Q.positive(r))
h_simplified

In [None]:
h_simplified.subs({gamma: 0.8, r: 0.11})

In [None]:
sp.collect(
    sp.expand(
        x*(gamma*(x - 1)*(x**2 + (-r**2 - 1)) + (-2*gamma*r**2+1))
    ),
    x,
)

In [None]:
h_div = sp.collect(sp.simplify((h_simplified - x) / (x - 1)), x)
h_div.subs({gamma: 0.8, r: 0.11})

In [None]:
a, b, c = sp.symbols("a b c")

f = (a*x + b*x**3 + c*x**5).subs({a: 3.4445, b: -4.7750, c: 2.0315})

dfdx = sp.simplify(sp.expand(sp.diff(f, x, real=True)))
dfdx = dfdx.refine(sp.Q.real(x)).refine(sp.Q.real(gamma)).refine(sp.Q.real(r))

all_roots = sp.nroots(dfdx)
real_roots = [root.evalf() for root in all_roots if root.as_real_imag()[1] == 0]
for root in real_roots:
    print(f.subs({x: root}))

In [None]:
sp.collect(sp.expand(h_exp), x)

In [None]:
sp.collect(sp.collect(sp.expand(h_exp), x), x**2)

In [None]:
fp = [1 - r, 1 + r, -1 - r, -1 + r]
h_exp = x + gamma * (x - fp[0])*(x - fp[1])*(x - fp[2])*(x - fp[3])
sp.collect(sp.expand(h_exp), x)

In [None]:
h1_exp = h_exp.subs({
    gamma: best_hyperparams[3][0],
    r: best_hyperparams[3][1],
})
sp.collect(sp.expand(h1_exp), x)

Optimizer (1D) | Steepest Descent in Norm
--- | ---
SGD | Euclidean norm ($\|\cdot\|_2$)
Adam w/o accumulation | $\infty$-norm ($\|\cdot\|_\infty$)
Adam | Dynamically-learned norm from $\infty$-norm

Optimizer (2D) | Steepest Descent in Norm
--- | ---
SGD | Euclidean norm = Frobenius norm ($\|\|\cdot\|\|_F$) = Schatten-$2$ norm ($\|\|\cdot\|\|_{S_2}$)
Adam w/o accumulation | Max-of-Max norm ($\|\|W\|\| = \max_{l}\max_{r} \| M_{l,r} \|$)
Adam | Dynamically-learned norm from Max-of-Max norm
Shampoo w/o accumulation | Spectral norm = Schatten-$\infty$ norm ($\|\|\cdot\|\|_{S_\infty}$)
Shampoo | Dynamically-learned (approx.) $S_p$-norm from Spectral norm
SOAP | Same as above, but with momentum on $\Delta p$
**Muon** | **Static range of $S_p$-norms for large $p$**
Shape-Aware Muon (mine) | Same as above, but range of $p$ depends on matrix shape

### Marchenko–Pastur distribution

Let $X \in R^{d_{out} \times d_{in}}$ be a random matrix with i.i.d. entries drawn from $\mathcal{N}(0, \sigma^2)$.

Then, the lower and upper bounds of the Marchenko-Pastur distribution are $\lambda_{\pm} = \sigma^2(1 \pm \sqrt{M/N})^2$.

### Expected Frobenius Norm of Random Matrix

Recall that if each $x_i$ is drawn from $\mathcal{N}(0, \sigma^2)$, then $\mathbb{E}[|x_i|^2] = \sigma^2$. Thus,

$$\mathbb{E}[\|X\|_F^2] = \Sigma_{c=1}^{M}\Sigma_{r=1}^{N} \mathbb{E}[|a_{ij}|^2] = MN\sigma^2$$
$$\mathbb{E}[\|X\|_F] = \sigma\sqrt{MN}$$

### Expected Initial range of Singular Values

$$\lambda'_{\pm} = \frac{\sigma^2(1 \pm \sqrt{M/N})^2}{\sigma\sqrt{MN}}$$
$$\lambda'_{\pm} = \sigma\frac{(1 \pm \sqrt{M/N})^2}{\sqrt{MN}}$$
$$\lambda'_{\pm} = \sigma\frac{1 \pm 2\sqrt{M/N} + M/N}{\sqrt{MN}}$$

In [None]:
import sympy as sp

sigma = sp.Symbol("sigma")
M = sp.Symbol("M", positive=True)
N = sp.Symbol("N", positive=True)
k = sp.Symbol("k", positive=True)

In [None]:
N <= k

N <= k

In [None]:
lambda_min = sigma * (1 - 2*sp.sqrt(M/N) + M/N) / sp.sqrt(M*N)
lambda_max = sigma * (1 + 2*sp.sqrt(M/N) + M/N) / sp.sqrt(M*N)
lambda_min

In [None]:
sp.radsimp(sp.simplify(sp.solve(lambda_min - 1/2, sigma)[0].subs({M: k*N})))

In [None]:
lambda_min.subs({M: k*N})

In [None]:
lambda_min.subs({M: 4*N, sigma: N})

In [None]:
lambda_min.subs({M: k*N, k: 1}), lambda_min.subs({M: k*N, k: 1/4}), lambda_min.subs({M: k*N, k: 4})

In [None]:
lambda_max.subs({M: k*N, k: 1}), lambda_max.subs({M: k*N, k: 1/4}), lambda_max.subs({M: k*N, k: 4})

In [None]:
lambda_max.subs({M: 4*N, sigma: N/9})

$$\|W\|_{\alpha \to \beta} = \max_{x \in \mathbb{R}^{d_{in}}} \frac{\|Wx\|_\beta}{\|x\|_\alpha}$$

$$\|W\|_{2 \to 2} = \max_{x \in \mathbb{R}^{d}} \frac{\|Wx\|_2}{\|x\|_2} = \|W\|_{\text{spectral}} = \|W\|_{S_{\infty}}$$