# Optimizers

Implementations for SDG, Momentum, etc. up to AdamW and Muon

## SGD (Stochastic Gradient Descent)

**Idea:** Take a step in the negative gradient direction.

```
θ = θ - lr * g
```

That's it. Simple, but struggles with noisy gradients and ill-conditioned landscapes.

In [None]:
import torch

In [None]:
class SDG(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, weight_decay=0.01):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # Weight Decay
                grad = p.grad
                if group['weight_decay'] != 0.0:
                    # equivalent to: loss = ... + weight_decay/2 * W**2
                    grad = grad.add(p, alpha=group['weight_decay'])  # not in place!
                # Update
                p.data.add_(grad, alpha=-group['lr'])


In [39]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02

# Optimizers
opt_torch = torch.optim.SGD([W1], lr=lr, weight_decay=0.01)
opt_custom = SDG([W2], lr=lr, weight_decay=0.01)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
print(f"All good after {i+1} iterations!")

All good after 20 iterations!


## SGD with Momentum

**Idea:** Accumulate gradients over time into a "velocity." Smooths out noise, builds up speed in consistent directions.

```
v = β * v + g
θ = θ - lr * v
```

Typical β = 0.9 (averages ~10 steps). Steady-state velocity is g/(1-β), so effective step is larger than vanilla SGD.

In [40]:
import torch

In [46]:
class SDGMomentum(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, momentum=0.9, weight_decay=0.01):
        defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # Lazy Init
                if p not in self.state:
                    self.state[p] = {
                        'momentum_buffer': torch.zeros_like(p),
                    }
                # Weight Decay
                grad = p.grad
                if group['weight_decay'] != 0.0:
                    # equivalent to: loss = ... + weight_decay/2 * W**2
                    grad = grad.add(p, alpha=group['weight_decay'])  # not in place!
                # Update Step
                v = self.state[p]['momentum_buffer']
                v.mul_(group['momentum']).add_(grad)
                p.data.add_(v, alpha=-group['lr'])

In [47]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02

# Optimizers
opt_torch = torch.optim.SGD([W1], lr=lr, momentum=0.9, weight_decay=0.01)
opt_custom = SDGMomentum([W2], lr=lr, momentum=0.9, weight_decay=0.01)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    state1 = opt_torch.state[W1]
    state2 = opt_custom.state[W2]
    assert list(state1.keys()) == ['momentum_buffer']
    assert torch.equal(state1['momentum_buffer'], state2['momentum_buffer'])

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
    print(f"Diff {weight_max_diff}")

print(f"All good after {i+1} iterations!")

Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
All good after 20 iterations!


## RMSprop

**Idea:** Fix AdaGrad by using EMA instead of sum. Old gradients decay away.

```
s = β * s + (1 - β) * g²
θ = θ - lr * g / (√s + ε)
```

Typical β = 0.99. Learning rate stabilizes instead of decaying to zero.

In [2]:
import torch

In [48]:
class RMSProp(torch.optim.Optimizer):
    def __init__(self, params, lr=0.01, alpha=0.99, eps=1e-8, weight_decay=0.01):
        defaults = dict(lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay)
        super().__init__(params, defaults)
    
    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                # Lazy Init
                if p not in self.state:
                    self.state[p] = {
                        'square_avg': torch.zeros_like(p),
                    }

                # Weight Decay
                grad = p.grad
                if group['weight_decay'] != 0.0:
                    # equivalent to: loss = ... + weight_decay/2 * W**2
                    grad = grad.add(p, alpha=group['weight_decay'])  # not in place!

                # Update Step
                s = self.state[p]['square_avg']
                s.mul_(group['alpha'])

                # s = s + (1-group['alpha']) * grad * grad
                s.addcmul_(grad, grad, value=1-group['alpha'])

                avg = s.sqrt().add_(group['eps'])
                
                # p.add_(grad / avg, alpha=-group['lr'])
                p.data.addcdiv_(grad, avg, value=-group['lr'])

In [50]:
# Seed
torch.manual_seed(42)

# Forward + backward
x = torch.randn(16, 32)
target = torch.randn(16, 64)

# Create identical weights
W1 = torch.randn(64, 32, requires_grad=True)
W2 = W1.clone().detach().requires_grad_(True)

# Params
lr = 0.02

# Optimizers
opt_torch = torch.optim.RMSprop([W1], lr=lr, alpha=0.99, weight_decay=0.01)
opt_custom = RMSProp([W2], lr=lr, alpha=0.99, weight_decay=0.01)

for i in range(20):
    opt_torch.zero_grad()
    opt_custom.zero_grad()
    loss1 = ((x @ W1.T - target) ** 2).mean()
    loss2 = ((x @ W2.T - target) ** 2).mean()
    loss1.backward()
    loss2.backward()

    opt_torch.step()
    opt_custom.step()

    state1 = opt_torch.state[W1]
    state2 = opt_custom.state[W2]
    assert list(state1.keys()) == ['step', 'square_avg']
    assert torch.equal(state1['square_avg'], state2['square_avg'])

    weight_max_diff = (W1 - W2).abs().max().item()
    assert weight_max_diff == 0.0
    print(f"Diff {weight_max_diff}")

print(f"All good after {i+1} iterations!")

Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
Diff 0.0
All good after 20 iterations!
