In [23]:
import sys
from pathlib import Path

project_root = Path("..").resolve()
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

import src.seed as seed
import src.models as models
import src.functions as fn

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import time
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device = seed.device
generator = seed.generator

In [2]:
def generate_PSD_matrix(dim, sharpness):
    A = np.random.randn(dim, dim)
    Q, R = np.linalg.qr(A)
    lam = np.array([1]*(dim-1) + [sharpness], dtype=float)
    return Q @ np.diag(lam) @ Q.T

def f(W, A, B):
    # 0.5 * tr(W^T A W B)
    return 0.5 * np.trace(W.T @ A @ W @ B)

def grad_f(W, A, B):
    return A @ W @ B

def muon_step(W, lr, beta, B_momentum, A, B, eps=1e-10):
    G = grad_f(W, A, B)
    B_momentum = beta * B_momentum + G 
    B_reg = B_momentum + eps * np.eye(B_momentum.shape[0])
    U, S, Vh = np.linalg.svd(B_reg)
    O = U @ Vh
    W = W - lr * O
    return W, B_momentum


dim = 3
lrs =np.array([0.9, 0.3, 0.1, 0.03, 0.01, 0.003, 0.001]).round(3)
betas = np.array([0.99, 0.95, 0.9, 0.5, 0]).round(2)
sharpnesses = np.arange(2,11)

T = 10000
rows = []

for lr in lrs:
    print(f"lr={lr}")
    for beta in betas:
        for sharpness in sharpnesses:
            A = generate_PSD_matrix(dim, sharpness)
            B = generate_PSD_matrix(dim, 1)
            W = np.ones((dim, dim))
            M = np.zeros_like(W)
            for i in range(T):
                W, M = muon_step(W, lr, beta, M, A, B)
                rows.append((lr, beta, sharpness, f(W, A, B), dim, i))

df = pd.DataFrame(rows, columns=["lr","beta","sharpness","loss","dim","t"])

beta = 0.0
sharpness = 10
xs = np.arange(0, T)
for lr in lrs:
    
    ys = df[(df["lr"] == lr) & (df["beta"] == beta) & (df["sharpness"] == sharpness)]["loss"].values
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=xs, y=ys, mode='lines', name=f"lr={lr}, beta={beta}, sharpness={sharpness}"))
    fig.update_layout(title=f"Muon on Quadratic Loss: Sharpness = {sharpness}, lr={lr}, beta={beta}", xaxis_title="Iteration", yaxis_title="Loss")
    fig.update_yaxes(range=[-0.01, 0.1])
    fig.update_xaxes(range=[0,2000])
    fig.show()


lr=0.9
lr=0.3
lr=0.1
lr=0.03
lr=0.01
lr=0.003
lr=0.001


In [24]:
X, y, X_test, y_test = fn.load_cifar_10()


dtype(): align should be passed as Python or NumPy boolean but got `align=0`. Did you mean to pass a tuple to create a subarray type? (Deprecated NumPy 2.4)



In [25]:
class Muon(torch.optim.Muon):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for g in self.param_groups:
            g["adjust_lr_fn"] = "none"

In [203]:
def max_muon_layer_sharpness(model, opt_muon, criterion, X, y, generator,
                             subsample_dim=1024, iters=30, tol=1e-4):
    # get muon weight matrices
    ps = [p for g in opt_muon.param_groups for p in g["params"]]
    # keep only 2D (weight matrices)
    muon_ws = [p for p in ps if p.ndim == 2]
    if len(muon_ws) == 0:
        raise ValueError("No 2D Muon parameters found in opt_muon.")

    # subsample
    n = X.shape[0]
    m = min(subsample_dim, n)
    idx = torch.randperm(n, device=X.device, generator=generator)[:m]
    Xs, ys = X[idx], y[idx]

    # forward once
    outputs = model(Xs)
    loss = criterion(outputs, ys)

    def power_iteration_for_param(W):
        # grad wrt W
        (gW,) = torch.autograd.grad(loss, W, create_graph=True, retain_graph=True)
        g_flat = gW.reshape(-1)
        dim = g_flat.numel()
        device = g_flat.device

        def Hv(v):
            # Hessian-vector product wrt W only
            (hW,) = torch.autograd.grad(g_flat @ v, W, retain_graph=True)
            return hW.reshape(-1)

        v = torch.randn(dim, device=device, generator=generator)
        v = v / (v.norm() + 1e-12)

        eig_old = 0.0
        for _ in range(iters):
            w = Hv(v)
            eig = (v @ w).item()
            v = w / (w.norm() + 1e-12)

            if abs(eig - eig_old) / (abs(eig_old) + 1e-12) < tol:
                break
            eig_old = eig

        w = Hv(v)
        return (v @ w).item()

    # compute per-muon-layer sharpness and take max
    lambdas = [power_iteration_for_param(W) for W in muon_ws]
    return max(lambdas)

def train_muon_model(model, opt_muon, opt_adam, criterion, epochs, accuracy, 
                     X, y, X_test, y_test, output_dir, generator):
    """Trains the provided model with the specified optimizer and criterion for 
    a set number of epochs or until the desired accuracy is reached. Records 
    training loss, training accuracy, test accuracy, and sharpness metrics at 
    each epoch.

    Args:
        model (_type_): The neural network model to train
        opt_muon (_type_): The Muon optimizer used for training
        opt_adam (_type_): The Adam optimizer used for training
        criterion (_type_): The loss function used for training
        epochs (_type_): The maximum number of training epochs
        accuracy (_type_): The target accuracy to stop training early
        X (_type_): Training input data
        y (_type_): Training target labels
        X_test (_type_): Test input data
        y_test (_type_): Test target labels
        output_dir (_type_): Directory to save output files
        generator (_type_): Random number generator for reproducibility
    """
    print(f"Training {model.__class__.__name__} with " +
          f"{opt_muon.__class__.__name__} and learning rate " +
          f"{opt_muon.param_groups[0]['lr']} for {epochs} epochs.")

    learning_rate = opt_muon.param_groups[0]['lr']
    momentum = opt_muon.param_groups[0].get('momentum', 0.0)

    model.to(device)
    model.train()

    train_losses = np.full(epochs, np.nan)
    train_accuracies = np.full(epochs, np.nan)
    test_accuracies = np.full(epochs, np.nan)
    H_sharps = np.full(epochs, np.nan)
    A_sharps = np.full(epochs, np.nan)

    if isinstance(criterion, nn.MSELoss):
        y_loss = torch.nn.functional.one_hot(
            y, num_classes=model.num_labels).float().to(device)
       
    else:
        y_loss = y.to(device)

    start = time.time()
    
    train_acc = 0.0
    epoch = 0

    while train_acc < accuracy and epoch < epochs :

        opt_muon.zero_grad(set_to_none=True)
        opt_adam.zero_grad(set_to_none=True)

        outputs = model(X)
        loss = criterion(outputs, y_loss)
        loss.backward()
        
        opt_muon.step()
        opt_adam.step()
        
        train_losses[epoch] = loss.item()
        
        if epoch % (epochs // 100) == 0:
            H_sharps[epoch], lambda_A = fn.get_hessian_metrics(
                model, opt_muon, criterion, X, y_loss, epoch+1, generator=generator
            )
            A_sharps[epoch] = max_muon_layer_sharpness(
                model, opt_muon, criterion, X, y_loss, generator=generator
            )

        with torch.no_grad():
            model.eval()
            train_preds = outputs.argmax(dim=1)
            test_preds = model(X_test).argmax(dim=1)
            train_acc = (train_preds == y).float().mean().item()
            test_acc = (test_preds == y_test).float().mean().item()
            train_accuracies[epoch] = train_acc
            test_accuracies[epoch] = test_acc
        model.train()

        if (epoch+1) % 1000 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}, " +
                  f"Time: {round(((time.time() - start) / 60), 2)}, " +
                  f"Train Acc: {train_accuracies[epoch]:.4f}, " +
                  f"Test Acc: {test_accuracies[epoch]:.4f}, ")
        epoch += 1

    metadata, output_data = fn.setup_output_files(output_dir)
    model_id = metadata.shape[0] + 1

    if 'learning_rate_adam' not in metadata.columns:
        metadata['learning_rate_adam'] = np.nan

    metadata.loc[metadata.shape[0]] ={
        "model_id": model_id,
        "model_type": model.__class__.__name__,
        "activation_function": model.activation.__name__,
        "optimizer": opt_muon.__class__.__name__,
        "criterion": criterion.__class__.__name__,
        "learning_rate": learning_rate,
        "learning_rate_adam": opt_adam.param_groups[0]['lr'],
        "beta1": momentum,
        "num_epochs": epochs,
        "time_minutes": round((time.time() - start) / 60, 2),
    }

    output_data = pd.concat([output_data, pd.DataFrame({
        "model_id": np.ones_like(train_losses) * model_id,
        "epoch": np.arange(1, epochs + 1),
        "train_loss": train_losses,
        "sharpness_H": H_sharps.round(4),
        "sharpness_A": A_sharps.round(4),
        "test_accuracy": test_accuracies,
        "train_accuracy": train_accuracies,
    })], ignore_index=True)

    fn.save_output_files(metadata, output_data, output_dir)

class MLP4(nn.Module):
    def __init__(self, input_size, hidden_layer_size, num_labels, activation):
        super().__init__()
        self.input_size = input_size
        self.hidden_layers_size = hidden_layer_size
        self.num_labels = num_labels
        self.activation = activation

        self.h1  = nn.Linear(input_size,  hidden_layer_size)
        self.h2  = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.h3  = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.h4  = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.out = nn.Linear(hidden_layer_size, num_labels)
        
        self.param_list = list(self.parameters())

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.h1(x))
        x = F.relu(self.h2(x))
        x = F.relu(self.h3(x))
        x = F.relu(self.h4(x))
        return self.out(x)

In [175]:
learning_rates = np.arange(1,1001,1)/10000
learning_rates

array([0.0001, 0.0002, 0.0003, 0.0004, 0.0005, 0.0006, 0.0007, 0.0008,
       0.0009, 0.001 , 0.0011, 0.0012, 0.0013, 0.0014, 0.0015, 0.0016,
       0.0017, 0.0018, 0.0019, 0.002 , 0.0021, 0.0022, 0.0023, 0.0024,
       0.0025, 0.0026, 0.0027, 0.0028, 0.0029, 0.003 , 0.0031, 0.0032,
       0.0033, 0.0034, 0.0035, 0.0036, 0.0037, 0.0038, 0.0039, 0.004 ,
       0.0041, 0.0042, 0.0043, 0.0044, 0.0045, 0.0046, 0.0047, 0.0048,
       0.0049, 0.005 , 0.0051, 0.0052, 0.0053, 0.0054, 0.0055, 0.0056,
       0.0057, 0.0058, 0.0059, 0.006 , 0.0061, 0.0062, 0.0063, 0.0064,
       0.0065, 0.0066, 0.0067, 0.0068, 0.0069, 0.007 , 0.0071, 0.0072,
       0.0073, 0.0074, 0.0075, 0.0076, 0.0077, 0.0078, 0.0079, 0.008 ,
       0.0081, 0.0082, 0.0083, 0.0084, 0.0085, 0.0086, 0.0087, 0.0088,
       0.0089, 0.009 , 0.0091, 0.0092, 0.0093, 0.0094, 0.0095, 0.0096,
       0.0097, 0.0098, 0.0099, 0.01  , 0.0101, 0.0102, 0.0103, 0.0104,
       0.0105, 0.0106, 0.0107, 0.0108, 0.0109, 0.011 , 0.0111, 0.0112,
      

In [204]:
output_dir = "eos/muon_MJ"
input_size = 32 * 32 * 3
hidden_layer_size = 170
num_labels = 10
activation = F.relu
criterion = nn.MSELoss()

adam_learning_rates = [1e-4, 3e-4, 1e-3]
muon_learning_rates = np.arange(1,1001,1)/10000

for adam_lr in adam_learning_rates:
    for muon_lr in muon_learning_rates:
        model = MLP4(input_size, hidden_layer_size, num_labels, activation)

        # Set Muon and Adam Parameters
        all_params = set(model.parameters())
        muon_params = {model.h2.weight, model.h3.weight}
        adamw_params = list(all_params - muon_params)

        opt_muon = Muon(muon_params, nesterov=False, lr=muon_lr, weight_decay=0)
        opt_adamw = torch.optim.Adam(adamw_params, lr=adam_lr)

        train_muon_model(
            model=model,
            opt_muon=opt_muon,
            opt_adam=opt_adamw,
            criterion=criterion,
            epochs=500,
            accuracy=1.1,
            X=X,
            y=y,
            X_test=X_test,
            y_test=y_test,
            output_dir=output_dir,
            generator=generator
        )

Training MLP4 with Muon and learning rate 0.0001 for 500 epochs.


In [209]:
output_dir = "eos/muon_MJ"
fn.delete_model_data(range(0,1000),output_dir=output_dir)

In [205]:
output_dir = "eos/muon_MJ"
md, out = fn.load_output_files(output_dir)

In [186]:
def plot_output_data(metadata, output, model_id):
    metadata = metadata[metadata['model_id']==model_id]
    output = output[output['model_id']==model_id]
    
    xs = np.arange(metadata['num_epochs'].iloc[0])
    losses = output['train_loss']
    sharpness_H = output['sharpness_H']
    sharpness_A = output['sharpness_A']
    train_accuracy = output['train_accuracy']
    test_accuracy = output['test_accuracy']
    learning_rate = metadata['learning_rate'].iloc[0]
    sharpness_H_lim = 2 * (1 + 0.9)  / ((1 - 0.9) * learning_rate)

    fig = make_subplots(rows = 2, cols = 1, 
                        specs=[[{"secondary_y": True}],
                               [{"secondary_y": True}]],
                        shared_xaxes=True,
                        vertical_spacing=0.1)
    
    fig.add_trace(
        go.Scatter(x=xs, y=losses, name="Training Loss",line=dict(width=2)),
        secondary_y=False, row=1, col=1
    )

    # fig.add_trace(
    #     go.Scatter(x=xs, y=sharpness_H, name="Max Eigenvalue of H", mode='markers', line=dict(width=2)),
    #     secondary_y=True, row=1, col=1
    # )

    fig.add_trace(
        go.Scatter(x=xs, y=sharpness_H, name="Sharpness of Hessian", mode='markers', line=dict(width=2)),
        secondary_y=True, row=1, col=1
    )

    fig.add_trace(
        go.Scatter(x=xs, y=sharpness_A, name="Max Sharpness of Muon Layers", mode='markers', line=dict(width=2)),
        secondary_y=True, row=1, col=1
    )

    fig.add_trace(
        go.Scatter(x=xs, y=test_accuracy, name="Test Accuracy", line=dict(width=2)),
        secondary_y=False, row=2, col=1
    )
    fig.add_hline(y=sharpness_H_lim, line_dash="dash", line_color="black", 
                  row=1, col=1, secondary_y=True)
    
    fig.update_yaxes(title_text="Training Loss", secondary_y=False, 
                     range = [0,0.1], showgrid=False,
                     row=1, col=1)
    fig.update_yaxes(title_text="Max Sharpness of Muon Layers", secondary_y=True, 
                     range = [0, output['sharpness_H'].max()*1.1],
                     row=1, col=1)
    
    fig.update_xaxes(title_text="epoch",
                     range = [0,output['train_loss'].notna().sum()])
    fig.update_layout(title_text = f"Stability of Muon ; learning rate = {learning_rate}", height = 1000, width = 1000)
    
    fig.show()


In [207]:
md.tail(10)

Unnamed: 0,model_id,model_type,activation_function,optimizer,criterion,learning_rate,beta1,beta2,num_epochs,time_minutes,learning_rate_adam
98,99,MLP4,relu,Muon,MSELoss,0.099,,,500,0.28,
99,100,MLP4,relu,Muon,MSELoss,0.1,,,500,0.31,
100,101,MLP4,relu,Muon,MSELoss,0.0001,,,500,0.31,
101,102,MLP4,relu,Muon,MSELoss,0.0002,,,500,0.3,
102,103,MLP4,relu,Muon,MSELoss,0.0003,,,500,0.2,
103,104,MLP4,relu,Muon,MSELoss,0.0001,,,500,0.2,
104,105,MLP4,relu,Muon,MSELoss,0.0002,,,500,0.27,
105,106,MLP4,relu,Muon,MSELoss,0.0003,,,500,0.25,
106,107,MLP4,relu,Muon,MSELoss,0.0001,,,500,0.32,
107,108,MLP4,relu,Muon,MSELoss,0.0001,0.95,,500,0.3,0.001


In [182]:
plot_output_data(md, out, model_id=103)

In [172]:
out['sharpness_H_smooth'] = (
    out[out['sharpness_H'].notna()]
    .groupby('model_id')['sharpness_H']
    .transform(lambda x: x.rolling(window=5, min_periods=2).mean())
)

max_sharpness = (
    out
    .groupby('model_id')[['sharpness_H_smooth']]
    .max()
    .merge(right=md[["model_id","learning_rate"]], left_index=True, right_on="model_id")
)

fig =go.Figure()
fig.add_trace(go.Scatter(
    x=max_sharpness['learning_rate'],
    y=max_sharpness['sharpness_H_smooth'],
    mode='markers'
))
fig.update_layout(
    title="Max Sharpness vs Learning Rate",
    xaxis_title="Learning Rate",
    yaxis_title="Max Sharpness of Hessian"
)
fig.show()