In [1]:
import torch
import math
import sys; sys.path.append(".")
from mrd import core, data, models, callbacks
import lightning as pl
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
SEED = 0
DATA_ROOT = "./data_mnist"

WIDTH = 64           # <<<<< model size knob
DEPTH = 1              # number of hidden blocks after the first
USE_LAYERNORM = False
ACTIVATION = "gelu"    # "gelu" or "tanh"
D = 32

TRAIN_BS = 4096
TEST_BS = 4096
DIAG_BS  = 4096

LR = 0.01
MOMENTUM = 1.0 # 0.9
WEIGHT_DECAY = 0 # 1e-4
NESTEROV = True
GRAD_CLIP = None # 1.0

MAX_STEPS = 400
VAL_EVERY_N_STEPS = 200
DIAG_EVERY_N_STEPS = 10
DIAG_K = 16

DEVICES = 1

PRECISION = "32-true"

In [3]:
g = torch.Generator(device="cpu"); g.manual_seed(SEED)
w = torch.randn(D, generator=g) / math.sqrt(D)
A = torch.randn(D, D, generator=g) / math.sqrt(D)

def f_star(X):  # X: (B,D)
    return data.target_sin_mix(X, w=w, A=A)

adm_cfg = data.AnalyticDMConfig(
    d=D,
    n_train=50_000,
    n_test=10_000,
    train_bs=TRAIN_BS,
    test_bs=TEST_BS,
    num_workers=1,
    noise_std=0.05,
    x_dist="normal",
    fixed_dataset=True,
    standardize_x=True,   # “whiten-ish” inputs
)

dm = data.AnalyticRegressionData(adm_cfg, target_fn=f_star)
dm.setup()

# sanity check shapes
xb, yb = next(iter(dm.train_dataloader()))

In [4]:
model = models.MLP(width=WIDTH, depth=DEPTH, in_dim=D, out_dim=1)

module = models.MRDTaskModule(
   model=model, task="regression",
    lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY, nesterov=NESTEROV
)

mrd_cb = callbacks.MRDDiagnosticsCallback(every_n_steps=DIAG_EVERY_N_STEPS, diag_bs=DIAG_BS, K=DIAG_K)

from lightning.pytorch.callbacks import TQDMProgressBar

class NotebookTQDM(TQDMProgressBar):
    def init_train_tqdm(self):
        bar = super().init_train_tqdm()
        bar.disable = False
        return bar

trainer = pl.Trainer(
    devices=DEVICES,
    precision=PRECISION,
    limit_val_batches=0,
    max_steps=MAX_STEPS,
    gradient_clip_val=GRAD_CLIP,
    log_every_n_steps=20,
    val_check_interval=None,
    enable_checkpointing=False,
    enable_model_summary=True,
    enable_progress_bar=True,
    callbacks=[NotebookTQDM(refresh_rate=20), mrd_cb],
    logger=False,   # keep notebook simple; metrics stored in mrd_cb.rows
)

trainer.fit(module, datamodule=dm)
trainer.test(module, datamodule=dm)

df = mrd_cb.dataframe()
df.tail()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA A10') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


/home/ubuntu/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=29` in the `DataLoader` to improve performance.


Epoch 0:   0%|          | 0/12 [00:00<?, ?it/s]

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
plt.figure()
plt.plot(df["step"], df.get("loss", pd.Series([float("nan")]*len(df))), label="loss")
plt.xlabel("step")
plt.ylabel("Loss")
plt.title("Metric vs Residual (exact, tiny batch)")
plt.legend()
plt.show()

plt.figure()
plt.plot(df["step"], df.get("H_model_F", pd.Series([float("nan")]*len(df))), label="||H_model||_F")
plt.plot(df["step"], df.get("H_mix_F", pd.Series([float("nan")]*len(df))), label="||H||_F")
plt.plot(df["step"], df.get("G_F", pd.Series([float("nan")]*len(df))), label="||G||_F")
plt.plot(df["step"], df.get("R_F", pd.Series([float("nan")]*len(df))), label="||R||_F")
plt.xlabel("step")
plt.ylabel("Frobenius norm")
plt.yscale("log")
plt.title("Metric vs Residual (exact, tiny batch)")
plt.legend()
plt.show()

plt.figure()
plt.plot(df["step"], df.get("Phi_align_F_fd", pd.Series([float("nan")]*len(df))), label="||Phi_align||_F")
plt.plot(df["step"], df.get("Phi_damp_F_fd", pd.Series([float("nan")]*len(df))), label="||Phi_damp||_F")
plt.plot(df["step"], df.get("Phi_trans_F_fd", pd.Series([float("nan")]*len(df))), label="||Phi_trans||_F")
plt.xlabel("step")
plt.ylabel("Frobenius norm")
plt.yscale("log")
plt.title("Three-channel magnitudes (exact, tiny batch)")
plt.legend()
plt.show()

plt.figure()
plt.plot(df["step"], df.get("m_F", pd.Series([float("nan")]*len(df))))
plt.axhline(0.0, linestyle="--")
plt.xlabel("step")
plt.ylabel("m_F")
plt.title("Frobenius margin m_F (exact, tiny batch)")
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# set this to match your run
TASK = "regression"        # "regression" or "classification"
EPS = 1e-12

steps = df["step"].to_numpy()

def col(name, default=np.nan):
    if name in df.columns:
        return df[name].to_numpy()
    return np.full(len(df), default, dtype=float)

# --- loss
loss = col("loss", np.nan)

plt.figure()
plt.plot(steps, loss, label="loss")
plt.xlabel("step")
plt.ylabel("Loss")
plt.title("Loss")
plt.legend()
plt.show()

# --- metric vs residual + (optional) model-curvature proxy
G_F = col("G_F", np.nan)
R_F = col("R_F", np.nan)

# If you *actually* logged a direct model-curvature estimate, use it.
# Suggested column names (pick one in your diagnostics code):
#   "H_model_F"  or "Hbar_F"  or "H_F"
H_model = None
for name in ["H_model_F", "Hbar_F", "H_F"]:
    if name in df.columns:
        H_model = df[name].to_numpy()
        H_model_label = name
        break

# Otherwise build a cheap proxy from what you already logged.
# Regression: R = E[e H] => ||E[H]||_F is (very crudely) ~ ||R||_F / E|e|.
# We can upper-bound E|e| by sqrt(E[e^2]) = sqrt(MSE). If "loss" is MSE, use sqrt(loss).
H_proxy = None
H_proxy_label = None
if H_model is None and TASK == "regression" and np.isfinite(loss).any():
    e_rms = np.sqrt(np.maximum(loss, 0.0))  # sqrt(MSE)
    H_proxy = R_F / (e_rms + EPS)
    H_proxy_label = "||R||_F / sqrt(MSE)  (proxy for ||E[H]||_F)"

plt.figure()
plt.plot(steps, G_F, label="||G||_F")
plt.plot(steps, R_F, label="||R||_F")

if H_model is not None:
    plt.plot(steps, H_model, label=H_model_label)
elif H_proxy is not None:
    plt.plot(steps, H_proxy, label=H_proxy_label)

plt.xlabel("step")
plt.ylabel("Frobenius norm / proxy")
plt.yscale("log")
plt.title("Metric vs Residual + Model Curvature")
plt.legend()
plt.show()

# --- channels
Pa = col("Phi_align_F_fd", np.nan)
Pd = col("Phi_damp_F_fd", np.nan)
Pt = col("Phi_trans_F_fd", np.nan)

plt.figure()
plt.plot(steps, Pa, label="||Phi_align||_F")
plt.plot(steps, Pd, label="||Phi_damp||_F")
plt.plot(steps, Pt, label="||Phi_trans||_F")
plt.xlabel("step")
plt.ylabel("Frobenius norm")
plt.yscale("log")
plt.title("Three-channel magnitudes")
plt.legend()
plt.show()

# --- Frobenius margin
m_F = col("m_F", np.nan)

plt.figure()
plt.plot(steps, m_F, label="m_F")
plt.axhline(0.0, linestyle="--")
plt.xlabel("step")
plt.ylabel("m_F")
plt.title("Frobenius margin m_F")
plt.legend()
plt.show()

In [None]:
df