# SIREN INR for BraTS 2023 with Optional CNN Distillation (JAX)

This notebook implements a SIREN-based implicit neural representation (INR) for voxel-wise segmentation. It supports standard CE/Dice training and optional knowledge distillation from a CNN teacher (e.g., nnU-Net) by loading per-voxel soft targets/logits if available.

Notes:
- SIREN uses sinusoidal activations enabling high-frequency signal fitting.
- We combine coordinates + multi-modal intensities as input.
- For distillation, provide teacher per-voxel logits as NIfTI or NPZ tensors aligned with the target volume.


In [None]:
import os, math, pathlib, time
from typing import Dict, Tuple, Optional
import numpy as np
import nibabel as nib
import jax, jax.numpy as jnp
import optax

DATA_ROOT = pathlib.Path('../data/BraTS-2023')
SAVE_PATH = pathlib.Path('../artifacts/inr_siren_brats23.npz')
CASE_LIMIT = 16
GLOBAL_BATCH = 32768
MICRO_BATCH = 2048
SIREN_HIDDEN = [256,256,256,256]
OMEGA_0 = 30.0  # SIREN frequency scale for first layer
OMEGA = 30.0    # subsequent layers
LR = 2e-3; MIN_LR=2e-4; WARMUP=50; STEPS=600
NUM_CLASSES = 4
RNG_SEED = 123
KD_WEIGHT = 0.3  # weight for distillation (0 disables KD)
TEACHER_DIR = pathlib.Path('artifacts/teacher_logits')  # optional

key = jax.random.PRNGKey(RNG_SEED)
ACCUM = (GLOBAL_BATCH + MICRO_BATCH - 1)//MICRO_BATCH
print('Devices:', jax.devices(), 'accum', ACCUM)


Devices: [CpuDevice(id=0)] accum 16


In [2]:
MODS = ['t1n','t1c','t2w','t2f']; SEG='seg'
def find_cases(root: pathlib.Path):
    return [p for p in sorted(root.iterdir()) if p.is_dir() and any((p/f'{p.name}-{m}.nii.gz').exists() for m in MODS)]
def load_case(p: pathlib.Path):
    base=p.name
    mods=[]
    for m in MODS:
        a = nib.load(str(p/f'{base}-{m}.nii.gz')).get_fdata().astype(np.float32)
        mask = a!=0;
        if mask.any(): a=(a - a[mask].mean())/(a[mask].std()+1e-6)
        mods.append(a)
    seg = nib.load(str(p/f'{base}-{SEG}.nii.gz')).get_fdata().astype(np.int16)
    return np.stack(mods,0), seg
cases = find_cases(DATA_ROOT)[:CASE_LIMIT]
mods0, seg0 = load_case(cases[0])
print('Vol shape:', mods0.shape, seg0.shape)


FileNotFoundError: [Errno 2] No such file or directory: 'data/BraTS-2023'

## SIREN MLP
SIREN uses sine activations with special initialization.

In [None]:
def siren_init_first(key, in_dim, out_dim, omega0=30.0):
    # Uniform in [-1/in_dim, 1/in_dim]
    limit = 1.0/in_dim
    W = jax.random.uniform(key, (in_dim, out_dim), minval=-limit, maxval=limit)
    b = jnp.zeros((out_dim,))
    return {'W': W, 'b': b}
def siren_init(key, in_dim, out_dim, omega=30.0):
    # Recommended: U(-sqrt(6/in)/omega, sqrt(6/in)/omega)
    limit = math.sqrt(6/in_dim)/omega
    W = jax.random.uniform(key, (in_dim, out_dim), minval=-limit, maxval=limit)
    b = jnp.zeros((out_dim,))
    return {'W': W, 'b': b}
def init_siren(key, in_dim, hidden, out_dim, omega0=30.0, omega=30.0):
    params=[]
    dims=[in_dim]+hidden+[out_dim]
    # first layer special
    key,k1=jax.random.split(key)
    params.append(siren_init_first(k1, dims[0], dims[1], omega0))
    # hidden
    for i in range(1,len(dims)-2):
        key,ki=jax.random.split(key)
        params.append(siren_init(ki, dims[i], dims[i+1], omega))
    # last linear
    key,kl=jax.random.split(key)
    W = jax.random.uniform(kl, (dims[-2], dims[-1]), minval=-1e-4, maxval=1e-4)
    params.append({'W': W, 'b': jnp.zeros((dims[-1],))})
    return key, params
def apply_siren(params, x, omega0=30.0, omega=30.0):
    h=x
    h = jnp.sin(jnp.dot(h, params[0]['W']) + params[0]['b']) * 1.0  # first
    for layer in params[1:-1]:
        h = jnp.sin(jnp.dot(h, layer['W']) + layer['b'])
    return jnp.dot(h, params[-1]['W']) + params[-1]['b']


## Sampling & Distillation Hooks
Distillation expects optional teacher logits per voxel (C,H,W,D). Provide files under `TEACHER_DIR` named `<case>-teacher.npz` with array `logits`.

In [None]:
def load_teacher_logits(case_path: pathlib.Path) -> Optional[np.ndarray]:
    name = case_path.name
    fp = TEACHER_DIR/f'{name}-teacher.npz'
    if fp.exists():
        d = np.load(fp)
        if 'logits' in d: return d['logits']  # (C,H,W,D)
    return None

def sample_batch_np(key, cases, batch_size):
    # Random case, uniform voxel sampling
    key, kc, kx, ky, kz = jax.random.split(key, 5)
    ci = int(jax.random.randint(kc, (), 0, len(cases)))
    mods, seg = load_case(cases[ci])
    H,W,D = seg.shape
    xs = np.array(jax.random.randint(kx, (batch_size,), 0, H)); ys = np.array(jax.random.randint(ky, (batch_size,), 0, W)); zs = np.array(jax.random.randint(kz, (batch_size,), 0, D))
    intens = mods[:, xs, ys, zs].transpose(1,0)
    labels = seg[xs, ys, zs].astype(np.int32)
    coords = np.stack([xs, ys, zs], axis=-1); norm = (coords / np.array([H-1,W-1,D-1]))*2-1
    # teacher (optional)
    tlog = load_teacher_logits(cases[ci])
    teacher = None
    if tlog is not None and tlog.shape[0]==NUM_CLASSES:
        teacher = tlog[:, xs, ys, zs].transpose(1,0).astype(np.float32)  # (B,C)
    return jnp.array(norm), jnp.array(intens), jnp.array(labels), (teacher, ci)


## Training
Loss = CE(labels) + KD_WEIGHT * KL(student||teacher) when teacher is available.

In [None]:
schedule = optax.warmup_cosine_decay_schedule(0.0, LR, WARMUP, max(1,STEPS-WARMUP), MIN_LR)
opt = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(schedule))

# Infer in-dim
in_dim = 3 + mods0.shape[0]  # no Fourier here; SIREN learns high-freq via sine
key, params = init_siren(key, in_dim, SIREN_HIDDEN, NUM_CLASSES, omega0=OMEGA_0, omega=OMEGA)
opt_state = opt.init(params)

def forward(params, coords, intens):
    x = jnp.concatenate([coords, intens], axis=-1)
    return apply_siren(params, x, omega0=OMEGA_0, omega=OMEGA)

def loss_batch(params, coords, intens, labels, tsoft=None):
    logits = forward(params, coords, intens)
    y = jax.nn.one_hot(labels, NUM_CLASSES)
    ce = optax.softmax_cross_entropy(logits, y).mean()
    if tsoft is not None and KD_WEIGHT>0:
        # teacher logits -> soft probs
        tp = jax.nn.softmax(jnp.array(tsoft), axis=-1)
        sp = jax.nn.log_softmax(logits, axis=-1)
        kd = -(tp * sp).sum(-1).mean()  # cross-entropy with teacher probs
        return (1-KD_WEIGHT)*ce + KD_WEIGHT*kd
    return ce

grad_fn = jax.jit(jax.value_and_grad(loss_batch))

loss_hist=[]
for step in range(1, STEPS+1):
    grads_acc = [ {'W': jnp.zeros_like(p['W']), 'b': jnp.zeros_like(p['b'])} for p in params ]
    loss_acc = 0.0
    for _ in range(ACCUM):
        key, sub = jax.random.split(key)
        c,i,l, (tsoft,_) = sample_batch_np(sub, cases, MICRO_BATCH)
        val, grads = grad_fn(params, c, i, l, tsoft)
        loss_acc += float(val)
        grads_acc = [ {'W': ga['W']+g['W'], 'b': ga['b']+g['b']} for ga,g in zip(grads_acc, grads) ]
    grads_mean = [ {'W': g['W']/ACCUM, 'b': g['b']/ACCUM} for g in grads_acc ]
    updates, opt_state = opt.update(grads_mean, opt_state, params)
    params = optax.apply_updates(params, updates)
    loss_hist.append(loss_acc/ACCUM)
    if step % 25 == 0 or step==1:
        print(f'step {step}/{STEPS} loss={loss_hist[-1]:.4f}')

# Save
flat={};
for k,(layer) in enumerate(params): flat[f'W_{k}']=np.array(layer['W']); flat[f'b_{k}']=np.array(layer['b'])
np.savez_compressed(SAVE_PATH, **flat)
print('Saved', SAVE_PATH)


## Evaluation
Reconstruct and compute Dice on one case (no teacher needed).

In [None]:
def predict_volume(params, mods, seg, chunk=120000):
    M,H,W,D = mods.shape
    xs,ys,zs = np.arange(H),np.arange(W),np.arange(D)
    grid = np.stack(np.meshgrid(xs,ys,zs, indexing='ij'),axis=-1).reshape(-1,3)
    intens = mods.transpose(1,2,3,0).reshape(-1,M)
    norm = (grid/np.array([H-1,W-1,D-1]))*2-1
    outs=[]
    for i in range(0,len(grid),chunk):
        logits = forward(params, jnp.array(norm[i:i+chunk]), jnp.array(intens[i:i+chunk]))
        outs.append(np.array(jnp.argmax(logits,axis=-1), dtype=np.int16))
    pred = np.concatenate(outs,0).reshape(H,W,D)
    return pred, seg
def dice_score(pred,true,C=NUM_CLASSES):
    s={}
    for c in range(C):
        p=(pred==c); t=(true==c); inter=(p&t).sum(); denom=p.sum()+t.sum();
        s[c]= (2*inter+1e-6)/(denom+1e-6) if denom>0 else np.nan
    return s
pred,true = predict_volume(params, mods0, seg0)
print('Dice:', dice_score(pred,true))
