# Optimizers

Implementations for SDG, Momentum, etc. up to AdamW and Muon

## SGD (Stochastic Gradient Descent)

**Idea:** Take a step in the negative gradient direction.

```
p = p - lr * g
```

That's it. Simple, but struggles with noisy gradients and ill-conditioned landscapes.

In [None]:
import torch

In [None]:
class SDG(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, weight_decay=0.01):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # Weight Decay
                grad = p.grad
                if group['weight_decay'] != 0.0:
                    # equivalent to: loss = ... + weight_decay/2 * W**2
                    grad = grad.add(p, alpha=group['weight_decay'])  # not in place!
                # Update
                p.data.add_(grad, alpha=-group['lr'])


In [None]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02

# Optimizers
opt_torch = torch.optim.SGD([W1], lr=lr, weight_decay=0.01)
opt_custom = SDG([W2], lr=lr, weight_decay=0.01)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
print(f"All good after {i+1} iterations!")

## SGD with Momentum

**Idea:** Accumulate gradients over time into a "velocity." Smooths out noise, builds up speed in consistent directions.

```
v = B * v + g
p = p - lr * v
```

Typical B = 0.9 (averages ~10 steps). Steady-state velocity is g/(1-B), so effective step is larger than vanilla SGD.

In [None]:
import torch

In [None]:
class SDGMomentum(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, momentum=0.9, weight_decay=0.01):
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # Lazy Init
                if p not in self.state:
                    self.state[p] = {
                        'momentum_buffer': torch.zeros_like(p),
                    }
                # Weight Decay
                grad = p.grad
                if group['weight_decay'] != 0.0:
                    # equivalent to: loss = ... + weight_decay/2 * W**2
                    grad = grad.add(p, alpha=group['weight_decay'])  # not in place!
                # Update Step
                v = self.state[p]['momentum_buffer']
                v.mul_(group['momentum']).add_(grad)
                p.data.add_(v, alpha=-group['lr'])

In [None]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02

# Optimizers
opt_torch = torch.optim.SGD([W1], lr=lr, momentum=0.9, weight_decay=0.01)
opt_custom = SDGMomentum([W2], lr=lr, momentum=0.9, weight_decay=0.01)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    state1 = opt_torch.state[W1]
    state2 = opt_custom.state[W2]
    assert list(state1.keys()) == ['momentum_buffer']
    assert torch.equal(state1['momentum_buffer'], state2['momentum_buffer'])

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
    print(f"Diff {weight_max_diff}")

print(f"All good after {i+1} iterations!")

## RMSprop

**Idea:** Fix AdaGrad by using EMA instead of sum. Old gradients decay away.

```
s = B * s + (1 - B) * g^2
p = p - lr * g / (sqrt(s) + eps)
```

Typical B = 0.99. Learning rate stabilizes instead of decaying to zero.

In [None]:
import torch

In [None]:
class RMSProp(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, alpha=0.99, eps=1e-8, weight_decay=0.01):
        defaults = dict(lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # Lazy Init
                if p not in self.state:
                    self.state[p] = {
                        'square_avg': torch.zeros_like(p),
                    }

                # Weight Decay
                grad = p.grad
                if group['weight_decay'] != 0.0:
                    # equivalent to: loss = ... + weight_decay/2 * W**2
                    grad = grad.add(p, alpha=group['weight_decay'])  # not in place!

                # Update Step
                s = self.state[p]['square_avg']
                s.mul_(group['alpha'])

                # s = s + (1-group['alpha']) * grad * grad
                s.addcmul_(grad, grad, value=1-group['alpha'])

                avg = s.sqrt().add_(group['eps'])
                
                # p.add_(grad / avg, alpha=-group['lr'])
                p.data.addcdiv_(grad, avg, value=-group['lr'])

In [None]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02

# Optimizers
opt_torch = torch.optim.RMSprop([W1], lr=lr, alpha=0.99, weight_decay=0.01)
opt_custom = RMSProp([W2], lr=lr, alpha=0.99, weight_decay=0.01)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    state1 = opt_torch.state[W1]
    state2 = opt_custom.state[W2]
    assert list(state1.keys()) == ['step', 'square_avg']
    assert torch.equal(state1['square_avg'], state2['square_avg'])

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
    print(f"Diff {weight_max_diff}")

print(f"All good after {i+1} iterations!")

## Adam

**Idea:** Combine momentum (first moment) with RMSprop (second moment). Add bias correction for early steps.

```
v = B1 * v + (1 - B1) * g          # first moment (direction)
s = B2 * s + (1 - B2) * g^2        # second moment (scaling)

v_corrected = v / (1 - B1^t)       # bias correction
s_corrected = s / (1 - B2^t)

p = p - lr * v_corrected / (sqrt(s_corrected) + eps)
p = p - lr * wd * p                # AdamW: decoupled weight decay
```

Typical: B1 = 0.9, B2 = 0.999, eps = 1e-8.

Bias correction compensates for zero initialization. After ~1000 steps, correction ≈ 1.

In [None]:
import torch

In [None]:
class Adam(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, decoupled_weight_decay=False):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, decoupled_weight_decay=decoupled_weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # Lazy Init
                if p not in self.state:
                    self.state[p] = {
                        'step': 0,
                        'exp_avg': torch.zeros_like(p),
                        'exp_avg_sq': torch.zeros_like(p),
                    }
                self.state[p]['step'] += 1

                # Weight Decay
                grad = p.grad
                if group['weight_decay'] != 0.0:
                    if not group['decoupled_weight_decay']:
                        # Adam
                        # equivalent to: loss = ... + weight_decay/2 * W**2
                        grad = grad.add(p, alpha=group['weight_decay'])  # not in place!
                    else:
                        # AdamW
                        # p = p - lr * weight_decay * p
                        p.mul_(1 - group['lr'] * group['weight_decay'])

                # Update v
                # v = B1 * v + (1-B1) * g
                v = self.state[p]['exp_avg']
                v.lerp_(grad, 1 - group['betas'][0])

                # Update s
                # s = B2 * s + (1-B2) * g**2
                s = self.state[p]['exp_avg_sq']
                s.mul_(group['betas'][1])
                s.addcmul_(grad, grad, value=1-group['betas'][1])

                # Correction
                # Somewhat convoluted way to do:
                # v_corrected = v / (1-B1**t)
                # s_corrected = s / (1-B2**t)
                # p = p - lr * v_corrected / (sqrt(s_corrected)+eps)
                t = self.state[p]['step']
                bias1 = 1-group['betas'][0]**t
                bias2 = 1-group['betas'][1]**t
                bias2_sqrt = bias2**0.5
                denom = (s.sqrt() / bias2_sqrt).add_(group['eps'])
                p.data.addcdiv_(v, denom, value=-group['lr'] / bias1)



In [None]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02

# Optimizers
opt_torch = torch.optim.Adam([W1], lr=lr, betas=(0.9, 0.999), weight_decay=0.01)
opt_custom = Adam([W2], lr=lr, betas=(0.9, 0.999), weight_decay=0.01)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    state1 = opt_torch.state[W1]
    state2 = opt_custom.state[W2]
    assert list(state1.keys()) == ['step', 'exp_avg', 'exp_avg_sq']
    assert state1['step'] == state2['step']
    assert torch.equal(state1['exp_avg'], state2['exp_avg'])
    assert torch.equal(state1['exp_avg_sq'], state2['exp_avg_sq'])

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
    print(f"Diff {weight_max_diff}")

print(f"All good after {i+1} iterations!")

In [None]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02

# Optimizers
opt_torch = torch.optim.AdamW([W1], lr=lr, betas=(0.9, 0.999), weight_decay=0.01)
opt_custom = Adam([W2], lr=lr, betas=(0.9, 0.999), weight_decay=0.01, decoupled_weight_decay=True)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    state1 = opt_torch.state[W1]
    state2 = opt_custom.state[W2]
    assert list(state1.keys()) == ['step', 'exp_avg', 'exp_avg_sq']
    assert state1['step'] == state2['step']
    assert torch.equal(state1['exp_avg'], state2['exp_avg'])
    assert torch.equal(state1['exp_avg_sq'], state2['exp_avg_sq'])

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
    print(f"Diff {weight_max_diff}")

print(f"All good after {i+1} iterations!")

## Newton-Schulz Iteration

**Not an optimizer**—a subroutine to orthogonalize a matrix cheaply.

**Problem:** Given matrix G, find closest orthogonal matrix U. SVD works but is O(n³) and slow on GPU.

**Solution:** Iterative approximation.

```
X = G / ||G||                        # scale so singular values < 1

repeat 5 times:
    X = 1.5 * X - 0.5 * X @ X.T @ X

return X
```

Converges to orthogonal matrix. Each iteration is just matrix multiplies-GPU friendly.

For non-square matrices, adjust multiplication order based on shape (orthonormalize the smaller dimension).

In [None]:
import torch

In [None]:
def zeropower_via_newtonschulz(grad, steps=5):
    assert grad.ndim == 2
    a, b, c = 3.4445, -4.7750, 2.0315
    eps=1e-7
    X = grad.bfloat16()
    if grad.size(0) > grad.size(1):
        X = X.T
    # Scale down to norm at most 1
    X.div_(X.norm().clamp(min=eps))
    for _ in range(steps):
        A = X @ X.T
        B = torch.addmm(A, A, A, beta=b, alpha=c)
        X = torch.addmm(X, B, X, beta=a)
    if grad.size(0) > grad.size(1):
        X = X.T
    return X

In [None]:
grad = torch.randn(128, 64)
X = zeropower_via_newtonschulz(grad, steps=5)
assert X.shape == grad.shape
# Short side orthogonality: X.T @ X should be ~I since n < m
eye = X.float().T @ X.float()
# off-diagonals are small
assert (eye - torch.diag(eye.diag())).abs().max() < 0.2
# diagonal is in reasonable range
assert eye.diag().min() > 0.5 and eye.diag().max() < 1.5

## Muon

**Idea:** Replace Adam's element-wise scaling with matrix-level orthogonalization. Each update is "balanced" - no direction gets special treatment.

```
p = p - lr * wd * p                # decoupled weight decay

v = B * v + (1-B) * g              # momentum or EMA equivalent, because orthogonalization wipes scale
                                   # EMA possibly more numerical stable

vv = B * v + (1-B) * g             # optional, Nesterov look-ahead (note it's just lerp again)

U = newton_schulz(vv)              # orthogonalize

lr_adj = lr * sqrt(max(1, m/n))    # adjust for aspect ratio
p = p - lr * U
```

Typical: B = 0.95, lr = 0.02, 5 Newton-Schulz iterations.

**Key insight:** Adam treats weight matrices as bags of independent numbers. Muon treats them as transformations with structure. Orthogonalization preserves the "direction" of the update while balancing magnitudes across the matrix.

**Scope:** Only applies to 2D weight matrices. Use Adam for embeddings, biases, LayerNorm params.

In [None]:
class Muon(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, momentum=0.95, nesterov=True, ns_steps=5, weight_decay=0.1):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # Lazy Init
                if p not in self.state:
                    self.state[p] = {
                        'momentum_buffer': torch.zeros_like(p),
                    }

                # Decoupled Weight Decay
                grad = p.grad
                if group['weight_decay'] != 0:
                    p.mul_(1 - group['lr'] * group['weight_decay'])

                # Update v
                # v = B1 * v + (1-B) * g
                v = self.state[p]['momentum_buffer']
                v.lerp_(grad, 1 - group['momentum'])

                # Optional Nesterov look-ahead
                # vv = B*v + (1-B)*g
                vv = grad.lerp(v, group['momentum']) if group['nesterov'] else v

                # Update
                update = zeropower_via_newtonschulz(vv, group['ns_steps'])
                lr = group['lr'] * (max(1, p.size(0) / p.size(1)))**0.5
                p.add_(update, alpha=-lr)

In [None]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02
momentum = 0.95
nesterov = True
ns_steps = 5

# Optimizers
opt_torch = torch.optim.Muon([W1], lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps,
                             weight_decay=0.01, adjust_lr_fn="original")
opt_custom = Muon([W2], lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, weight_decay=0.01)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    state1 = opt_torch.state[W1]
    state2 = opt_custom.state[W2]
    assert list(state1.keys()) == ['momentum_buffer']
    assert torch.equal(state1['momentum_buffer'], state2['momentum_buffer'])

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
    print(f"Diff {weight_max_diff}")

print(f"All good after {i+1} iterations!")