# Modular INR Training

This notebook uses `inr.dataloader`, `inr.model`, and `inr.train` so that the model, training loop, and data loading are defined once in Python modules instead of being redefined here.

In [None]:
# Imports and configuration
from inr.train import train_inr, evaluate_inr

config = {
    "DATA_ROOT": "../data/BraTS-2023",
    "CASE_LIMIT": 256,
    "NUM_FOLDS": 5,
    "FOLD_INDEX": 0,
    "GLOBAL_BATCH_SIZE": 16384,
    "MICRO_BATCH_SIZE": 4096,
    "FOURIER_FREQS": 8,
    "HIDDEN_DIMS": [16, 16, 16, 16],
    "LR": 1e-3,
    "MIN_LR": 5e-4,
    "WARMUP_STEPS": 100,
    "TRAIN_STEPS": 1000,
    "RNG_SEED": 42,
    "NUM_CLASSES": 4,
    "DICE_WEIGHT": 0.8,
    "CLASS_WEIGHTS": [0.3, 1.0, 0.8, 1.0],
    "CLIP_NORM": 1.0,
    "OPTIMIZER_CHOICE": "muon",
    # Weights & Biases configuration (used if use_wandb=True)
    "WANDB_PROJECT": "brats-inr-segmentation",
    "WANDB_ENTITY": None,
    "WANDB_RUN_NAME": None,
    "WANDB_TAGS": ["fourier-features", "INR", "medical-imaging", "segmentation"],
    "WANDB_NOTES": "Fourier Feature INR modular training"
}

config

In [None]:
# Train model
# Set use_wandb=False if you don't want to log to Weights & Biases.
params, state = train_inr(config, use_wandb=True, resume_from=None)
print("Training complete.")

In [None]:
# Evaluate model on validation (or training) set
from inr.train import evaluate_inr
metrics, artifacts = evaluate_inr(params, state, config, use_wandb=True)
metrics

In [None]:
import wandb
wandb.finish()

In [None]:
# Interactive inference viewer for MU-Glioma-Post hold-out cases
import ipywidgets as widgets
from IPython.display import display, clear_output
import pathlib as _pl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from inr.model import model_load, predict_volume, dice_score

# -------------------------------------------------------------------
# 1. Load trained INR model from checkpoint
# -------------------------------------------------------------------
# Set this to your final model NPZ. Example:
# CHECKPOINT_PATH = "../artifacts/brats-inr-segmentation/effortless-resonance-214/effortless-resonance-214.npz"
CHECKPOINT_PATH = '../artifacts/brats-inr-segmentation/vocal-planet-216/vocal-planet-216.npz'  # <-- set this before running the cell

if CHECKPOINT_PATH is None:
    raise ValueError("Please set CHECKPOINT_PATH to a trained INR NPZ file.")

params, inr_cfg = model_load(CHECKPOINT_PATH)
cfg = inr_cfg.get("config", inr_cfg)
print("Loaded INR model from:", CHECKPOINT_PATH)
print("FOURIER_FREQS:", cfg.get("FOURIER_FREQS"))
print("NUM_CLASSES:", cfg.get("NUM_CLASSES"))

# -------------------------------------------------------------------
# 2. Configure MU-Glioma-Post paths
# -------------------------------------------------------------------
PROJECT_ROOT = _pl.Path("..").resolve()
DATA_ROOT = PROJECT_ROOT / "data"
MU_GLIOMA_MANIFEST = DATA_ROOT / "MU-Glioma-Post" / "manifest.csv"

# Load manifest (columns: id, t1, t1ce, t2, flair, mask)
mu_cases = pd.read_csv(MU_GLIOMA_MANIFEST)

# Build dropdown options: (label, id)
case_options = [(row["id"], row["id"]) for _, row in mu_cases.iterrows()]
print("MU-Glioma-Post cases:", len(case_options))

hold_pred_cache = {}  # case_id -> pred volume
state = {"mods": None, "true": None, "pred": None}

# -------------------------------------------------------------------
# 3. Case loading + prediction
# -------------------------------------------------------------------
def _load_mu_case_by_id(case_id: str):
    row = mu_cases.loc[mu_cases["id"] == case_id].iloc[0]

    def _load_nii(rel_path):
        import nibabel as nib
        fp = DATA_ROOT / rel_path
        img = nib.load(str(fp))
        return img.get_fdata().astype(np.float32)

    mods_list = []

    # 1) T1 (may be empty for some rows)
    if isinstance(row["t1"], str) and row["t1"].strip():
        arr = _load_nii(row["t1"])
        mask = arr != 0
        if mask.any():
            mu = arr[mask].mean()
            sigma = arr[mask].std() + 1e-6
            arr = (arr - mu) / sigma
    else:
        # If T1 is missing, use zeros with the same shape as t1ce
        arr_ref = _load_nii(row["t1ce"])
        arr = np.zeros_like(arr_ref, dtype=np.float32)
    mods_list.append(arr)

    # 2) T1ce, T2, Flair (always present in this dataset)
    for key in ["t1ce", "t2", "flair"]:
        arr = _load_nii(row[key])
        mask = arr != 0
        if mask.any():
            mu = arr[mask].mean()
            sigma = arr[mask].std() + 1e-6
            arr = (arr - mu) / sigma
        mods_list.append(arr)

    # Load segmentation mask as int16
    import nibabel as nib
    seg_fp = DATA_ROOT / row["mask"]
    seg = nib.load(str(seg_fp)).get_fdata().astype(np.int16)

    mods = np.stack(mods_list, axis=0)
    return mods, seg


def load_and_predict_cached(case_id: str):
    mods, seg = _load_mu_case_by_id(case_id)
    if case_id not in hold_pred_cache:
        pred, _ = predict_volume(
            params,
            {"mods": mods, "seg": seg},
            fourier_freqs=int(cfg["FOURIER_FREQS"]),
            chunk=120000,
        )
        hold_pred_cache[case_id] = pred
    return mods, seg, hold_pred_cache[case_id]

# -------------------------------------------------------------------
# 4. Utility metrics + visualization
# -------------------------------------------------------------------
def _dice_macro_slice(pred2d, true2d, num_classes: int):
    scores = []
    for c in range(num_classes):
        p = (pred2d == c)
        t = (true2d == c)
        inter = (p & t).sum()
        denom = p.sum() + t.sum()
        if denom > 0:
            scores.append((2 * inter + 1e-6) / (denom + 1e-6))
    return float(np.mean(scores)) if scores else float("nan")


def _psnr_slice(pred2d, true2d, max_val: float):
    mse = np.mean((np.asarray(pred2d, dtype=np.float32) - np.asarray(true2d, dtype=np.float32)) ** 2)
    if mse <= 1e-12:
        return float("inf")
    return float(10.0 * np.log10((max_val * max_val) / (mse + 1e-12)))


def visualize_modalities_with_overlays(mods, seg_gt, seg_pred, z):
    plt.ioff()
    M = mods.shape[0]
    fig, axes = plt.subplots(2, M, figsize=(3 * M, 6))
    if M == 1:
        axes = np.array([[axes[0]], [axes[1]]], dtype=object)

    num_classes = int(cfg["NUM_CLASSES"])
    for m in range(M):
        ax_gt = axes[0, m]
        ax_pred = axes[1, m]

        ax_gt.imshow(mods[m, :, :, z], cmap="gray")
        ax_gt.imshow(seg_gt[:, :, z], cmap="tab10", alpha=0.35, vmin=0, vmax=num_classes - 1)
        ax_gt.set_title(f"Mod {m} + GT", fontsize=10)
        ax_gt.axis("off")

        ax_pred.imshow(mods[m, :, :, z], cmap="gray")
        ax_pred.imshow(seg_pred[:, :, z], cmap="tab10", alpha=0.35, vmin=0, vmax=num_classes - 1)
        d = _dice_macro_slice(seg_pred[:, :, z], seg_gt[:, :, z], num_classes=num_classes)
        p = _psnr_slice(seg_pred[:, :, z], seg_gt[:, :, z], max_val=num_classes - 1)
        ax_pred.set_title(f"Mod {m} + Pred", fontsize=10)
        ax_pred.text(
            0.01,
            0.99,
            f"Dice {d:.3f} PSNR {p:.2f} dB",
            transform=ax_pred.transAxes,
            ha="left",
            va="top",
            fontsize=8,
            color="yellow",
            bbox=dict(boxstyle="round", fc="black", alpha=0.5, pad=0.4),
        )
        ax_pred.axis("off")

    plt.tight_layout()
    return fig


# -------------------------------------------------------------------
# 5. Widgets + callbacks
# -------------------------------------------------------------------
out = widgets.Output()
if len(case_options) == 0:
    print("No MU-Glioma-Post cases available to visualize.")
else:
    dd_hold = widgets.Dropdown(options=case_options, description="MU case:")

    # Initialize first case
    state["mods"], state["true"], state["pred"] = load_and_predict_cached(dd_hold.value)
    z_slider = widgets.IntSlider(
        min=0,
        max=int(state["pred"].shape[2] - 1),
        value=int(state["pred"].shape[2] // 2),
        description="Slice z",
    )

    def render_slice(z):
        with out:
            clear_output(wait=True)
            fig = visualize_modalities_with_overlays(state["mods"], state["true"], state["pred"], int(z))
            display(fig)
            plt.close(fig)

    def on_slice_change(change):
        render_slice(change["new"])

    def on_case_change(change):
        state["mods"], state["true"], state["pred"] = load_and_predict_cached(change["new"])
        z_slider.unobserve(on_slice_change, names="value")
        z_slider.max = int(state["pred"].shape[2] - 1)
        z_slider.value = int(state["pred"].shape[2] // 2)
        z_slider.observe(on_slice_change, names="value")
        render_slice(z_slider.value)

    z_slider.observe(on_slice_change, names="value")
    dd_hold.observe(on_case_change, names="value")
    display(widgets.VBox([dd_hold, z_slider, out]))
    render_slice(z_slider.value)
