# SIREN INR Viewer

Load pretrained SIREN implicit segmentation weights (`artifacts/inr_siren_brats23.npz`) and interactively browse predictions across BraTS cases and slices.

Features:
- Infers hidden layer sizes from checkpoint.
- Efficient chunked full-volume prediction.
- Interactive dropdown (case) + slider (slice).
- Overlay GT + prediction per modality with per-slice Dice & PSNR.

Assumptions: Model trained with input = concat([x,y,z] normalized to [-1,1], 4 modality intensities).

In [None]:
import pathlib, os, math
import numpy as np
import nibabel as nib
import jax, jax.numpy as jnp
import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

DATA_ROOT = pathlib.Path('../data/BraTS-2023')
WEIGHTS_PATH = pathlib.Path('../artifacts/inr_siren_brats23.npz')
CHUNK = 120_000  # adjust if memory/time issues
OMEGA_0 = 30.0
OMEGA = 30.0
NUM_CLASSES = 4
MODS = ['t1n','t1c','t2w','t2f']
SEG = 'seg'

def find_cases(root: pathlib.Path):
    return [p for p in sorted(root.iterdir()) if p.is_dir() and any((p/f'{p.name}-{m}.nii.gz').exists() for m in MODS)]

def load_case(p: pathlib.Path):
    base = p.name
    mods = []
    for m in MODS:
        arr = nib.load(str(p/f'{base}-{m}.nii.gz')).get_fdata().astype(np.float32)
        mask = arr != 0
        if mask.any():
            mu = arr[mask].mean(); sigma = arr[mask].std() + 1e-6
            arr = (arr - mu)/sigma
        mods.append(arr)
    seg = nib.load(str(p/f'{base}-{SEG}.nii.gz')).get_fdata().astype(np.int16)
    return np.stack(mods,0), seg

def load_params_npz(path: pathlib.Path):
    if not path.exists():
        raise FileNotFoundError(f'Weights file not found: {path}')
    d = np.load(path)
    layers = []
    i = 0
    while f'W_{i}' in d:
        layers.append({'W': jnp.array(d[f'W_{i}']), 'b': jnp.array(d[f'b_{i}'])})
        i += 1
    if not layers:
        raise RuntimeError('No layers found in checkpoint')
    return layers

def apply_siren(params, x):
    h = x
    # First layer sine with omega_0 scaling
    h = jnp.sin(jnp.dot(h, params[0]['W']) + params[0]['b'])
    for layer in params[1:-1]:
        h = jnp.sin(jnp.dot(h, layer['W']) + layer['b'])
    return jnp.dot(h, params[-1]['W']) + params[-1]['b']

def predict_volume(params, mods, seg, chunk=CHUNK):
    M,H,W,D = mods.shape
    xs,ys,zs = np.arange(H), np.arange(W), np.arange(D)
    grid = np.stack(np.meshgrid(xs,ys,zs, indexing='ij'), axis=-1).reshape(-1,3)
    intens = mods.transpose(1,2,3,0).reshape(-1,M)
    norm = (grid / np.array([H-1,W-1,D-1]))*2 - 1
    preds=[]
    for i in range(0, len(grid), chunk):
        x_in = jnp.concatenate([jnp.array(norm[i:i+chunk]), jnp.array(intens[i:i+chunk])], axis=-1)
        logits = apply_siren(params, x_in)
        preds.append(np.array(jnp.argmax(logits, axis=-1), dtype=np.int16))
    pred = np.concatenate(preds,0).reshape(H,W,D)
    return pred, seg

def dice_scores(pred, true, C=NUM_CLASSES):
    out={}
    for c in range(C):
        p=(pred==c); t=(true==c); inter=(p&t).sum(); denom=p.sum()+t.sum()
        out[c] = (2*inter+1e-6)/(denom+1e-6) if denom>0 else np.nan
    return out

def slice_metrics(pred2d, true2d):
    # macro dice over present classes
    ds=[]
    for c in range(NUM_CLASSES):
        p=(pred2d==c); t=(true2d==c); denom=p.sum()+t.sum()
        if denom>0: inter=(p&t).sum(); ds.append((2*inter+1e-6)/(denom+1e-6))
    dice_macro = float(np.mean(ds)) if ds else float('nan')
    mse = np.mean((pred2d.astype(np.float32)-true2d.astype(np.float32))**2)
    psnr = float('inf') if mse<=1e-12 else float(10*np.log10((3*3)/(mse+1e-12)))
    return dice_macro, psnr

def visualize_modalities(mods, true, pred, z):
    M=mods.shape[0]
    fig, axes = plt.subplots(2, M, figsize=(3*M,6))
    if M==1: axes=np.array([[axes[0]],[axes[1]]])
    dice_macro, psnr = slice_metrics(pred[:,:,z], true[:,:,z])
    for m in range(M):
        ax_gt=axes[0,m]; ax_pr=axes[1,m]
        ax_gt.imshow(mods[m,:,:,z], cmap='gray'); ax_gt.imshow(true[:,:,z], cmap='tab10', alpha=0.35, vmin=0, vmax=3)
        ax_gt.set_title(f'Mod {m} + GT', fontsize=10); ax_gt.axis('off')
        ax_pr.imshow(mods[m,:,:,z], cmap='gray'); ax_pr.imshow(pred[:,:,z], cmap='tab10', alpha=0.35, vmin=0, vmax=3)
        ax_pr.set_title(f'Mod {m} + Pred', fontsize=10)
        ax_pr.text(0.01,0.99,f'Dice {dice_macro:.3f} PSNR {psnr:.2f} dB', transform=ax_pr.transAxes, ha='left', va='top', fontsize=8, color='yellow', bbox=dict(boxstyle='round', fc='black', alpha=0.5, pad=0.4))
        ax_pr.axis('off')
    fig.tight_layout(); plt.show()

# Load weights and infer dims
params = load_params_npz(WEIGHTS_PATH)
first_in = params[0]['W'].shape[0]
expected_in = 3 + len(MODS)
if first_in != expected_in:
    print(f'Warning: first layer input dim {first_in} != expected {expected_in}. Check modalities or training config.')
print('Loaded', len(params), 'layers')
cases = find_cases(DATA_ROOT)
print('Found', len(cases), 'cases')

# Precompute predictions lazily and cache
pred_cache = {}  # case_path -> (mods, seg, pred)
def get_prediction(case_path):
    if case_path in pred_cache: return pred_cache[case_path]
    mods, seg = load_case(case_path)
    pred, seg_true = predict_volume(params, mods, seg)
    pred_cache[case_path] = (mods, seg_true, pred)
    return pred_cache[case_path]

out = widgets.Output()
if not cases:
    display(widgets.HTML('<b>No cases found.</b>'))
else:
    dd = widgets.Dropdown(options=[(c.name, c) for c in cases], description='Case:')
    mods, seg, pred = get_prediction(dd.value)
    z_slider = widgets.IntSlider(min=0, max=int(pred.shape[2]-1), value=int(pred.shape[2]//2), description='Slice z')
    def render(z):
        with out:
            clear_output(wait=True)
            visualize_modalities(mods, seg, pred, int(z))
    def on_case(change):
        global mods, seg, pred
        mods, seg, pred = get_prediction(change['new'])
        z_slider.max = int(pred.shape[2]-1)
        z_slider.value = int(pred.shape[2]//2)
        render(z_slider.value)
    def on_z(change): render(change['new'])
    dd.observe(on_case, names='value'); z_slider.observe(on_z, names='value')
    display(widgets.VBox([dd, z_slider, out]))
    render(z_slider.value)


FileNotFoundError: Weights file not found: artifacts/inr_siren_brats23.npz