In [107]:
import torch
import torch.nn as nn
import math

# Simple model with different parameter types
class ToyModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Matrix parameters (2D)
        self.weight1 = nn.Parameter(torch.randn(4, 8))
        self.weight2 = nn.Parameter(torch.randn(8, 2))
        # Vector parameters (1D)
        self.bias1 = nn.Parameter(torch.randn(8))
        self.bias2 = nn.Parameter(torch.randn(2))
        
    def forward(self, x):
        x = x @ self.weight1 + self.bias1
        x = x @ self.weight2 + self.bias2
        return x

def loss_fn(x, y): # MSE loss 
    return (x - y).norm()

# Parameter group example
model = ToyModel()

input = torch.randn(4)
output = model(input)
loss = loss_fn(output, torch.randn(2))

# Ok, so i can choose my own parameter groups when initializing the optimizer
# - but once specified, parameter groups are fixed within the optimizer

# Method 1: Using a single parameter group
opt1 = torch.optim.SGD(model.parameters(), lr=0.1)
print(f"Optimizer 1 has {len(opt1.param_groups)} parameter group(s)")

# Method 2: Using multiple parameter groups with different learning rates
opt2 = torch.optim.SGD([
    {'params': [model.weight1, model.weight2], 'lr': 0.1},
    {'params': [model.bias1, model.bias2], 'lr': 0.01}
])
print(f"Optimizer 2 has {len(opt2.param_groups)} parameter group(s)")

# Print learning rates for each group
for i, group in enumerate(opt2.param_groups):
    print(f"Group {i} has lr={group['lr']} and contains {len(group['params'])} parameters")


Optimizer 1 has 1 parameter group(s)
Optimizer 2 has 2 parameter group(s)
Group 0 has lr=0.1 and contains 2 parameters
Group 1 has lr=0.01 and contains 2 parameters


In [108]:
loss

tensor(2.9734, grad_fn=<LinalgVectorNormBackward0>)

In [71]:
# loss.backward()
# regarding

import torch 
import math 

def decide_rank(loss, max_loss=None, max_rank: int = 8, min_rank: int = 1):
    # Use 'maximal loss' to decide rank value
    if not max_loss: 
        max_loss = loss.item() 
    # ratio to decide rank value
    return max(min_rank, int(loss.item() / max_loss * max_rank))


@torch.compile # speed-up for overhead, might raise memory consumption
def zeropower_via_newtonschulz5(G, steps):
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
    quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
    of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
    zero even beyond the point where the iteration no longer converges all the way to one everywhere
    on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
    where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
    performance at all relative to UV^T, where USV^T = G is the SVD.
    """
    assert len(G.shape) == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.bfloat16()
    if G.size(0) > G.size(1):
        X = X.T
    # Ensure spectral norm is at most 1
    X = X / (X.norm() + 1e-7)
    # Perform the NS iterations
    for _ in range(steps):
        A = X @ X.T
        B = (
            b * A + c * A @ A
        )  # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
        X = a * X + B @ X

    if G.size(0) > G.size(1):
        X = X.T
    return X

In [106]:
# Fixed grouping, no adjustment on magnitude & direction change version (ARG)

# gradient 
for p in model.parameters(): 
    if p.ndim == 2: 
       # max_loss = self.state[p]['max_loss']
       break 

p # tensor value 
p.grad # gradient value
orig_grad = p.grad  
r = decide_rank(loss, None, max_rank=min(p.shape))

# SVD (2 rotation matrices, one vector of singular values)
U, S, V = torch.svd_lowrank(orig_grad, q=r, niter=2)
# lr_approx = U @ torch.diag(S) @ V.T # low rank approximation of gradient 

# momentum computations
# For S, we accumulate 1st moment and 2nd moment like Adam 
# For U, V, we accumulate 1st moment and use NS to orthogonalize 1st momentum



In [101]:
V.shape, S.shape, U.shape

(torch.Size([8, 4]), torch.Size([4]), torch.Size([4, 4]))

In [110]:
S.norm() * V

tensor([[  4.9877,  -4.5207,   9.5587,   5.0980],
        [ -0.7154,  -1.1230,  -1.6472,   3.6788],
        [  3.5238, -10.2778,  -0.1282,  -5.9239],
        [  5.4887,   6.9303,   5.5392,  -0.3661],
        [  1.0456,  -0.0203,  -0.3516,  -2.3887],
        [  1.4112,  -2.2756,  -2.9140,   1.8091],
        [ 10.7519,   1.5227,  -7.3522,   2.1512],
        [ -2.2711,  -2.9870,  -2.0802,  10.1696]])

In [None]:
# Traditional Optimizer assumes 'gradient' and 'parameter' are fixed
# - Our idea combines 'wrapping low-rank adaptor' and call .backward() with optimization gadegts together 
# - in terms of code this is not just a custom optimizer, there needs to be another functionality happening before the .backward() functional .... 
