In [1]:
import torch 
import torch.nn as nn 

# model_dtype = torch.bfloat16 
model_dtype = torch.float32

model = nn.Linear(10, 5)
model.to(model_dtype)

# Toy case: 
# - multiplication operation (dimension pair multiply)
N = 50
input = torch.randn((N, 10)).to(model_dtype)
output = input[:,::2] * input[:,1::2].to(model_dtype)

def criterion(pred, target): 
    l2_loss = torch.norm(pred - output, p=2, dim=1)
    return l2_loss.mean()

In [2]:
# from yao import YAO
from temp_yao import YAO

# yao optimizer
yao_optimizer = YAO(params=model.parameters())

In [3]:
# loop 
epochs = 100 

for epoch in range(epochs): 
    pred = model(input)
    loss = criterion(pred, output)

    yao_optimizer.zero_grad() 
    loss.backward() 
    yao_optimizer.step(loss)
    print(f"Train iter {epoch}/{epochs} - loss {loss.item()}")

:: Global Step ::
Train iter 0/100 - loss 2.2631921768188477
Train iter 1/100 - loss 2.262282371520996
Train iter 2/100 - loss 2.261791467666626
Train iter 3/100 - loss 2.2610867023468018
Train iter 4/100 - loss 2.260451555252075
Train iter 5/100 - loss 2.2594103813171387
Train iter 6/100 - loss 2.2584586143493652
Train iter 7/100 - loss 2.257549285888672
Train iter 8/100 - loss 2.256727695465088
Train iter 9/100 - loss 2.2558224201202393
:: Global Step ::
Train iter 10/100 - loss 2.2548744678497314
Train iter 11/100 - loss 2.253951072692871
Train iter 12/100 - loss 2.253005027770996
Train iter 13/100 - loss 2.2522101402282715
Train iter 14/100 - loss 2.251286268234253
Train iter 15/100 - loss 2.250624418258667
Train iter 16/100 - loss 2.249948024749756
Train iter 17/100 - loss 2.249108076095581
Train iter 18/100 - loss 2.2481565475463867
Train iter 19/100 - loss 2.24727725982666
:: Global Step ::
Train iter 20/100 - loss 2.2463643550872803
Train iter 21/100 - loss 2.2454776763916016
T

In [5]:
# Update max_loss and compute new rank
self = yao_optimizer 
current_loss = loss.item() 

if self.max_loss is None:
    self.max_loss = current_loss
else:
    self.max_loss = max(self.max_loss, current_loss)

for group in self.param_groups:
    for p in group["params"]:
        if not self.state[p]["use_arg"]:
            continue  # Skip non-low-rank params

        state = self.state[p]
        if "moment1_u" not in state:
            continue  # Not initialized yet

        # Adaptive rank for each parameter
        new_rank = calculate_rank(current_loss, self.max_loss, max_rank=min(p.shape))

        # Get current rank and buffers
        current_rank = state["moment1_u"].shape[1]
        if new_rank == current_rank:
            continue  # No change needed

        # Project momentum buffers to new rank
        state["moment1_u"] = self._adjust_rank(state["moment1_u"], new_rank)
        state["moment1_v"] = self._adjust_rank(state["moment1_v"], new_rank)
        state["moment1_s"] = self._adjust_rank(state["moment1_s"], new_rank)
        state["moment2_s"] = self._adjust_rank(state["moment2_s"], new_rank)

In [8]:
# new_rank # new rank is 99 --> wrong ... 
print(f"Current loss: {current_loss} | Max loss: {self.max_loss}")
# adjust rank functional issue ... 
state["moment1_u"].shape


Current loss: 2.1349761486053467 | Max loss: 2.1414403915405273


torch.Size([5, 99])

In [10]:
state["moment1_u"].shape, state["moment1_v"].shape

(torch.Size([5, 99]), torch.Size([10, 99]))

In [11]:
p

Parameter containing:
tensor([ 0.1801, -0.2629, -0.2272,  0.2899, -0.1162], requires_grad=True)

In [24]:
# Debugging run for local_step 

from temp_yao import * 
self = yao_optimizer


for group in self.param_groups:
    lr = group["lr"]
    beta1, beta2 = group["adamw_betas"]
    eps = group["adamw_eps"]
    weight_decay = group["wd"]
    
    # --- Low-Rank Params ---
    lowrank_params = [p for p in group["params"] if self.state[p]["use_arg"]]
    for p in lowrank_params:
        g = p.grad
        if g is None:
            continue
    
        state = self.state[p]
        if "step" not in state:
            # Initialize on first step
            rank = min(g.shape)  # Default rank (adjust if needed)
            state["step"] = 0
            state["moment1_u"] = torch.zeros(g.shape[0], rank)
            state["moment1_v"] = torch.zeros(g.shape[1], rank)
            state["moment1_s"] = torch.zeros(rank)
            state["moment2_s"] = torch.zeros(rank)
    
        # Low-rank SVD approximation | this guy does not seem to support bfloat16 input type
        U, S, V = svd_lowrank(g, q=state["moment1_u"].shape[1], niter=2)
    
        print(f"Shape of U: {to_shape(U)} | V: {to_shape(V)} | g: {to_shape(g)}") 
        print(f"Dtype of U: {U.dtype} | V: {V.dtype} | g: {g.dtype}")
    
        print("Dtype of U Moment: ", state['moment1_u'].dtype)
        print("Dtype of V Moment: ", state["moment1_v"].dtype)
        
        # Update momentum buffers
        state["step"] += 1
        
        # _beta1 = torch.tensor(beta1, dtype=torch.bfloat16)
        # _beta2 = torch.tensor(beta2, dtype=torch.bfloat16)
    
        state["moment1_u"].lerp_(U, 1 - beta1)
        state["moment1_v"].lerp_(V, 1 - beta1)
        state["moment1_s"].lerp_(S, 1 - beta1)
        state["moment2_s"].lerp_(S.norm()**2, 1 - beta2)
    
        # Newton-Schulz orthogonalization
        U = zeropower_via_newtonschulz5(state["moment1_u"], group["ns_steps"])
        V = zeropower_via_newtonschulz5(state["moment1_v"], group["ns_steps"])
        _mid = (state["moment1_s"] / (eps + state["moment2_s"].sqrt())).unsqueeze(-1).to(torch.bfloat16)
        g = U @ (_mid * V.T)

Shape of U: 5, 5 | V: 10, 5 | g: 5, 10
Dtype of U: torch.float32 | V: torch.float32 | g: torch.float32
Dtype of U Moment:  torch.float32
Dtype of V Moment:  torch.float32
