In [None]:
import torch
import math
import sys; sys.path.append(".")
from mrd import core, data, models, callbacks
import lightning as pl
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
SEED = 0
DATA_ROOT = "./data_mnist"

WIDTH = 64           # <<<<< model size knob
DEPTH = 1              # number of hidden blocks after the first
USE_LAYERNORM = False
ACTIVATION = "gelu"    # "gelu" or "tanh"
D = 32

TRAIN_BS = 4096
TEST_BS = 4096
DIAG_BS  = 4096

LR = 0.01
MOMENTUM = 1.0 # 0.9
WEIGHT_DECAY = 0 # 1e-4
NESTEROV = True
GRAD_CLIP = None # 1.0

MAX_STEPS = 400
VAL_EVERY_N_STEPS = 200
DIAG_EVERY_N_STEPS = 10
DIAG_K = 16

DEVICES = 1

PRECISION = "32-true"

In [None]:
g = torch.Generator(device="cpu"); g.manual_seed(SEED)
w = torch.randn(D, generator=g) / math.sqrt(D)
A = torch.randn(D, D, generator=g) / math.sqrt(D)

def f_star(X):  # X: (B,D)
    return data.target_sin_mix(X, w=w, A=A)

adm_cfg = data.AnalyticDMConfig(
    d=D,
    n_train=50_000,
    n_test=10_000,
    train_bs=TRAIN_BS,
    test_bs=TEST_BS,
    num_workers=1,
    noise_std=0.05,
    x_dist="normal",
    fixed_dataset=True,
    standardize_x=True,   # “whiten-ish” inputs
)

dm = data.AnalyticRegressionData(adm_cfg, target_fn=f_star)
dm.setup()

# sanity check shapes
xb, yb = next(iter(dm.train_dataloader()))

In [None]:
model = models.MLP(width=WIDTH, depth=DEPTH, in_dim=D, out_dim=1)

module = models.MRDTaskModule(
   model=model, task="regression",
    lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY, nesterov=NESTEROV
)

mrd_cb = callbacks.MRDDiagnosticsCallback(every_n_steps=DIAG_EVERY_N_STEPS, diag_bs=DIAG_BS, K=DIAG_K)

from lightning.pytorch.callbacks import TQDMProgressBar

class NotebookTQDM(TQDMProgressBar):
    def init_train_tqdm(self):
        bar = super().init_train_tqdm()
        bar.disable = False
        return bar

trainer = pl.Trainer(
    devices=DEVICES,
    precision=PRECISION,
    limit_val_batches=0,
    max_steps=MAX_STEPS,
    gradient_clip_val=GRAD_CLIP,
    log_every_n_steps=20,
    val_check_interval=None,
    enable_checkpointing=False,
    enable_model_summary=True,
    enable_progress_bar=True,
    callbacks=[NotebookTQDM(refresh_rate=20), mrd_cb],
    logger=False,   # keep notebook simple; metrics stored in mrd_cb.rows
)

trainer.fit(module, datamodule=dm)
trainer.test(module, datamodule=dm)

df = mrd_cb.dataframe()
df.tail()

In [None]:
plt.figure()
plt.plot(df["step"], df.get("loss", pd.Series([float("nan")]*len(df))), label="loss")
plt.xlabel("step")
plt.ylabel("Loss")
plt.title("Metric vs Residual (exact, tiny batch)")
plt.legend()
plt.show()

plt.figure()
plt.plot(df["step"], df.get("H_model_F", pd.Series([float("nan")]*len(df))), label="||H_model||_F")
plt.plot(df["step"], df.get("H_mix_F", pd.Series([float("nan")]*len(df))), label="||H||_F")
plt.plot(df["step"], df.get("G_F", pd.Series([float("nan")]*len(df))), label="||G||_F")
plt.plot(df["step"], df.get("R_F", pd.Series([float("nan")]*len(df))), label="||R||_F")
plt.xlabel("step")
plt.ylabel("Frobenius norm")
plt.yscale("log")
plt.title("Metric vs Residual (exact, tiny batch)")
plt.legend()
plt.show()

plt.figure()
plt.plot(df["step"], df.get("Phi_align_F_fd", pd.Series([float("nan")]*len(df))), label="||Phi_align||_F")
plt.plot(df["step"], df.get("Phi_damp_F_fd", pd.Series([float("nan")]*len(df))), label="||Phi_damp||_F")
plt.plot(df["step"], df.get("Phi_trans_F_fd", pd.Series([float("nan")]*len(df))), label="||Phi_trans||_F")
plt.xlabel("step")
plt.ylabel("Frobenius norm")
plt.yscale("log")
plt.title("Three-channel magnitudes (exact, tiny batch)")
plt.legend()
plt.show()

plt.figure()
plt.plot(df["step"], df.get("m_F", pd.Series([float("nan")]*len(df))))
plt.axhline(0.0, linestyle="--")
plt.xlabel("step")
plt.ylabel("m_F")
plt.title("Frobenius margin m_F (exact, tiny batch)")
plt.show()