In [None]:
import torch
import math
import sys; sys.path.append(".")
from mrd import core, data, models, callbacks
import lightning as pl
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [None]:
SEED = 0
DATA_ROOT = "./data_mnist"

WIDTH = 64           # <<<<< model size knob
DEPTH = 1              # number of hidden blocks after the first
USE_LAYERNORM = False
ACTIVATION = "gelu"    # "gelu" or "tanh"
D = 8

TRAIN_BS = 1024
TEST_BS = 1024
DIAG_BS  = 1024

LR = 0.001
MOMENTUM = 0.99
WEIGHT_DECAY = 0 # 1e-4
NESTEROV = True
GRAD_CLIP = None # 1.0

MAX_STEPS = 1000
VAL_EVERY_N_STEPS = 200
DIAG_EVERY_N_STEPS = 10
DIAG_K = 16

DEVICES = 1

PRECISION = "32-true"

In [None]:
g = torch.Generator(device="cpu"); g.manual_seed(SEED)
w = torch.randn(D, generator=g) / math.sqrt(D)
A = torch.randn(D, D, generator=g) / math.sqrt(D)

def f_star(X):  # X: (B,D)
    return data.target_sin_mix(X, w=w, A=A)

adm_cfg = data.AnalyticDMConfig(
    d=D,
    n_train=50_000,
    n_test=10_000,
    train_bs=TRAIN_BS,
    test_bs=TEST_BS,
    num_workers=1,
    noise_std=0.05,
    x_dist="normal",
    fixed_dataset=True,
    standardize_x=True,   # “whiten-ish” inputs
)

dm = data.AnalyticRegressionData(adm_cfg, target_fn=f_star)
dm.setup()

# sanity check shapes
xb, yb = next(iter(dm.train_dataloader()))

In [None]:
model = models.MLP(width=WIDTH, depth=DEPTH, in_dim=D, out_dim=1)

module = models.MRDTaskModule(
   model=model, task="regression",
    lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY, nesterov=NESTEROV
)

mrd_cb = callbacks.ExactCurvatureCallback(every_n_steps=10)

from lightning.pytorch.callbacks import TQDMProgressBar

class NotebookTQDM(TQDMProgressBar):
    def init_train_tqdm(self):
        bar = super().init_train_tqdm()
        bar.disable = False
        return bar

trainer = pl.Trainer(
    devices=DEVICES,
    precision="bf16-mixed",
    limit_val_batches=0,
    max_steps=MAX_STEPS,
    gradient_clip_val=GRAD_CLIP,
    log_every_n_steps=20,
    val_check_interval=None,
    enable_checkpointing=False,
    enable_model_summary=True,
    enable_progress_bar=True,
    callbacks=[NotebookTQDM(refresh_rate=20), mrd_cb],
    logger=False,   # keep notebook simple; metrics stored in mrd_cb.rows
)

trainer.fit(module, datamodule=dm)
# trainer.test(module, datamodule=dm)

# df = mrd_cb.dataframe()
# df.tail()

In [None]:
df = mrd_cb.dataframe()

def abs_max(lmin, lmax):
    return np.maximum(np.abs(lmin), np.abs(lmax))

def kappa_psd(lmin, lmax, eps=1e-12):
    # For (numerically) PSD matrices only
    return np.where(lmin > eps, lmax / lmin, np.inf)

def kappa_extreme_abs(lmin, lmax, eps=1e-12):
    # Uses only stored extremes: max(|λ|)/min(|λ_extremes|)
    a = np.maximum(np.abs(lmin), np.abs(lmax))
    b = np.maximum(np.minimum(np.abs(lmin), np.abs(lmax)), eps)
    return a / b
    
df["H_op"]  = abs_max(df["H_lmin"],  df["H_lmax"])
df["G_op"]  = abs_max(df["G_lmin"],  df["G_lmax"])
df["R_op"]  = abs_max(df["R_lmin"],  df["R_lmax"])
df["Hz_op"] = abs_max(df["Hztheta_lmin"], df["Hztheta_lmax"])

df["metric_dominance"] = df["G_op"] / (df["R_op"] + 1e-12)
df["model_dominance"]  = df["Hz_op"] / (df["G_op"] + 1e-12)

eps = 1e-12

# True PSD-style where meaningful
df["kappa_G_psd"]  = kappa_psd(df["G_lmin"].to_numpy(),  df["G_lmax"].to_numpy(),  eps)
# For the others, provide the always-defined extreme-based proxy
df["kappa_H_ext"]  = kappa_extreme_abs(df["H_lmin"].to_numpy(),  df["H_lmax"].to_numpy(),  eps)
df["kappa_R_ext"]  = kappa_extreme_abs(df["R_lmin"].to_numpy(),  df["R_lmax"].to_numpy(),  eps)
df["kappa_Hz_ext"] = kappa_extreme_abs(df["Hztheta_lmin"].to_numpy(), df["Hztheta_lmax"].to_numpy(), eps)

In [None]:
plt.plot(df["step"], df["loss"])
plt.xlabel("step")
plt.ylabel("loss")
plt.title("Loss by step")

In [None]:
plt.figure()
plt.plot(df["step"], df["H_op"],  label="||H_loss||")
plt.plot(df["step"], df["G_op"],  label="||G||")
plt.plot(df["step"], df["R_op"],  label="||R||")
plt.plot(df["step"], df["Hz_op"], label="||∂²θ z||")
plt.yscale("log")
plt.xlabel("step")
plt.ylabel("operator norm (exact)")
plt.title("Exact curvature magnitudes")
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(df["step"], df["metric_dominance"], label="||G|| / ||R||")
plt.plot(df["step"], df["model_dominance"],  label="||∂²θ z|| / ||G||")
plt.axhline(1.0, linestyle="--", color="black")
plt.yscale("log")
plt.xlabel("step")
plt.ylabel("dominance ratio")
plt.title("Curvature dominance ratios (exact)")
plt.legend()
plt.show()

In [None]:
plt.figure()
plt.plot(df["step"], df["H_lmax"] / (df["G_lmax"] + 1e-12), label="λ_max(H)/λ_max(G)")
plt.plot(df["step"], df["H_lmin"] / (df["G_lmin"] + 1e-12), label="λ_min(H)/λ_min(G)")
plt.axhline(1.0, linestyle="--", color="black")
plt.xlabel("step")
plt.ylabel("ratio")
plt.title("Loss–metric spectral alignment")
plt.legend()
plt.show()

In [None]:
plt.figure()
# plt.plot(df["step"], df["kappa_G_psd"],  label="κ(G)  (PSD)")
plt.plot(df["step"], df["kappa_H_ext"],  label="κ~(H)")
plt.plot(df["step"], df["kappa_R_ext"],  label="κ~(R)")
plt.plot(df["step"], df["kappa_Hz_ext"], label="κ~(∂²θ z)")
plt.yscale("log")
plt.xlabel("step")
plt.ylabel("condition number (log)")
plt.title("Conditioning over training")
plt.legend()
plt.show()

In [None]:
P = df["P"]
plt.plot(df["step"], df["H_morse"] / P, label="H / P")
plt.plot(df["step"], df["G_morse"] / P, label="G / P")
plt.plot(df["step"], df["R_morse"] / P, label="R / P")
plt.plot(df["step"], df["Hztheta_morse"] / P, label="Hzθ / P")
plt.xlabel("step")
plt.ylabel("fraction of negative directions")
plt.title("Normalized Morse indices")
plt.legend()
plt.tight_layout()
plt.show()