In [1]:
import sys
from pathlib import Path

project_root = Path("..").resolve()
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

import src.seed as seed
import src.models as models
import src.functions as fn

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import time
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device = seed.device
generator = seed.generator

In [2]:
X, y, X_test, y_test = fn.load_cifar_10()

  entry = pickle.load(f, encoding="latin1")


In [3]:
from torch.utils.data import TensorDataset, DataLoader

# Define model parameters
output_dir = "eos/sgd_MJ"
input_size = X.shape[1] * X.shape[2] * X.shape[3]
num_hidden_layers = 2
hidden_layer_size = 200
batch_size = 128

# Create data loaders
train_dataset = TensorDataset(X, y)
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True,
    generator=seed.cpu_generator
)

test_dataset = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# SGD training parameters
epochs = 500
lr = 0.1

model = models.FullyConnectedNet(
    input_size=input_size,
    num_hidden_layers=num_hidden_layers,
    hidden_layer_size=hidden_layer_size,
    num_labels=10,
    activation=nn.Tanh
)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

# Train with mini-batch SGD, compute Hessian on full batch
fn.train_sgd_model(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    epochs=epochs,
    accuracy=0.99,
    train_loader=train_loader,
    test_loader=test_loader,
    X=X,
    y=y,
    output_dir=output_dir
)


Training FullyConnectedNet with SGD and learning rate 0.1 for 500 epochs.


In [4]:
md, out = fn.load_output_files(output_dir)

In [5]:
def plot_output_data(metadata, output, model_id):
    metadata = metadata[metadata['model_id']==model_id]
    output = output[output['model_id']==model_id]
    
    xs = np.arange(metadata['num_epochs'].iloc[0])
    losses = output['train_loss']
    sharpness_H = output['sharpness_H']
    sharpness_A = output['sharpness_A']
    train_accuracy = output['train_accuracy']
    test_accuracy = output['test_accuracy']
    learning_rate = metadata['learning_rate'].iloc[0]
    sharpness_H_lim = 2 * (1 + 0.9)  / ((1 - 0.9) * learning_rate)

    fig = make_subplots(rows = 2, cols = 1, 
                        specs=[[{"secondary_y": True}],
                               [{"secondary_y": True}]],
                        shared_xaxes=True,
                        vertical_spacing=0.1)
    
    fig.add_trace(
        go.Scatter(x=xs, y=losses, name="Training Loss",line=dict(width=2)),
        secondary_y=False, row=1, col=1
    )

    # fig.add_trace(
    #     go.Scatter(x=xs, y=sharpness_H, name="Max Eigenvalue of H", mode='markers', line=dict(width=2)),
    #     secondary_y=True, row=1, col=1
    # )

    fig.add_trace(
        go.Scatter(x=xs, y=sharpness_H, name="Max Sharpness of Muon Layers", mode='markers', line=dict(width=2)),
        secondary_y=True, row=1, col=1
    )

    fig.add_trace(
        go.Scatter(x=xs, y=test_accuracy, name="Test Accuracy", line=dict(width=2)),
        secondary_y=False, row=2, col=1
    )
    fig.add_hline(y=sharpness_H_lim, line_dash="dash", line_color="black", 
                  row=1, col=1, secondary_y=True)
    
    fig.update_yaxes(title_text="Training Loss", secondary_y=False, 
                     range = [0,0.1], showgrid=False,
                     row=1, col=1)
    fig.update_yaxes(title_text="Max Sharpness of Muon Layers", secondary_y=True, 
                     range = [0, output['sharpness_H'].max()*1.1],
                     row=1, col=1)
    
    fig.update_xaxes(title_text="epoch",
                     range = [0,output['train_loss'].notna().sum()])
    fig.update_layout(title_text = f"Stability of Muon ; learning rate = {learning_rate}", height = 1000, width = 1000)
    
    fig.show()


In [6]:
plot_output_data(md, out, model_id=1)