# Analysis of the effect of the embedding LR update on the subsequent matmul

I wanted to write this out in a notebook to make sure I understood the way in which the embedding update effects the subsequent matmul.

No revelations unfortunately - it still seems as though our rule can't be justified this way (it is "unnatural"!). Under the "no-alignment" assumption the standard embedding LR breaks, but unfortunately our fix does nothing to help. Oh well.

In [1]:
import torch
from torch import randn
from typing import Iterable

In [2]:
def rms(*xs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]:
    if len(xs) == 1:
        return xs[0].pow(2).mean().sqrt()
    return tuple(rms(x) for x in xs)

## Setup

Toggle `full_alignment` and `umup_lr_rule` to see the effect. mup scaling is used by default.

In [3]:
d = 2**11
full_alignment = True
umup_lr_rule = False

w_lr = d ** -(1 if full_alignment else 0.5)
e_lr = d ** -(0.5 if umup_lr_rule else 0)

## Model & update

Everything can be described in terms of these three tensors (a single embedding vector, weight matrix and a gradient vector). Note that I assume the gradient is unit-scale, and then just use the adam LR rules but under and SGD-like update (I appreciate this is a bit odd, but it's simple and the maths should work out)

In [4]:
e1 = randn(d, 1)
W1 = randn(d + 1, d) * d**-0.5
g = randn(d + 1, 1)
rms(
    e1, W1, g
)  # all "well-scaled", except the weight which is 1/sqrt(d) (this isn't unit scaling!)

(tensor(0.9984), tensor(0.0221), tensor(0.9882))

Then we just run:

In [5]:
x1 = W1 @ e1
rms(x1)  # well-scaled

tensor(0.9953)

In [6]:
u_e = W1.T @ g * e_lr
u_W = g @ e1.T * w_lr
(
    rms(u_e, u_W),
    1 / d,
)  # the weight update is under-scaled (to be expected I think), though as a rank-1 matrix it has a much higher (O(1)) spectral norm! This means its effect doesn't "go to zero" in inf. width, though the rms does.

((tensor(0.9977), tensor(0.0005)), 0.00048828125)

In [7]:
e2 = e1 + u_e
e2_std = e2.std()
e2 /= e2_std  # Why is `/ e2.std()` allowed/justified? Normally we'd have a much smaller weight update (scaled down by small LR constant), and then the original weight would be decayed a bit, keeping this at about rms=1. This re-scaling does something similar, though allows us to see the effect of the weight update scaling more clearly.
W2 = W1 + u_W
rms(
    e2, W2
)  # Update is well-scaled. Weight has barely changed from its 1/sqrt(d) starting point

(tensor(0.9998), tensor(0.0221))

In [8]:
x2 = W2 @ e2
rms(x2)  # ~well-scaled. Certainly doesn't scale with a significant power of d

tensor(1.7412)

## Analysis

Now we break this down into its constituent terms.

First checking that they combine to the original

In [9]:
torch.allclose(x2, (W1 + u_W) @ (e1 + u_e * e_lr) / e2_std, atol=1e-6)
torch.allclose(x2, (W1 + g @ e1.T * w_lr) @ (e1 + W1.T @ g * e_lr) / e2_std, atol=1e-6)

True

In [10]:
# t1 = W1 @ e1 (== x1)
t2 = W1 @ W1.T @ g * e_lr
t3 = g @ e1.T * w_lr @ e1
t4 = g @ e1.T * w_lr @ W1.T @ g * e_lr
torch.allclose(x2, (x1 + t2 + t3 + t4) / e2_std, atol=1e-5)

True

### Weight @ emb_update (t2)

This is well-scaled under the original emb lr rule, but not under our lr rule - which isn't a great sign for our approach

In [11]:
print(f"{rms(W1, g), e_lr=}")
print(f"{rms(W1 @ W1.T)=}")
print(f"{rms(W1.T @ g)=}")
print(f"{rms(W1 @ W1.T @ g * e_lr / e2_std)=}")

rms(W1, g), e_lr=((tensor(0.0221), tensor(0.9882)), 1)
rms(W1 @ W1.T)=tensor(0.0312)
rms(W1.T @ g)=tensor(0.9977)
rms(W1 @ W1.T @ g * e_lr / e2_std)=tensor(0.9857)


### Weight_update @ emb (t3)

This is well-scaled under the original emb lr rule and our rule

In [12]:
print(f"{rms(g, e1)=}")
print(f"{rms(g @ e1.T)=}")
print(f"{rms(e1.T @ e1 * w_lr)=}")
print(f"{rms(g @ e1.T * w_lr @ e1)=}")

rms(g, e1)=(tensor(0.9882), tensor(0.9984))
rms(g @ e1.T)=tensor(0.9866)
rms(e1.T @ e1 * w_lr)=tensor(0.9968)
rms(g @ e1.T * w_lr @ e1)=tensor(0.9850)


### Weight_update @ emb_update (t4)

This vanishes with width under the original emb lr and our rule. Probably a good thing?

In [13]:
print(f"{rms(g @ e1.T @ W1.T @ g)=}")
print(f"{rms(g @ e1.T * w_lr @ W1.T @ g * e_lr)=}")

rms(g @ e1.T @ W1.T @ g)=tensor(46.5558)
rms(g @ e1.T * w_lr @ W1.T @ g * e_lr)=tensor(0.0227)
