In [None]:
import torch 
import torch.nn as nn 

# model_dtype = torch.bfloat16 
model_dtype = torch.float32

model = nn.Linear(10, 5)
model.to(model_dtype)

# Toy case: 
# - multiplication operation (dimension pair multiply)
N = 50
input = torch.randn((N, 10)).to(model_dtype)
output = input[:,::2] * input[:,1::2].to(model_dtype)

def criterion(pred, target): 
    l2_loss = torch.norm(pred - output, p=2, dim=1)
    return l2_loss.mean()

In [None]:
# from yao import YAO
from temp_yao import YAO

# yao optimizer
yao_optimizer = YAO(params=model.parameters())

In [None]:
# compute loss 
pred = model(input)
loss = criterion(pred, output)

# Backward pass
yao_optimizer.zero_grad()
loss.backward()

In [None]:
# work for bfloat typed model parameters 
yao_optimizer._local_step()

In [None]:
g = U @ ((state["moment1_s"] / (eps + state["moment2_s"].sqrt())).unsqueeze(-1) * V.T)

In [24]:
from temp_yao import * 
self = yao_optimizer


for group in self.param_groups:
    lr = group["lr"]
    beta1, beta2 = group["adamw_betas"]
    eps = group["adamw_eps"]
    weight_decay = group["wd"]
    
    # --- Low-Rank Params ---
    lowrank_params = [p for p in group["params"] if self.state[p]["use_arg"]]
    for p in lowrank_params:
        g = p.grad
        if g is None:
            continue
    
        state = self.state[p]
        if "step" not in state:
            # Initialize on first step
            rank = min(g.shape)  # Default rank (adjust if needed)
            state["step"] = 0
            state["moment1_u"] = torch.zeros(g.shape[0], rank)
            state["moment1_v"] = torch.zeros(g.shape[1], rank)
            state["moment1_s"] = torch.zeros(rank)
            state["moment2_s"] = torch.zeros(rank)
    
        # Low-rank SVD approximation | this guy does not seem to support bfloat16 input type
        U, S, V = svd_lowrank(g, q=state["moment1_u"].shape[1], niter=2)
    
        print(f"Shape of U: {to_shape(U)} | V: {to_shape(V)} | g: {to_shape(g)}") 
        print(f"Dtype of U: {U.dtype} | V: {V.dtype} | g: {g.dtype}")
    
        print("Dtype of U Moment: ", state['moment1_u'].dtype)
        print("Dtype of V Moment: ", state["moment1_v"].dtype)
        
        # Update momentum buffers
        state["step"] += 1
        
        # _beta1 = torch.tensor(beta1, dtype=torch.bfloat16)
        # _beta2 = torch.tensor(beta2, dtype=torch.bfloat16)
    
        state["moment1_u"].lerp_(U, 1 - beta1)
        state["moment1_v"].lerp_(V, 1 - beta1)
        state["moment1_s"].lerp_(S, 1 - beta1)
        state["moment2_s"].lerp_(S.norm()**2, 1 - beta2)
    
        # Newton-Schulz orthogonalization
        U = zeropower_via_newtonschulz5(state["moment1_u"], group["ns_steps"])
        V = zeropower_via_newtonschulz5(state["moment1_v"], group["ns_steps"])
        _mid = (state["moment1_s"] / (eps + state["moment2_s"].sqrt())).unsqueeze(-1).to(torch.bfloat16)
        g = U @ (_mid * V.T)

Shape of U: 5, 5 | V: 10, 5 | g: 5, 10
Dtype of U: torch.float32 | V: torch.float32 | g: torch.float32
Dtype of U Moment:  torch.float32
Dtype of V Moment:  torch.float32


In [19]:
U.dtype, V.dtype

(torch.bfloat16, torch.bfloat16)

In [23]:
suffix = (state["moment1_s"] / (eps + state["moment2_s"].sqrt())).unsqueeze(-1) * V.T
prefix = U 
print(f"Matrix multiplication between: {prefix.dtype} and {suffix.dtype}") 

# Ok so it turns out matrix multiplication in pytorch can't be broadcasted ... (wow ...)
torch.matmul(prefix, suffix)
# prefix.to(torch.float32) @ suffix

Matrix multiplication between: torch.bfloat16 and torch.float32


RuntimeError: expected m1 and m2 to have the same dtype, but got: c10::BFloat16 != float

In [13]:
(eps + state["moment2_s"].sqrt()).unsqueeze(-1) * V.T

tensor([[ 0.0195, -0.0508,  0.0397, -0.0012, -0.0154, -0.0679, -0.0075, -0.0488,
         -0.0284,  0.0791],
        [ 0.0272, -0.0519, -0.0193,  0.0081, -0.0672,  0.0202, -0.0529,  0.0829,
         -0.0118,  0.0127],
        [-0.0449, -0.0012,  0.0557, -0.0449, -0.0411,  0.0204,  0.0367,  0.0152,
         -0.0749, -0.0259],
        [-0.0453,  0.0112,  0.0470, -0.0059, -0.0011,  0.0357, -0.0406, -0.0013,
          0.0434,  0.0357],
        [ 0.0651,  0.0397,  0.0373,  0.0178, -0.0228,  0.0301,  0.0341,  0.0047,
          0.0083,  0.0215]])

In [10]:
g = 1.0

type(g)

float

In [7]:
# Data Type Promoting experimentation ... 

import torch 
from moun import zeropower_via_newtonschulz5 as ns 

g = torch.rand(5, 10)
print(f"Data type of tensor g: {g.dtype}") 
g = g.to(torch.bfloat16)
print(f"Data type after conversion for g: {g.dtype}")
x = ns(g, 2)
print(f"NS orthogonalization result dtype: {x.dtype}")
m = g 
m.add_(x, alpha=1.)
print(f"Addition between {m.dtype} tensor and {x.dtype} tensor becomes: {m.dtype}")

# Ok, so there are some addition-based broadcasting of tensor across dtype (indeed 'promotion of sort')
# But why do I get the previous incompatibility issue if such broadcasting across dtype exists??



Data type of tensor g: torch.float32
Data type after conversion for g: torch.bfloat16
NS orthogonalization result dtype: torch.bfloat16
Addition between torch.bfloat16 tensor and torch.bfloat16 tensor becomes: torch.bfloat16


In [15]:
from temp_yao import svd_lowrank # proudly presented functional

print(":: Testing wrapped randomized low-rank operator ::")
g = torch.rand(5,10)

print(f"LowRank SVD on input: {g.dtype}")
U, S, V = svd_lowrank(g, q=4, niter=2)
print(f"- output type U : {U.dtype} | S: {S.dtype} | V: {V.dtype} -")

g = g.to(torch.bfloat16)
print(f"LowRank SVD on input: {g.dtype}")
U, S, V = svd_lowrank(g, q=4, niter=2)
print(f"- output type U : {U.dtype} | S: {S.dtype} | V: {V.dtype} -")

g_approx = U @ (S.unsqueeze(-1) * V.T)

:: Testing wrapped randomized low-rank operator ::
LowRank SVD on input: torch.float32
- output type U : torch.float32 | S: torch.float32 | V: torch.float32 -
LowRank SVD on input: torch.bfloat16
- output type U : torch.bfloat16 | S: torch.bfloat16 | V: torch.bfloat16 -


In [27]:
print(":: Pytorch requires aligning dtype for matrix multiplication")
u = torch.randn(8, 4).to(torch.bfloat16)
v = torch.randn(4, 8).to(torch.bfloat16)
u @ v 

u = torch.randn(8, 4).to(torch.float32)
v = torch.randn(4, 8).to(torch.bfloat16)
u @ v 

:: Pytorch requires aligning dtype for matrix multiplication


RuntimeError: expected m1 and m2 to have the same dtype, but got: float != c10::BFloat16

In [11]:
g = torch.ones(5,10).to("cuda")
# U, S, V = torch.svd_lowrank(g, q=8, niter=2)

In [18]:
import torch
import torch.nn as nn

# Define model (10 input features → 5 output features)
model = nn.Linear(10, 5)

# Toy case: Ensure output shape matches model's output (N, 5)
N = 50
input = torch.randn((N, 10))
output = input[:, :5] * input[:, 5:]  # Shape (N, 5)

def criterion(pred, target):
    l2_loss = torch.norm(pred - target, p=2, dim=1)  # L2-norm per sample
    return l2_loss.mean()

# Optimizer (assuming YAO is imported)
from yao import YAO
yao_optimizer = YAO(params=model.parameters())

# Forward pass
pred = model(input)
loss = criterion(pred, output)

# Backward pass
yao_optimizer.zero_grad()
loss.backward()
yao_optimizer.step()

tensor(2.2630, grad_fn=<MeanBackward0>)