# LM-Adagrad

In this notebook I describe a limited-memory version of full matrix Adagrad algorithm. If this algorithm has been already described somewhere else, please let me know.

### Full-matrix Adagrad
The standard full-matrix Adagrad update rule is:

Hyperparameters: learning rate $\alpha$ and regularization parameter $\epsilon$

To initialize the algorithm, create the accumulation buffer A of shape $d\times d$, where $d$ is number of parameters, initialized to zeros.

On time step $t$ you have parameters $w_t$ and do this:

1. Evaluate the gradient vector $g_t$
2. Calculate outer product of gradient with itself: $G=g_t g_t^T$
3. Update accumulator with gradient outer product: $A_{t+1}=A_{t}+G$
4. Make a regularized version of $A_{t+1}$ by adding identity matrix times small value: $A_{reg} = A_{t+1}+I\cdot\epsilon$
5. Update the parameters: $w_{t+1}=w_{t}-\alpha \cdot A_{reg}^{-1/2}g_{t}$

The main disadvantage of full-matrix Adagrad is that the matrix $A$ is prohibitively large. That's why authors suggested using just the diagonal of $A$, however that gets rid of a lot of information about off-diagonal elements. Methods like Shampoo and KFAC have been proposed to use a little bit more than just the diagonal of $A$, but not the full matrix.

### Derivation of Limited-Memory Adagrad

I realized that what Full-matrix Adagrad actually does is it applies ZCA whitening, except it doesn't perform centering. But ZCA whitening is commonly applied to large datasets via singular value decomposition (SVD), which can also be adapted to make a memory efficient version of Adagrad, and here is how:

Suppose we take past $k$ gradient vectors and stack them as columns into a single matrix $M\in \mathbb{R}^{d\times k}$. Accumulator $A$ is sum of outer products of columns of $M$, so it can be calculated as $A=MM^T$.

Now let's consider a thin SVD of $M$:

$M=U\Sigma V^T$

Here $U$ is $d\times k$, $\Sigma$ is a diagonal matrix with $k$ singular values, and $V$ is an orthogonal $k\times k$ matrix.

We know that $A=MM^T$, now swap M for it's SVD:

$A = MM^T = (U \Sigma V^T)(U \Sigma V^T)^T = U \Sigma V^T V \Sigma^T U^T = U \Sigma^2 U^T$

Note: because $V$ is orthogonal, $VV^T=VV^{-1}=I$ (identity)

In step 4 of Adagrad we need $A^{-1/2}$

$A^{-1/2} = (U \Sigma^2 U^T)^{-1/2} = U (\Sigma^2)^{-1/2} U^T = U \Sigma^{-1} U^T$

Note: $(U \Sigma^2 U^T)^{-1/2}$ is equivalent to $U (\Sigma^2)^{-1/2} U^T$ because $U \Sigma^2 U^T$ is a matrix diagonalization.

So, the update rule is:

$w_{t+1} = w_t - A^{-1/2} g_t = w_t - U \Sigma^{-1} U^T g_t$

The proposed method is equivalent to full matrix Adagrad if it only used sum of outer products of last $k$ gradients.

### Limited-Memory Adagrad update rule

Hyperparameters: learning rate $\alpha$, history size $k$ and regularization parameter $\epsilon$. I set $k=10$ and $\epsilon=1e-6$, if I find better values I will put them there.

To initialize the algorithm, initialize an empty list `history` to store past $k$ gradients.

on time step $t$ you have parameters $w_t$ and you do this:

1. Evaluate the gradient vector $g_t$
2. Append $g_t$ at the end of `history`
3. If `length(history) > k`, delete first element in `history`, so that it only has last $k$ gradient vectors
4. Stack gradients in `history` as columns of matrix $M\in \mathbb{R}^{d\times k}$.
    * Optionally center M by subtracting from each row it's mean $\bar{g}$:
    * $M_{centered} = M - \bar{g}$, where $\bar{g}$ is a vector with means of each row of $M$.
    * this is how ZCA whitening is performed, and it does seem to help but it also is more unstable.
5. compute $U$, $\Sigma$, $V^T$ = SVD(M). We don't need $V^T$ so it can be discarded. We assume $\Sigma$ is returned as a vector of singular values.
6. Add regularization to singular values: $\Sigma = (\Sigma^2+\epsilon)^{1/2}$
7. Update the parameters: $w_{t+1} = w_t - U \Sigma^{-1} U^T g_t$

#### Notes:
* To implement step 7, first make a temporary variable $Z=(U^Tg)/S$ , then $w_{t+1} = w_t - UZ$ . Cuz if you multiplied $U \Sigma^{-1} U^T$ first, you would get $d\times d$ matrix and you're PC would explode.
* There are fast SVD methods for tall matrices. For example in pytorch set `U, S, V = torch.linalg.svd(M, solver="gesvda")`, otherwise it will keep freezing.
* The reason regularization is calculated as $\Sigma = (\Sigma^2+\epsilon)^{1/2}$ is because if we did $\Sigma = \Sigma+\epsilon$ , that would be equivalent to adding $I\cdot\epsilon$ to square root of $A$ instead of $A$ .
* Before applying the update rule, all gradients from all layers can be concatenated into a single gradient vector. Alternatively the update rule can be applied separately to each layer, which ignores interactions between layers.


### Tips
* Momentum makes it way better. This can be done in a few ways:
    * in step 7 $g_t$ can be replaced with a momentum buffer like in Adam
    * momentum could be applied to the update itself ( $U \Sigma^{-1} U^T g_t$ ), like in LaProp.
* Clip update norm. Or what I found works well is to clip or graft update norm to exponential moving average of of past (unclipped) updates.

### Reference implementation:

In [1]:
import torch

def limited_memory_adagrad(
    w: torch.Tensor, # current parameters vector
    g: torch.Tensor, # current gradient vector
    history: list[torch.Tensor], # history of past gradients, initialized to an empty list
    lr=1e-2, k=10, eps=1e-6, centered=False # hyperparameters
):

    # update history
    history.append(g)
    if len(history) > k: del history[0]

    # stack history as columns of M
    M = torch.stack(history, dim=1) # (d, k)

    # optionally apply centering
    if centered:
        M -= M.mean(1, keepdim=True)

    # compute thin SVD, M has to be on CUDA for "gesvda" driver
    U, S, _ = torch.linalg.svd(M.cuda(), driver="gesvda")
    U = U.to(w); S = S.to(w) # move back to weights device

    # regularize singular values
    S = (S**2 + eps).sqrt()

    # compute U S^-1 U^T g
    # start with Z = S^-1 U^T g
    Z = (U.T @ g) / S

    # now update is U@Z
    w -= (U @ Z) * lr




In [29]:
# Of course there is a torchzero implementation too
import torchzero as tz
from torch import nn
from torch.nn import functional as F
model = nn.Sequential(nn.Linear(10, 10), nn.Tanh(), nn.Linear(10, 10))

inputs = torch.randn(64, 10)
targets = torch.randn(64, 10)

opt = tz.Modular(
    model.parameters(),
    tz.m.LMAdagrad(),
    tz.m.WeightDecay(1e-3),
    tz.m.LR(1),
)

for i in range(100):
    preds = model(inputs)
    loss = F.mse_loss(preds, targets)
    loss.backward()
    opt.step()
    if i % 10 == 0: print(f'{i}, {loss = }')

0, loss = tensor(1.0929, grad_fn=<MseLossBackward0>)
10, loss = tensor(0.7662, grad_fn=<MseLossBackward0>)
20, loss = tensor(0.6243, grad_fn=<MseLossBackward0>)
30, loss = tensor(0.5514, grad_fn=<MseLossBackward0>)
40, loss = tensor(0.5425, grad_fn=<MseLossBackward0>)
50, loss = tensor(0.5285, grad_fn=<MseLossBackward0>)
60, loss = tensor(0.4883, grad_fn=<MseLossBackward0>)
70, loss = tensor(0.4897, grad_fn=<MseLossBackward0>)
80, loss = tensor(0.5217, grad_fn=<MseLossBackward0>)
90, loss = tensor(0.5008, grad_fn=<MseLossBackward0>)
