In [1]:
import sys
from pathlib import Path

project_root = Path("..").resolve()
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

import src.seed as seed
import src.models as models
import src.functions as fn

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import time
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device = seed.device
generator = seed.generator

In [2]:
X, y, X_test, y_test = fn.load_cifar_10()

  entry = pickle.load(f, encoding="latin1")


In [3]:
class Muon(torch.optim.Muon):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for g in self.param_groups:
            g["adjust_lr_fn"] = "none"

In [None]:
def newtonschulz5(G, steps=5, eps=1e-7, use_bfloat16=True):
    assert G.ndim == 2
    a, b, c = (3.4445, -4.7750, 2.0315)
    X = G.to(torch.bfloat16) if use_bfloat16 else G
    X = X / (X.norm() + eps)
    transposed = False
    if G.size(0) > G.size(1):
        X = X.T
        transposed = True
    for _ in range(steps):
        A = X @ X.T
        B = b * A + c * (A @ A)
        X = a * X + B @ X
    if transposed:
        X = X.T
    return X.to(G.dtype)

def max_muon_layer_sharpness(model, opt_muon, criterion, X, y, generator,
                             subsample_dim=1024, iters=30, tol=1e-4):
    # get muon weight matrices
    ps = [p for g in opt_muon.param_groups for p in g["params"]]
    # keep only 2D (weight matrices)
    muon_ws = [p for p in ps if p.ndim == 2]
    if len(muon_ws) == 0:
        raise ValueError("No 2D Muon parameters found in opt_muon.")

    # subsample
    n = X.shape[0]
    m = min(subsample_dim, n)
    idx = torch.randperm(n, device=X.device, generator=generator)[:m]
    Xs, ys = X[idx], y[idx]

    # forward once
    outputs = model(Xs)
    loss = criterion(outputs, ys)

    def power_iteration_for_param(W):
        # grad wrt W
        (gW,) = torch.autograd.grad(loss, W, create_graph=True, retain_graph=True)
        g_flat = gW.reshape(-1)
        dim = g_flat.numel()
        device = g_flat.device

        def Hv(v):
            # Hessian-vector product wrt W only
            (hW,) = torch.autograd.grad(g_flat @ v, W, retain_graph=True)
            return hW.reshape(-1)

        v = torch.randn(dim, device=device, generator=generator)
        v = v / (v.norm() + 1e-12)

        eig_old = 0.0
        for _ in range(iters):
            w = Hv(v)
            eig = (v @ w).item()
            v = w / (w.norm() + 1e-12)

            if abs(eig - eig_old) / (abs(eig_old) + 1e-12) < tol:
                break
            eig_old = eig

        w = Hv(v)
        return (v @ w).item()

    # compute per-muon-layer sharpness and take max
    lambdas = [power_iteration_for_param(W) for W in muon_ws]
    return max(lambdas)

def train_muon_model(model, opt_muon, opt_adam, criterion, epochs, accuracy, 
                     X, y, X_test, y_test, output_dir, generator):
    """Trains the provided model with the specified optimizer and criterion for 
    a set number of epochs or until the desired accuracy is reached. Records 
    training loss, training accuracy, test accuracy, and sharpness metrics at 
    each epoch.

    Args:
        model (_type_): The neural network model to train
        opt_muon (_type_): The Muon optimizer used for training
        opt_adam (_type_): The Adam optimizer used for training
        criterion (_type_): The loss function used for training
        epochs (_type_): The maximum number of training epochs
        accuracy (_type_): The target accuracy to stop training early
        X (_type_): Training input data
        y (_type_): Training target labels
        X_test (_type_): Test input data
        y_test (_type_): Test target labels
        output_dir (_type_): Directory to save output files
        generator (_type_): Random number generator for reproducibility
    """
    print(f"Training {model.__class__.__name__} with " +
          f"{opt_muon.__class__.__name__} and learning rate " +
          f"{opt_muon.param_groups[0]['lr']} for {epochs} epochs.")

    learning_rate = opt_muon.param_groups[0]['lr']
    momentum = opt_muon.param_groups[0].get('momentum', 0.0)

    model.to(device)
    model.train()

    train_losses = np.full(epochs, np.nan)
    train_accuracies = np.full(epochs, np.nan)
    test_accuracies = np.full(epochs, np.nan)
    H_sharps = np.full(epochs, np.nan)
    A_sharps = np.full(epochs, np.nan)

    if isinstance(criterion, nn.MSELoss):
        y_loss = torch.nn.functional.one_hot(
            y, num_classes=model.num_labels).float().to(device)
       
    else:
        y_loss = y.to(device)

    start = time.time()
    
    train_acc = 0.0
    epoch = 0

    while train_acc < accuracy and epoch < epochs :

        opt_muon.zero_grad(set_to_none=True)
        opt_adam.zero_grad(set_to_none=True)

        outputs = model(X)
        loss = criterion(outputs, y_loss)
        loss.backward()
        
        opt_muon.step()
        opt_adam.step()
        
        train_losses[epoch] = loss.item()
        
        if epoch % (epochs // 100) == 0:
            lambda_H, lambda_A = fn.get_hessian_metrics(
                model, opt_muon, criterion, X, y_loss, generator=generator
            )
            H_sharps[epoch] = lambda_H

            A_sharps[epoch] = max_muon_layer_sharpness(
                model, opt_muon, criterion, X, y_loss, generator=generator
            )

        with torch.no_grad():
            model.eval()
            train_preds = outputs.argmax(dim=1)
            test_preds = model(X_test).argmax(dim=1)
            train_acc = (train_preds == y).float().mean().item()
            test_acc = (test_preds == y_test).float().mean().item()
            train_accuracies[epoch] = train_acc
            test_accuracies[epoch] = test_acc
        model.train()

        if (epoch+1) % 1000 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}, " +
                  f"Time: {round(((time.time() - start) / 60), 2)}, " +
                  f"Train Acc: {train_accuracies[epoch]:.4f}, " +
                  f"Test Acc: {test_accuracies[epoch]:.4f}, ")
        epoch += 1

    metadata, output_data = fn.setup_output_files(output_dir)
    model_id = metadata.shape[0] + 1

    metadata.loc[metadata.shape[0]] ={
        "model_id": model_id,
        "model_type": model.__class__.__name__,
        "activation_function": model.activation.__name__,
        "optimizer": opt_muon.__class__.__name__,
        "criterion": criterion.__class__.__name__,
        "learning_rate": learning_rate,
        "momentum": momentum,
        "num_epochs": epochs,
        "time_minutes": round((time.time() - start) / 60, 2),
    }

    output_data = pd.concat([output_data, pd.DataFrame({
        "model_id": np.ones_like(train_losses) * model_id,
        "epoch": np.arange(1, epochs + 1),
        "train_loss": train_losses,
        "sharpness_H": H_sharps.round(4),
        "sharpness_A": A_sharps.round(4),
        "test_accuracy": test_accuracies,
        "train_accuracy": train_accuracies,
    })], ignore_index=True)

    fn.save_output_files(metadata, output_data, output_dir)

class MLP4(nn.Module):
    def __init__(self, input_size, hidden_layer_size, num_labels, activation):
        super().__init__()
        # Reset seed for reproducible initialization
        torch.manual_seed(seed.SEED)
        
        self.input_size = input_size
        self.hidden_layers_size = hidden_layer_size
        self.num_labels = num_labels
        self.activation = activation

        self.h1  = nn.Linear(input_size,  hidden_layer_size)
        self.h2  = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.h3  = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.h4  = nn.Linear(hidden_layer_size, hidden_layer_size)
        self.out = nn.Linear(hidden_layer_size, num_labels)
        
        self.param_list = list(self.parameters())

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.h1(x))
        x = F.relu(self.h2(x))
        x = F.relu(self.h3(x))
        x = F.relu(self.h4(x))
        return self.out(x)


In [None]:
output_dir = "eos/muon_MJ"
input_size = 32 * 32 * 3
hidden_layer_size = 170
num_labels = 10
activation = F.relu
criterion = nn.MSELoss()

learning_rates = [1e-2, 3e-3, 1e-3, 3e-4, 1e-4]

for lr in learning_rates:
    model = MLP4(input_size, hidden_layer_size, num_labels, activation)

    # Set Muon and Adam Parameters
    muon_params = [model.h2.weight, model.h3.weight]
    adamw_params = [
        model.h1.weight, model.h1.bias,
        model.h2.bias,
        model.h3.bias,
        model.h4.weight, model.h4.bias,
        model.out.weight, model.out.bias
    ]

    opt_muon = Muon(muon_params, lr=lr, weight_decay=0)
    opt_adamw = torch.optim.Adam(adamw_params, lr=lr)

    train_muon_model(
        model=model,
        opt_muon=opt_muon,
        opt_adam=opt_adamw,
        criterion=criterion,
        epochs=500,
        accuracy=1,
        X=X,
        y=y,
        X_test=X_test,
        y_test=y_test,
        output_dir=output_dir,
        generator=generator
    )


Training MLP4 with Muon and learning rate 0.01 for 500 epochs.



The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.



Training MLP4 with Muon and learning rate 0.003 for 500 epochs.
Training MLP4 with Muon and learning rate 0.001 for 500 epochs.
Training MLP4 with Muon and learning rate 0.0003 for 500 epochs.
Training MLP4 with Muon and learning rate 0.0001 for 500 epochs.


In [42]:
output_dir = "eos/muon_MJ"
fn.delete_model_data(range(10),output_dir=output_dir)

In [45]:
md, out = fn.load_output_files(output_dir)

In [56]:
def plot_output_data(metadata, output, model_id):
    metadata = metadata[metadata['model_id']==model_id]
    output = output[output['model_id']==model_id]
    
    xs = np.arange(metadata['num_epochs'].iloc[0])
    losses = output['train_loss']
    sharpness_H = output['sharpness_H']
    sharpness_A = output['sharpness_A']
    train_accuracy = output['train_accuracy']
    test_accuracy = output['test_accuracy']
    learning_rate = metadata['learning_rate'].iloc[0]
    sharpness_H_lim = 2 * (1 + 0.9)  / ((1 - 0.9) * learning_rate)

    fig = make_subplots(rows = 2, cols = 1, 
                        specs=[[{"secondary_y": True}],
                               [{"secondary_y": True}]],
                        shared_xaxes=True,
                        vertical_spacing=0.1)
    
    fig.add_trace(
        go.Scatter(x=xs, y=losses, name="Training Loss",line=dict(width=2)),
        secondary_y=False, row=1, col=1
    )

    # fig.add_trace(
    #     go.Scatter(x=xs, y=sharpness_H, name="Max Eigenvalue of H", mode='markers', line=dict(width=2)),
    #     secondary_y=True, row=1, col=1
    # )

    fig.add_trace(
        go.Scatter(x=xs, y=sharpness_H, name="Hessian Sharpness", mode='markers', line=dict(width=2)),
        secondary_y=True, row=1, col=1
    )

    fig.add_trace(
        go.Scatter(x=xs, y=sharpness_A, name="Max Sharpness of Muon Layers", mode='markers', line=dict(width=2)),
        secondary_y=True, row=1, col=1
    )

    fig.add_trace(
        go.Scatter(x=xs, y=test_accuracy, name="Test Accuracy", line=dict(width=2)),
        secondary_y=False, row=2, col=1
    )
    fig.add_hline(y=sharpness_H_lim, line_dash="dash", line_color="black", 
                  row=1, col=1, secondary_y=True)
    
    fig.update_yaxes(title_text="Training Loss", secondary_y=False, 
                     range = [0,0.1], showgrid=False,
                     row=1, col=1)
    fig.update_yaxes(title_text="Max Sharpness of Muon Layers", secondary_y=True, 
                     range = [0, output['sharpness_H'].max()*1.1],
                     row=1, col=1)
    
    fig.update_xaxes(title_text="epoch",
                     range = [0,output['train_loss'].notna().sum()])
    fig.update_layout(title_text = f"Stability of Muon ; learning rate = {learning_rate}", height = 1000, width = 1000)
    
    fig.show()


In [66]:
plot_output_data(md, out, model_id=4)