# Correct u-µP under SGD

It has been suggested that our current implementation of u-µP under SGD is incorrect,
insofar as it doesn't scale in the same way as µP with hidden size, even accounting
for abc-symmetry (see the u-µP paper Eq. 2).

To test this we will take the SGD implementation from the original mup repo, and
iteratively change it towards our current implementation, seeing where it breaks down.
We will conclude by running our fixed u-µP SGD alongside a the mup SGD to show that the
two are equivalent.

## Train model using SGD from mup repo

In [None]:
%pip install mup

import mup
import torch
from torch import nn
import unit_scaling as uu
from typing import Optional

In [2]:
b = 3
d_inp = 5
d_hid = 7
d_out = 11
steps = 3

xs = torch.randn(steps, b, d_inp)
# Set up some arbitrary function we wish to learn
ys = torch.tanh(xs @ torch.randn(d_inp, d_out))
ys /= ys.std()

In [3]:
# Standard torch nn.Linear init is proportional to 1/sqrt(fan_in), but with some
# convoluted constant. For simplicity these two classes just give 1/sqrt(fan_in) init.
class Linear(nn.Linear):
    def reset_parameters(self) -> None:
        nn.init.normal_(self.weight, std=self.weight.shape[1] ** -0.5)


class MuReadout(mup.MuReadout):
    def reset_parameters(self) -> None:
        nn.init.normal_(self.weight, std=self.weight.shape[1] ** -0.5)


# Basic 3-layer MLP, no biases
class ModelA(nn.Sequential):
    def __init__(self, d_inp: int, d_hid: int, d_out: int):
        super().__init__(
            Linear(d_inp, d_hid, bias=False),
            Linear(d_hid, d_hid, bias=False),
            MuReadout(d_hid, d_out, bias=False),
        )

In [4]:
def training_loop(model, opt):
    for x, y in zip(xs, ys):
        y_pred = model(x)
        loss = ((y - y_pred) ** 2).mean()
        loss.backward()
        opt.step()
        model.zero_grad()
        print(f"loss={loss.item():.4f}")

In [5]:
torch.manual_seed(1472)
model_a = ModelA(d_inp, d_hid, d_out)
# mup requires this "base shapes" step. This sets it up such that there is no base shape
# (i.e. it equals 1), and only d_hid determines the model-width
mup.set_base_shapes(model_a, ModelA(d_inp, 1, d_out))
opt_a = mup.optim.MuSGD(model_a.parameters(), lr=1e-1)

In [6]:
training_loop(model_a, opt_a)

loss=1.1976
loss=1.1297
loss=0.9949


^ our test for subsequent SGD implementations being correct is that they generate
these three loss values.

## Permute under abc-symmetry in the same way as u-µP

This doesn't give our full u-µP implementation (see next section for that), but is our
starting point. Here we'll create the same model, but using the unit scaling library
(`uu.`) which unit-inits the weights and adds unit-scaled multipliers. We adjust the
lrs accordingly under abc-symmetry.

To make this concrete, here is the original table of SGD scaling factors for mup, as
implemented in their library (table 8 in the Tensor Programs V paper):

| | input | hidden | output |
|-|-|-|-|
| mult (A) | 1 | 1 | 1/d_hid |
| init std (B) | 1/sqrt(d_in) | 1/sqrt(d_hid) | 1 |
| SGD lr (C) | d_hid | 1 | d_hid |

abc-symmetry says that our model's dynamics are invariant to changes of the form:

```A ← Aθ, B ← B/θ, C ← C/θ^2```

Setting `θ` to give `B=1` everywhere gives our unit-scaled table:

| | input | hidden | output |
|-|-|-|-|
| mult (A) | 1/sqrt(d_in) | 1/sqrt(d_hid) | 1/d_hid |
| init std (B) | 1 | 1 | 1 |
| SGD lr (C) | d_hid * d_in | d_hid | d_hid |

Our unit-scaled layers implement A and B - this gives our second model:

In [7]:
class ModelB(nn.Sequential):
    def __init__(self, d_inp: int, d_hid: int, d_out: int):
        super().__init__(
            uu.Linear(d_inp, d_hid),
            uu.Linear(d_hid, d_hid),
            uu.LinearReadout(d_hid, d_out, constraint="to_output_scale"),
        )
# The "to_output_scale" is necessary here, for reasons explained in the next section

For C, our SGD LRs, the bottom row of the above table suggests we wish to scale by

```(d_hid * d_inp**0.5, d_hid**0.5, 1)```.

However our unit-scaling library doesn't apply the A multipliers when computing grad_ws
(this ensures unit-scale), meaning we actually need to use the top row multiplied by the
bottom row for our SGD LRs. This gives:

```(d_hid * d_inp**0.5, d_hid**0.5, 1)```,

which we use below.

Our unit-scaled layers also contain a special multiplier for the grad_w calculation,
designed to maintain unit scale. This multiplier equals 1/sqrt(batch_size). To get back
to equivalence with the original SGD we therefore also multiply the base lr by 
sqrt(batch_size).

In [8]:
def gen_param_groups(model, base_lr, lr_mods):
    parameter_groups = []
    for params, lr_mod in zip(model.parameters(), lr_mods):
        parameter_groups.append({"params": [params], "lr": base_lr * lr_mod})
    return parameter_groups

In [9]:
torch.manual_seed(1472)

model_b = ModelB(d_inp, d_hid, d_out)
base_lr = 1e-1 * b**0.5
opt_b = torch.optim.SGD(
    gen_param_groups(model_b, base_lr, lr_mods=(d_hid * d_inp**0.5, d_hid**0.5, 1))
)

In [10]:
training_loop(model_b, opt_b)

loss=1.1976
loss=1.1297
loss=0.9949


Our loss here is exactly the same as in the previous case, meaning our implementation
is still correct.

## Introduce unit-scaled readout scaling

The final step to get a unit-scaled model here is to introduce the following
trick in the backward pass:

Our above model has to use 1/d_hid as our output mult.
This differs from the 1/sqrt(d) mult we usually use to get unit-scaling at init.
This is fine in the forward pass as the readout is the last layer, but in the backward
pass it's more of a problem as this mis-scaling propagates.
To fix this, we simply hack the gradient of `uu.LinearReadout` to use the ideal unit
scaling factor in the backward pass, which in this case is 1/sqrt(d_out), keeping
1/d_hid in the forward.

This is the default implementation of `uu.LinearReadout`, and is why we had to add
`constraint="to_output_scale"` in our previous `ModelB`. This change gives us `ModelC`.

In [11]:
class ModelC(nn.Sequential):
    def __init__(self, d_inp: int, d_hid: int, d_out: int):
        super().__init__(
            uu.Linear(d_inp, d_hid),
            uu.Linear(d_hid, d_hid),
            uu.LinearReadout(d_hid, d_out),
        )

We use the same optimizer as before:

In [12]:
torch.manual_seed(1472)

model_c = ModelC(d_inp, d_hid, d_out)
base_lr = 1e-1 * b**0.5
opt_c = torch.optim.SGD(
    gen_param_groups(model_c, base_lr, lr_mods=(d_inp**0.5 * d_hid, d_hid**0.5, 1))
)

In [13]:
training_loop(model_c, opt_c)

loss=1.1976
loss=1.1471
loss=0.9887


We see now that the loss values are different — something has changed here and our SGD
implementation is now broken.

## Train model using above + fix - show it's the same

The problem is that our backward pass re-scaling must be compensated for in our LRs.
We have gone from 1/d_hid scaling in our readout grad to 1/sqrt(d_out). Thus
to compensate we must multiply all our learning rates by sqrt(d_out)/d_hid.

Recall that previously we had:

```(d_hid * d_inp**0.5, d_hid**0.5, 1)```,

which now changes to:

```
(d_hid * d_inp**0.5 * sqrt(d_out)/d_hid, d_hid**0.5 * sqrt(d_out)/d_hid, 1)
= (d_inp**0.5 * d_out**0.5, d_hid**-0.5 * d_out**0.5, 1)
```

In [14]:
torch.manual_seed(1472)

model_c2 = ModelC(d_inp, d_hid, d_out)
base_lr = 1e-1 * b**0.5
opt_c2 = torch.optim.SGD(
    gen_param_groups(
        model_c2,
        base_lr,
        lr_mods=(d_inp**0.5 * d_out**0.5, d_hid**-0.5 * d_out**0.5, 1),
    )
)

In [15]:
training_loop(model_c2, opt_c2)

loss=1.1976
loss=1.1297
loss=0.9949


This is equal to our original mup SGD loss, meaning that we've fixed the problem! We
now have a unit-scaled model with identical dynamics to the mup model under SGD.

## Show corrected library implementation

All that remains is to fix our library and show that it indeed matches the original mup
SGD implementation.

Below is our library SGD implementation, but with one modification added to handle the
case (which we make the default) that the readout layer has a `None` constraint - i.e.
the gradient is re-scaled, and we must correct for it.

In [16]:
def lr_scale_func_sgd(readout_constraint: Optional[str]):
    """Calculate the LR scaling factor for :class:`torch.optim.SGD`."""

    def lr_scale_func_sgd_inner(param):
        scale = uu.optim.lr_scale_for_depth(param)
        if param.mup_type in ("bias", "norm"):
            return scale * param.shape[0]
        if param.mup_type == "weight":
            if readout_constraint is None:  # <<< NEW MODIFICATION
                return scale * uu.optim._get_fan_in(param) ** -0.5
            elif readout_constraint == "to_output_scale":  # <<< existing case
                return scale * uu.optim._get_fan_in(param) ** 0.5
            else:
                assert False, f"Unhandled readout constraint: {readout_constraint}"
        if param.mup_type == "output":
            return scale
        assert False, f"Unexpected mup_type {param.mup_type}"

    return lr_scale_func_sgd_inner


class SGD(torch.optim.SGD):
    def __init__(
        self,
        params,
        lr,
        *args,
        weight_decay: float = 0,
        independent_weight_decay: bool = True,
        allow_non_unit_scaling_params: bool = False,
        readout_constraint: Optional[str] = None,
        **kwargs,
    ) -> None:
        params = uu.optim.scaled_parameters(
            params,
            lr_scale_func_sgd(readout_constraint),
            lr=lr,
            weight_decay=weight_decay,
            independent_weight_decay=independent_weight_decay,
            allow_non_unit_scaling_params=allow_non_unit_scaling_params,
        )
        # No need to forward {lr, weight_decay}, as each group has these specified
        super().__init__(params, *args, **kwargs)

In [17]:
torch.manual_seed(1472)

model_c3 = ModelC(d_inp, d_hid, d_out)
base_lr = 1e-1
opt_c3 = SGD(model_c3.parameters(), base_lr)

In [18]:
training_loop(model_c3, opt_c3)

loss=1.1976
loss=1.1236
loss=1.0068


This is slightly different from our previous loss value. However the difference is only
in a couple of constant factors that are introduced by the two schemes — crucially _not_
factors that depend on model-width.

Specifically, these factors are the 1/sqrt(d_in) scaling factor in the very first mup
table presented, and the 1/sqrt(d_out) factor the we introduce in u-µP's readout grad
scaling.

Below we show how to correct for these constants under SGD such that the two schemes
are exactly the same. This is just our original `ModelA` with some non-width-dependent
LR tweaks:

In [19]:
torch.manual_seed(1472)

model_a2 = ModelA(d_inp, d_hid, d_out)
mup.set_base_shapes(model_a2, ModelA(d_inp, 1, d_out))
base_lr = 1e-1 * b**-0.5
opt_a2 = mup.optim.MuSGD(
    gen_param_groups(
        model_a2, base_lr, lr_mods=(d_inp**-1.0 * d_out**-0.5, d_out**-0.5, 1)
    )
)

In [20]:
training_loop(model_a2, opt_a2)

loss=1.1976
loss=1.1236
loss=1.0068


Sure enough, we get exactly the same loss, confirming that this fixed u-µP
SGD implementation scales in the same way as the original mup SGD.