# HW03 Participation D – Lion and SOAP Solutions

This notebook contains staff solutions for the **Lion** (part c) and **SOAP-style** (part h) optimizer questions from `q_mup_coding.ipynb`. It is not student-facing.

## Part c – Lion (staff solution)

This section records a reference implementation of `SimpleLion` and brief notes on expected behavior for parts c.1 (activation deltas) and c.2 (hyperparameter sweep).

In [None]:
import torch
from torch.optim.optimizer import Optimizer
from typing import Any


class SimpleLion(Optimizer):
    """Reference implementation matching the student TODO for part c."""

    def __init__(
        self,
        params: Any,
        lr: float = 1e-1,
        b1: float = 0.9,
    ):
        defaults = dict(lr=lr, b1=b1)
        super(SimpleLion, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            b1 = group["b1"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]
                if len(state) == 0:  # initialization
                    state["momentum"] = torch.zeros_like(p)

                m = state["momentum"]
                # Exponential-moving-average momentum
                m.lerp_(grad, 1 - b1)
                # Sign-based update
                u = m.sign()
                # Parameter update
                p.add_(u, alpha=-lr)
        return None

### Notes for c.1 (activation deltas)

- Compared to `SimpleAdam`, `SimpleLion` often produces:
  - Slightly larger or more uniform activation deltas across layers, because sign-based updates do not shrink when gradients are small.
  - Less variation in per-layer delta magnitudes when some layers have much smaller raw gradients.
- Students should comment qualitatively on which layers change most and how sign-based updates affect the pattern.

### Notes for c.2 (hyperparameter sweep)

- A small grid over `lr \in {0.001, 0.003, 0.01, 0.03}` and `b1 \in {0.8, 0.9, 0.95}` is sufficient.
- Reasonable settings typically look like:
  - `lr` around `0.003` or `0.01`.
  - `b1` around `0.9`.
- Too large a learning rate (e.g. `0.03`) can be unstable for Lion because update magnitudes do not shrink with small gradients.
- The goal is a brief qualitative statement about sensitivity to `lr` and `b1`, not finding a single perfect pair.

## Part h – SOAP-style Optimizer (staff solution)

This section records a reference implementation of `SimpleSOAP` and notes on expected behavior for the update-norm comparison (h.1) and the learning-rate comparison (h.2).

In [None]:
import torch
from torch.optim.optimizer import Optimizer
from typing import Any


class SimpleSOAP(Optimizer):
    """Reference implementation matching the student TODO for part h."""

    def __init__(
        self,
        params: Any,
        lr: float = 1e-1,
        b1: float = 0.9,
    ):
        defaults = dict(lr=lr, b1=b1)
        super(SimpleSOAP, self).__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for group in self.param_groups:
            lr = group["lr"]
            b1 = group["b1"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data

                state = self.state[p]
                if len(state) == 0:  # initialization
                    state["step"] = torch.tensor(0.0)
                    state["momentum"] = torch.zeros_like(p)

                state["step"] += 1
                m = state["momentum"]
                m.lerp_(grad, 1 - b1)

                if len(m.shape) == 1:
                    # For biases, just use the momentum directly
                    u = m
                else:
                    # SOAP-style matrix update: orthogonalize, then match Frobenius norm
                    U, S, Vh = torch.linalg.svd(m, full_matrices=False)
                    u0 = U @ Vh
                    m_frob = torch.norm(m, p="fro")
                    u0_frob = torch.norm(u0, p="fro") + 1e-16
                    scale = m_frob / u0_frob
                    u = u0 * scale

                p.add_(u, alpha=-lr)
        return None

### Notes for h.1 (update norms)

- **Adam**:
  - Frobenius and spectral norms reflect raw gradient magnitudes and vary by layer.
- **Shampoo**:
  - Spectral norms are roughly 1 for matrix parameters (due to orthogonalization).
  - Frobenius norms scale like `sqrt(rank)` and can differ substantially from Adam.
- **SOAP**:
  - By construction, Frobenius norms of updates closely match those of the original momentum (and so are closer to Adam), while spectral norms behave similarly to Shampoo.
  - Students should notice that SOAP preserves overall update "energy" per layer while still enforcing orthogonality.

### Notes for h.2 (learning-rate comparison)

- A small sweep over `lr \in {0.001, 0.003, 0.01, 0.03}` for `SimpleShampoo` and `SimpleSOAP` is sufficient.
- Typical qualitative observations:
  - SOAP often behaves more like Adam in terms of sensible learning-rate scales, because update magnitudes are preserved.
  - Shampoo can require more careful tuning of `lr` due to fixed spectral norms.
  - There is usually an overlapping range of good learning rates; SOAP may degrade less sharply when `lr` is slightly mis-specified.
- Students only need to report a brief comparison, not an exhaustive grid search.