In [2]:
# ==== SECTION 0: Imports & Paths ====
import os, re, glob, csv, json, time, random, gc
from pathlib import Path
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from functools import partial

# Configure matplotlib for Jupyter
%matplotlib inline
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['figure.dpi'] = 100

import jax
import jax.numpy as jnp
from jax import random as jrandom, jit, vmap, grad
import flax.linen as nn
from flax.training import train_state
import optax

print("JAX devices:", jax.devices())
print("JAX default backend:", jax.default_backend())

# Paths
BRA_TS_DIR = "../data/BraTS-2023"
MU_DIR     = "../data/MU-Glioma-Post"
CKPT_DIR   = "../checkpoints"
ART_DIR    = "../artifacts"
PRED_MU    = "../pred_mu"
DISTILL_DIR= "../distill"

for d in [CKPT_DIR, ART_DIR, PRED_MU, DISTILL_DIR]:
    os.makedirs(d, exist_ok=True)


JAX devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9), CpuDevice(id=10), CpuDevice(id=11)]
JAX default backend: cpu


In [2]:
# ==== SECTION 1: Build BraTS manifest ====
root = BRA_TS_DIR
cases = sorted([p for p in glob.glob(os.path.join(root, "*")) if os.path.isdir(p)])

def pick(files, pattern):
    rx = re.compile(pattern, re.I)
    for f in files:
        if rx.search(os.path.basename(f)):
            return f
    return ""

rows = []
for case in cases:
    files = glob.glob(os.path.join(case, "*.nii*"))
    rows.append({
        "id":   os.path.basename(case),
        "t1":   pick(files, r"t1n|t1\.nii"),
        "t1ce": pick(files, r"t1c|t1ce"),
        "t2":   pick(files, r"t2(?!f)|t2w"),
        "flair":pick(files, r"t2f|flair"),
        "mask": pick(files, r"seg|mask")
    })

man_bra = os.path.join(BRA_TS_DIR, "manifest.csv")
with open(man_bra, "w", newline="") as fp:
    w = csv.DictWriter(fp, fieldnames=["id","t1","t1ce","t2","flair","mask"])
    w.writeheader(); w.writerows(rows)

print(f"Wrote {len(rows)} rows -> {man_bra}")


Wrote 1251 rows -> ../data/BraTS-2023/manifest.csv


In [3]:
import os
import glob
import csv
import re
from typing import List

def pick(files: List[str], pattern: str, flags=re.IGNORECASE) -> str:
    """Return first file matching pattern (on basename). Returns empty string if none."""
    rx = re.compile(pattern, flags)
    for f in files:
        name = os.path.basename(f)
        if rx.search(name):
            return os.path.relpath(f, root)
    return ""

# Auto-detect correct root if dataset is nested (e.g., MU-Glioma-Post/MU-Glioma-Post/...)
root = MU_DIR
patients = sorted([p for p in glob.glob(os.path.join(root, "PatientID_*")) if os.path.isdir(p)])
if not patients:
    # look one level deeper for a folder that looks like the dataset (name starts with MU- or similar)
    candidates = [d for d in glob.glob(os.path.join(MU_DIR, "*")) if os.path.isdir(d) and os.path.basename(d).lower().startswith("mu")]
    if candidates:
        root = candidates[0]
        patients = sorted([p for p in glob.glob(os.path.join(root, "PatientID_*")) if os.path.isdir(p)])

rows = []
for pat_dir in patients:
    tps = sorted([tp for tp in glob.glob(os.path.join(pat_dir, "Timepoint_*")) if os.path.isdir(tp)])
    for tp in tps:
        files = glob.glob(os.path.join(tp, "**", "*.nii*"), recursive=True)
        row = {
            "id":   os.path.relpath(tp, root).replace(os.sep, "_"),
            "t1":   pick(files, r"(^|[^a-z0-9])t1(_n|n?1)?(\.|_|$)"),
            "t1ce": pick(files, r"t1c|t1ce"),
            "t2":   pick(files, r"(^|[^a-z0-9])t2(?!f)(\.|_|$)|t2w"),
            "flair":pick(files, r"t2f|flair"),
            "mask": pick(files, r"seg|mask|tumorMask|_seg")
        }
        rows.append(row)

man_mu = os.path.join(MU_DIR, "manifest.csv")
with open(man_mu, "w", newline="") as fp:
    w = csv.DictWriter(fp, fieldnames=["id","t1","t1ce","t2","flair","mask"])
    w.writeheader(); w.writerows(rows)

print(f"Root used: {root}")
print(f"Wrote {len(rows)} rows -> {man_mu}")
if len(rows) == 0:
    print("Warning: no timepoints found. Check that patient and Timepoint_* directories contain .nii/.nii.gz files.")

Root used: ../data/MU-Glioma-Post/MU-Glioma-Post
Wrote 596 rows -> ../data/MU-Glioma-Post/manifest.csv


In [4]:
# ==== SECTION 3: Train/Val split for BraTS ====
MAX_CASES = 512

rng = random.Random(42)
with open(man_bra) as fp:
    all_rows = [r for r in csv.DictReader(fp) if all(r.get(k) for k in ["t1","t1ce","t2","flair","mask"])]

rng.shuffle(all_rows)
all_rows = all_rows[:MAX_CASES]

n_total = len(all_rows)
n_val = int(0.2 * n_total)
val_rows = all_rows[:n_val]
train_rows = all_rows[n_val:]

def write_rows(rows, path):
    with open(path, "w", newline="") as fp:
        w = csv.DictWriter(fp, fieldnames=["id","t1","t1ce","t2","flair","mask"])
        w.writeheader(); w.writerows(rows)

train_csv = os.path.join(BRA_TS_DIR, "train.csv")
val_csv   = os.path.join(BRA_TS_DIR, "val.csv")
write_rows(train_rows, train_csv)
write_rows(val_rows, val_csv)

print(f"Train: {len(train_rows)} | Val: {len(val_rows)}")


Train: 410 | Val: 102


In [None]:
# ==== SECTION 4: Data Loading ====
def load_modalities(row):
    vols = [nib.load(row[k]).get_fdata().astype(np.float32) for k in ["t1","t1ce","t2","flair"]]
    vols = [(v - v.mean()) / (v.std() + 1e-6) for v in vols]
    vol = np.stack(vols, axis=0)  # [C, X, Y, Z]
    mask = nib.load(row["mask"]).get_fdata().astype(np.int64)
    mask = (mask > 0).astype(np.int64)
    return vol, mask

def create_slice_dataset(manifest_path, max_slices_per_case=8, background_drop=0.7, seed=0):
    rng = random.Random(seed)
    with open(manifest_path) as fp:
        rows = list(csv.DictReader(fp))
    
    items = []
    for row in rows:
        vol, msk = load_modalities(row)
        z = vol.shape[-1]
        idxs = np.linspace(0, z-1, num=min(max_slices_per_case, z), dtype=int)
        for iz in idxs:
            x = vol[..., iz]
            y = msk[..., iz]
            if (y.sum() == 0) and (rng.random() < background_drop):
                continue
            items.append((x, y))
    rng.shuffle(items)
    
    # Convert to numpy arrays
    X = np.array([x for x, _ in items], dtype=np.float32)
    Y = np.array([y for _, y in items], dtype=np.int64)
    X = np.clip(X, -5, 5)
    return X, Y

train_X, train_Y = create_slice_dataset(train_csv, seed=0)
val_X, val_Y = create_slice_dataset(val_csv, seed=1)

print(f"Train: {train_X.shape} | Val: {val_X.shape}")


In [None]:
# ==== SECTION 5: UNet2D in Flax ====
class DoubleConv(nn.Module):
    out_ch: int
    
    @nn.compact
    def __call__(self, x, train: bool = True):
        x = nn.Conv(self.out_ch, (3, 3), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        x = nn.Conv(self.out_ch, (3, 3), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.relu(x)
        return x

class UNet2D(nn.Module):
    n_classes: int = 2
    base: int = 16
    
    @nn.compact
    def __call__(self, x, train: bool = True):
        # Encoder
        c1 = DoubleConv(self.base)(x, train)
        p1 = nn.max_pool(c1, (2, 2), strides=(2, 2))
        
        c2 = DoubleConv(self.base * 2)(p1, train)
        p2 = nn.max_pool(c2, (2, 2), strides=(2, 2))
        
        c3 = DoubleConv(self.base * 4)(p2, train)
        
        # Decoder
        u2 = jax.image.resize(c3, shape=(c3.shape[0], c2.shape[1], c2.shape[2], c3.shape[3]), 
                              method='nearest')
        u2 = nn.Conv(self.base * 2, (2, 2))(u2)
        u2 = jnp.concatenate([u2, c2], axis=-1)
        c2_up = DoubleConv(self.base * 2)(u2, train)
        
        u1 = jax.image.resize(c2_up, shape=(c2_up.shape[0], c1.shape[1], c1.shape[2], c2_up.shape[3]),
                              method='nearest')
        u1 = nn.Conv(self.base, (2, 2))(u1)
        u1 = jnp.concatenate([u1, c1], axis=-1)
        c1_up = DoubleConv(self.base)(u1, train)
        
        out = nn.Conv(self.n_classes, (1, 1))(c1_up)
        return out


In [None]:
# ==== SECTION 6: Training Setup ====
def create_train_state(rng, learning_rate=1e-3):
    model = UNet2D()
    # NHWC format for JAX
    dummy_input = jnp.ones((1, 240, 240, 4))
    variables = model.init(rng, dummy_input, train=True)
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        params=variables['params'],
        tx=tx
    ), variables.get('batch_stats', {})

@jit
def compute_dice(logits, target):
    pred = jnp.argmax(logits, axis=-1)
    pred_bin = (pred == 1).astype(jnp.float32)
    tgt_bin = (target == 1).astype(jnp.float32)
    inter = jnp.sum(pred_bin * tgt_bin, axis=(1, 2))
    union = jnp.sum(pred_bin, axis=(1, 2)) + jnp.sum(tgt_bin, axis=(1, 2))
    return jnp.mean((2 * inter + 1e-6) / (union + 1e-6))

@jit
def train_step(state, batch_stats, batch_x, batch_y, rng):
    def loss_fn(params):
        variables = {'params': params, 'batch_stats': batch_stats}
        logits, updates = state.apply_fn(
            variables, batch_x, train=True, 
            mutable=['batch_stats']
        )
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits, batch_y
        ).mean()
        return loss, (logits, updates)
    
    (loss, (logits, updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    dice = compute_dice(logits, batch_y)
    return state, updates['batch_stats'], loss, dice

@jit
def eval_step(state, batch_stats, batch_x, batch_y):
    variables = {'params': state.params, 'batch_stats': batch_stats}
    logits = state.apply_fn(variables, batch_x, train=False)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch_y).mean()
    dice = compute_dice(logits, batch_y)
    return loss, dice


In [None]:
# ==== SECTION 7: Training Loop ====
rng = jrandom.PRNGKey(42)
rng, init_rng = jrandom.split(rng)

state, batch_stats = create_train_state(init_rng)

EPOCHS = 30
BATCH_SIZE = 4
n_train = len(train_X)
n_val = len(val_X)

history = {"tr_loss": [], "tr_dice": [], "va_loss": [], "va_dice": []}

for ep in range(1, EPOCHS + 1):
    # Training
    t0 = time.time()
    perm = np.random.permutation(n_train)
    train_X_shuffled = train_X[perm]
    train_Y_shuffled = train_Y[perm]
    
    tr_loss = tr_dice = n_batches = 0.0
    for i in range(0, n_train, BATCH_SIZE):
        batch_x = jnp.array(train_X_shuffled[i:i+BATCH_SIZE])
        batch_y = jnp.array(train_Y_shuffled[i:i+BATCH_SIZE])
        
        # Convert from NCHW to NHWC
        batch_x = jnp.transpose(batch_x, (0, 2, 3, 1))
        
        rng, step_rng = jrandom.split(rng)
        state, batch_stats, loss, dice = train_step(state, batch_stats, batch_x, batch_y, step_rng)
        
        tr_loss += loss
        tr_dice += dice
        n_batches += 1
    
    tr_loss /= n_batches
    tr_dice /= n_batches
    
    # Validation
    va_loss = va_dice = n_val_batches = 0.0
    for i in range(0, n_val, BATCH_SIZE):
        batch_x = jnp.array(val_X[i:i+BATCH_SIZE])
        batch_y = jnp.array(val_Y[i:i+BATCH_SIZE])
        
        batch_x = jnp.transpose(batch_x, (0, 2, 3, 1))
        
        loss, dice = eval_step(state, batch_stats, batch_x, batch_y)
        va_loss += loss
        va_dice += dice
        n_val_batches += 1
    
    va_loss /= n_val_batches
    va_dice /= n_val_batches
    
    history["tr_loss"].append(float(tr_loss))
    history["tr_dice"].append(float(tr_dice))
    history["va_loss"].append(float(va_loss))
    history["va_dice"].append(float(va_dice))
    
    # Save checkpoint
    ck = os.path.join(CKPT_DIR, f"unet2d_jax_ep{ep:02d}.npz")
    np.savez(ck, params=state.params, batch_stats=batch_stats)
    
    print(f"Epoch {ep:02d} | train loss {tr_loss:.4f} dice {tr_dice:.3f} | "
          f"val loss {va_loss:.4f} dice {va_dice:.3f} | {time.time()-t0:.1f}s")

# Plot UNet training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
ax1.plot(range(1, EPOCHS+1), history["tr_loss"], 'o-', label="Train Loss", linewidth=2, markersize=8)
ax1.plot(range(1, EPOCHS+1), history["va_loss"], 's-', label="Val Loss", linewidth=2, markersize=8)
ax1.set_xlabel("Epoch", fontsize=12)
ax1.set_ylabel("Cross-Entropy Loss", fontsize=12)
ax1.set_title("UNet Training & Validation Loss", fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_xticks(range(1, EPOCHS+1))

# Dice plot
ax2.plot(range(1, EPOCHS+1), history["tr_dice"], 'o-', label="Train Dice", linewidth=2, markersize=8, color='green')
ax2.plot(range(1, EPOCHS+1), history["va_dice"], 's-', label="Val Dice", linewidth=2, markersize=8, color='orange')
ax2.set_xlabel("Epoch", fontsize=12)
ax2.set_ylabel("Dice Score", fontsize=12)
ax2.set_title("UNet Training & Validation Dice", fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_xticks(range(1, EPOCHS+1))
ax2.set_ylim([0, 1])

plt.tight_layout()
plt.savefig(os.path.join(ART_DIR, "unet_training_curves_jax.png"), dpi=150, bbox_inches='tight')
plt.show()

# Print summary statistics
print("\n" + "="*50)
print("UNet Training Summary")
print("="*50)
print(f"Best Val Loss:  {min(history['va_loss']):.4f} (Epoch {history['va_loss'].index(min(history['va_loss']))+1})")
print(f"Best Val Dice:  {max(history['va_dice']):.4f} (Epoch {history['va_dice'].index(max(history['va_dice']))+1})")
print(f"Final Val Loss: {history['va_loss'][-1]:.4f}")
print(f"Final Val Dice: {history['va_dice'][-1]:.4f}")
print("="*50 + "\n")


In [None]:
# ==== SECTION 8: Qualitative Validation ====
batch_x = jnp.array(val_X[:4])
batch_y = val_Y[:4]
batch_x_hwc = jnp.transpose(batch_x, (0, 2, 3, 1))

variables = {'params': state.params, 'batch_stats': batch_stats}
logits = state.apply_fn(variables, batch_x_hwc, train=False)
pred = jnp.argmax(logits, axis=-1)

k = min(4, len(batch_x))
fig, axes = plt.subplots(k, 3, figsize=(9, 3*k))
for i in range(k):
    s = np.array(batch_x[i, 3])  # FLAIR channel
    axes[i,0].imshow(s, cmap="gray")
    axes[i,0].set_title("FLAIR")
    axes[i,1].imshow(batch_y[i], cmap="magma")
    axes[i,1].set_title("GT")
    axes[i,2].imshow(np.array(pred[i]), cmap="magma")
    axes[i,2].set_title("Pred")
    for j in range(3):
        axes[i,j].axis("off")
plt.tight_layout()
plt.savefig(os.path.join(ART_DIR, "val_quicklook_jax.png"), dpi=120)
plt.show()


In [None]:
# ==== SECTION 9: Inference on MU ====
def load_case_3d(row):
    vols = [nib.load(row[k]).get_fdata().astype(np.float32) for k in ["t1","t1ce","t2","flair"]]
    vols = [(v - v.mean())/(v.std()+1e-6) for v in vols]
    vol = np.stack(vols, axis=0)  # [C,X,Y,Z]
    aff = nib.load(row["t1"]).affine
    gt = None
    if row.get("mask"):
        try:
            gt = nib.load(row["mask"]).get_fdata().astype(np.int64)
            gt = (gt>0).astype(np.int64)
        except:
            gt = None
    return vol, aff, gt

# Load best checkpoint
last_ck = sorted(glob.glob(os.path.join(CKPT_DIR, "unet2d_jax_ep*.npz")))[-1]
ckpt = np.load(last_ck, allow_pickle=True)
state_params = ckpt['params'].item()
state_batch_stats = ckpt['batch_stats'].item()

# Read MU manifest
with open(man_mu) as fp:
    mu_rows = [r for r in csv.DictReader(fp) if all(r.get(k) for k in ["t1","t1ce","t2","flair"])]

@jit
def predict_slice(params, batch_stats, x):
    variables = {'params': params, 'batch_stats': batch_stats}
    model = UNet2D()
    logits = model.apply(variables, x, train=False)
    return jnp.argmax(logits, axis=-1)

metrics = []
for row in mu_rows:
    vol, aff, gt = load_case_3d(row)
    X, Y, Z = vol.shape[1:]
    pred_vol = np.zeros((X, Y, Z), dtype=np.uint8)
    
    for z in range(Z):
        x = jnp.array(vol[..., z].copy())[None, ...]  # [1,C,H,W]
        x = jnp.transpose(x, (0, 2, 3, 1))  # [1,H,W,C]
        pred = predict_slice(state_params, state_batch_stats, x)
        pred_vol[..., z] = np.array(pred[0]).astype(np.uint8)
    
    out_path = os.path.join(PRED_MU, f"{row['id']}_pred_jax.nii.gz")
    nib.save(nib.Nifti1Image(pred_vol, aff), out_path)
    print("Wrote:", out_path)
    
    if gt is not None and gt.shape == pred_vol.shape:
        inter = ((pred_vol==1)&(gt==1)).sum()
        union = (pred_vol==1).sum() + (gt==1).sum()
        dice = (2*inter + 1e-6)/(union + 1e-6)
        metrics.append({"id": row["id"], "dice": float(dice)})

if metrics:
    import pandas as pd
    df = pd.DataFrame(metrics)
    csv_path = os.path.join(ART_DIR, "mu_metrics_jax.csv")
    df.to_csv(csv_path, index=False)
    print("Saved:", csv_path, " | mean dice:", df["dice"].mean())


In [None]:
# ==== SECTION 10: Distillation Sampling ====
os.makedirs(DISTILL_DIR, exist_ok=True)

with open(train_csv) as fp:
    tr_rows = list(csv.DictReader(fp))

target_samples_per_case = 100_000
shard_size = 1_000_000

acc_x = []; acc_y = []; acc_z = []; acc_c = []
shard_id = 0

def flush_shard():
    global acc_x, acc_y, acc_z, acc_c, shard_id
    if not acc_x: return
    path = os.path.join(DISTILL_DIR, f"brats_samples_jax_{shard_id:03d}.npz")
    np.savez_compressed(path,
        x=np.array(acc_x, dtype=np.float32),
        y=np.array(acc_y, dtype=np.float32),
        z=np.array(acc_z, dtype=np.float32),
        c=np.array(acc_c, dtype=np.int64))
    print("Saved shard:", path, " | n=", len(acc_x))
    acc_x.clear(); acc_y.clear(); acc_z.clear(); acc_c.clear()
    shard_id += 1

for i, row in enumerate(tr_rows):
    vol, aff, _ = load_case_3d(row)
    X, Y, Z = vol.shape[1:]
    pred_vol = np.zeros((X, Y, Z), dtype=np.uint8)
    
    for z in range(Z):
        x = jnp.array(vol[..., z].copy())[None, ...]
        x = jnp.transpose(x, (0, 2, 3, 1))
        pred = predict_slice(state_params, state_batch_stats, x)
        pred_vol[..., z] = np.array(pred[0]).astype(np.uint8)
    
    # Stratified sampling
    pos_idx = np.argwhere(pred_vol==1)
    neg_idx = np.argwhere(pred_vol==0)
    n_pos = min(len(pos_idx), target_samples_per_case//2)
    n_neg = min(len(neg_idx), target_samples_per_case - n_pos)
    
    if n_pos > 0:
        sel = pos_idx[np.random.choice(len(pos_idx), n_pos, replace=False)]
        acc_x.extend((sel[:,0]/(X-1)).tolist())
        acc_y.extend((sel[:,1]/(Y-1)).tolist())
        acc_z.extend((sel[:,2]/(Z-1)).tolist())
        acc_c.extend([1]*n_pos)
    
    if n_neg > 0:
        sel = neg_idx[np.random.choice(len(neg_idx), n_neg, replace=False)]
        acc_x.extend((sel[:,0]/(X-1)).tolist())
        acc_y.extend((sel[:,1]/(Y-1)).tolist())
        acc_z.extend((sel[:,2]/(Z-1)).tolist())
        acc_c.extend([0]*n_neg)
    
    if len(acc_x) >= shard_size:
        flush_shard()
    
    del vol, pred_vol; gc.collect()

flush_shard()
print("Distillation sampling done.")


In [None]:
# ==== SECTION 11: Train Tiny MLP on Distillation ====
class TinyMLP(nn.Module):
    hidden: tuple = (64, 64, 32)
    out_dim: int = 2
    
    @nn.compact
    def __call__(self, x):
        for h in self.hidden:
            x = nn.Dense(h)(x)
            x = nn.relu(x)
        x = nn.Dense(self.out_dim)(x)
        return x

# Load NPZ files
npz_files = sorted(glob.glob(os.path.join(DISTILL_DIR, "brats_samples_jax_*.npz")))
assert npz_files, "No distillation NPZ files found!"

# Simple 95/5 split
n = len(npz_files)
split = max(1, int(0.95*n))
train_files, dev_files = npz_files[:split], npz_files[split:]

# Load all data into memory (for simplicity with JAX)
def load_npz_files(files):
    X, Y = [], []
    for f in files:
        with np.load(f) as npz:
            x = np.stack([npz["x"], npz["y"], npz["z"]], axis=1).astype(np.float32)
            c = npz["c"].astype(np.int64)
            X.append(x)
            Y.append(c)
    return np.concatenate(X), np.concatenate(Y)

mlp_train_X, mlp_train_Y = load_npz_files(train_files)
mlp_dev_X, mlp_dev_Y = load_npz_files(dev_files)

print(f"MLP Train: {mlp_train_X.shape} | Dev: {mlp_dev_X.shape}")

# Initialize MLP
rng = jrandom.PRNGKey(42)
rng, init_rng = jrandom.split(rng)

mlp = TinyMLP()
mlp_params = mlp.init(init_rng, jnp.ones((1, 3)))['params']
mlp_tx = optax.adam(3e-3)
mlp_state = train_state.TrainState.create(
    apply_fn=mlp.apply,
    params=mlp_params,
    tx=mlp_tx
)

@jit
def mlp_train_step(state, batch_x, batch_y):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch_x)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch_y).mean()
        return loss
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

@jit
def mlp_eval(params, batch_x, batch_y):
    logits = mlp.apply({'params': params}, batch_x)
    pred = jnp.argmax(logits, axis=-1)
    acc = jnp.mean(pred == batch_y)
    return acc

BATCH_SIZE = 8192
mlp_hist = []

for ep in range(1, EPOCHS + 1):
    t0 = time.time()
    
    # Training
    perm = np.random.permutation(len(mlp_train_X))
    train_X_shuffled = mlp_train_X[perm]
    train_Y_shuffled = mlp_train_Y[perm]
    
    running_loss = 0.0
    n_batches = 0
    for i in range(0, len(train_X_shuffled), BATCH_SIZE):
        batch_x = jnp.array(train_X_shuffled[i:i+BATCH_SIZE])
        batch_y = jnp.array(train_Y_shuffled[i:i+BATCH_SIZE])
        mlp_state, loss = mlp_train_step(mlp_state, batch_x, batch_y)
        running_loss += loss
        n_batches += 1
    
    # Evaluation
    dev_acc = 0.0
    n_dev_batches = 0
    for i in range(0, len(mlp_dev_X), BATCH_SIZE):
        batch_x = jnp.array(mlp_dev_X[i:i+BATCH_SIZE])
        batch_y = jnp.array(mlp_dev_Y[i:i+BATCH_SIZE])
        acc = mlp_eval(mlp_state.params, batch_x, batch_y)
        dev_acc += acc
        n_dev_batches += 1
    
    avg_loss = running_loss / n_batches
    avg_acc = dev_acc / n_dev_batches
    
    print(f"[MLP] Epoch {ep:02d} | loss {avg_loss:.4f} | dev acc {avg_acc:.3f} | {time.time()-t0:.1f}s")
    mlp_hist.append((float(avg_loss), float(avg_acc)))

# Save MLP weights
mlp_ckpt = os.path.join(CKPT_DIR, "mlp_distill_jax.npz")
np.savez(mlp_ckpt, params=mlp_state.params)

# Plot MLP training curves
plt.figure()
plt.plot([h[0] for h in mlp_hist])
plt.title("MLP Loss")
plt.savefig(os.path.join(ART_DIR, "mlp_loss_jax.png"), dpi=120)
plt.show()

plt.figure()
plt.plot([h[1] for h in mlp_hist])
plt.title("MLP Dev Accuracy")
plt.savefig(os.path.join(ART_DIR, "mlp_acc_jax.png"), dpi=120)
plt.show()


In [None]:
# ==== SECTION 12: Export MLP weights (NPZ and JSON) ====

# Export as NPZ (efficient binary format)
def export_mlp_npz(params, output_path):
    """Export MLP parameters to NPZ format"""
    export_dict = {}
    layer_idx = 0
    
    for key in sorted(params.keys()):
        if 'Dense' in key:
            kernel = np.array(params[key]['kernel'])  # [in, out]
            bias = np.array(params[key]['bias'])
            export_dict[f'layer_{layer_idx}_weight'] = kernel.T  # [out, in]
            export_dict[f'layer_{layer_idx}_bias'] = bias
            layer_idx += 1
    
    # Add metadata as simple numeric arrays
    export_dict['n_layers'] = np.array([layer_idx], dtype=np.int32)
    export_dict['input_dim'] = np.array([3], dtype=np.int32)
    export_dict['output_dim'] = np.array([2], dtype=np.int32)
    # For activation, we'll just document it's relu (or store as int code)
    # 0=relu, 1=sigmoid, 2=tanh, etc.
    export_dict['activation_code'] = np.array([0], dtype=np.int32)  # 0 = relu
    
    np.savez_compressed(output_path, **export_dict)
    return layer_idx

out_npz = os.path.join(ART_DIR, "mlp_distill_jax.npz")
n_layers = export_mlp_npz(mlp_state.params, out_npz)
print(f"Exported MLP to NPZ: {out_npz} ({n_layers} layers)")

# Also export as JSON for compatibility (optional)
def export_dense_layer(params, layer_name):
    """Extract weights from a Flax Dense layer"""
    kernel = np.array(params[layer_name]['kernel'])  # [in, out]
    bias = np.array(params[layer_name]['bias'])
    return {"W": kernel.T.tolist(), "b": bias.tolist()}

layers = []
for key in sorted(mlp_state.params.keys()):
    if 'Dense' in key:
        layers.append(export_dense_layer(mlp_state.params, key))

data = {
    "layers": layers,
    "activation": "relu",
    "input_dim": 3,
    "output_dim": 2
}

out_json = os.path.join(ART_DIR, "mlp_distill_jax.json")
with open(out_json, "w") as f:
    json.dump(data, f, indent=2)
print(f"Exported MLP to JSON: {out_json}")

# Verify NPZ export
with np.load(out_npz) as npz_data:
    print("\nNPZ Contents:")
    print(f"  n_layers: {npz_data['n_layers'][0]}")
    print(f"  input_dim: {npz_data['input_dim'][0]}")
    print(f"  output_dim: {npz_data['output_dim'][0]}")
    print(f"  activation_code: {npz_data['activation_code'][0]} (0=relu)")
    print("\n  Layer weights:")
    for key in sorted(npz_data.keys()):
        if 'weight' in key or 'bias' in key:
            print(f"    {key}: shape={npz_data[key].shape}, dtype={npz_data[key].dtype}")

print("\n" + "="*60)
print("JAX Pipeline Complete!")
print("="*60)
print(f"✓ UNet2D trained for {EPOCHS} epochs")
print(f"✓ Final validation Dice: {history['va_dice'][-1]:.3f}")
print(f"✓ MLP distillation accuracy: {mlp_hist[-1][1]:.3f}")
print(f"✓ Checkpoints saved to: {CKPT_DIR}")
print(f"✓ Artifacts saved to: {ART_DIR}")
print(f"✓ MLP exported as NPZ and JSON")
print("="*60)

In [None]:
# ==== SECTION 13: Visualize MLP Predictions on 3D Grid ====

# Load a sample case for visualization
with open(train_csv) as fp:
    sample_row = list(csv.DictReader(fp))[0]

vol, aff, mask = load_case_3d(sample_row)
X, Y, Z = vol.shape[1:]

print(f"Visualizing case: {sample_row['id']}")
print(f"Volume shape: {vol.shape}")

# Create a 3D coordinate grid (normalized to [0,1])
x_coords = np.linspace(0, 1, X)
y_coords = np.linspace(0, 1, Y)
z_coords = np.linspace(0, 1, Z)

# Pick a few representative slices to visualize
slice_indices = [Z//4, Z//2, 3*Z//4]

fig, axes = plt.subplots(len(slice_indices), 4, figsize=(16, 4*len(slice_indices)))
if len(slice_indices) == 1:
    axes = axes.reshape(1, -1)

for idx, z_slice in enumerate(slice_indices):
    # Get FLAIR image for this slice
    flair_slice = vol[3, :, :, z_slice]  # FLAIR is channel 3
    
    # Get ground truth UNet prediction for this slice
    x_batch = jnp.array(vol[..., z_slice].copy())[None, ...]
    x_batch = jnp.transpose(x_batch, (0, 2, 3, 1))
    variables = {'params': state_params, 'batch_stats': state_batch_stats}
    unet_logits = state.apply_fn(variables, x_batch, train=False)
    unet_pred = np.array(jnp.argmax(unet_logits, axis=-1)[0])
    
    # Get MLP predictions for all voxels in this slice
    mlp_pred_slice = np.zeros((X, Y), dtype=np.uint8)
    
    # Create coordinate grid for this slice
    coords_2d = np.stack(np.meshgrid(x_coords, y_coords, indexing='ij'), axis=-1)
    z_val = z_coords[z_slice]
    
    # Add z coordinate
    coords_3d = np.concatenate([
        coords_2d,
        np.full((X, Y, 1), z_val)
    ], axis=-1)  # [X, Y, 3]
    
    # Flatten for batch prediction
    coords_flat = coords_3d.reshape(-1, 3)  # [X*Y, 3]
    
    # Predict in batches to avoid memory issues
    batch_size = 8192
    mlp_preds = []
    for i in range(0, len(coords_flat), batch_size):
        batch = jnp.array(coords_flat[i:i+batch_size])
        logits = mlp.apply({'params': mlp_state.params}, batch)
        preds = jnp.argmax(logits, axis=-1)
        mlp_preds.append(np.array(preds))
    
    mlp_pred_slice = np.concatenate(mlp_preds).reshape(X, Y).astype(np.uint8)
    
    # Plot comparison
    axes[idx, 0].imshow(flair_slice, cmap='gray')
    axes[idx, 0].set_title(f'FLAIR (Slice {z_slice}/{Z})', fontsize=12)
    axes[idx, 0].axis('off')
    
    axes[idx, 1].imshow(unet_pred, cmap='magma', vmin=0, vmax=1)
    axes[idx, 1].set_title('UNet Prediction', fontsize=12)
    axes[idx, 1].axis('off')
    
    axes[idx, 2].imshow(mlp_pred_slice, cmap='magma', vmin=0, vmax=1)
    axes[idx, 2].set_title('MLP Prediction (Distilled)', fontsize=12)
    axes[idx, 2].axis('off')
    
    # Difference map
    diff = np.abs(unet_pred.astype(int) - mlp_pred_slice.astype(int))
    axes[idx, 3].imshow(diff, cmap='hot', vmin=0, vmax=1)
    axes[idx, 3].set_title(f'Difference (Err: {diff.sum()}/{X*Y})', fontsize=12)
    axes[idx, 3].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(ART_DIR, "mlp_vs_unet_comparison.png"), dpi=150, bbox_inches='tight')
plt.show()

# Compute overall agreement statistics
print("\n" + "="*50)
print("MLP vs UNet Agreement Statistics")
print("="*50)

total_pixels = 0
total_agreement = 0
total_tumor_pixels = 0
tumor_agreement = 0

for z_slice in range(Z):
    x_batch = jnp.array(vol[..., z_slice].copy())[None, ...]
    x_batch = jnp.transpose(x_batch, (0, 2, 3, 1))
    variables = {'params': state_params, 'batch_stats': state_batch_stats}
    unet_logits = state.apply_fn(variables, x_batch, train=False)
    unet_pred = np.array(jnp.argmax(unet_logits, axis=-1)[0])
    
    coords_2d = np.stack(np.meshgrid(x_coords, y_coords, indexing='ij'), axis=-1)
    z_val = z_coords[z_slice]
    coords_3d = np.concatenate([coords_2d, np.full((X, Y, 1), z_val)], axis=-1)
    coords_flat = coords_3d.reshape(-1, 3)
    
    mlp_preds = []
    for i in range(0, len(coords_flat), batch_size):
        batch = jnp.array(coords_flat[i:i+batch_size])
        logits = mlp.apply({'params': mlp_state.params}, batch)
        preds = jnp.argmax(logits, axis=-1)
        mlp_preds.append(np.array(preds))
    
    mlp_pred_slice = np.concatenate(mlp_preds).reshape(X, Y).astype(np.uint8)
    
    agreement = (unet_pred == mlp_pred_slice)
    total_pixels += X * Y
    total_agreement += agreement.sum()
    
    tumor_mask = (unet_pred == 1)
    if tumor_mask.sum() > 0:
        total_tumor_pixels += tumor_mask.sum()
        tumor_agreement += (agreement & tumor_mask).sum()

overall_acc = total_agreement / total_pixels
tumor_acc = tumor_agreement / total_tumor_pixels if total_tumor_pixels > 0 else 0.0

print(f"Overall Pixel Accuracy:  {overall_acc:.4f} ({total_agreement}/{total_pixels})")
print(f"Tumor Pixel Accuracy:    {tumor_acc:.4f} ({tumor_agreement}/{total_tumor_pixels})")
print(f"Background Accuracy:     {(total_agreement - tumor_agreement)/(total_pixels - total_tumor_pixels):.4f}")
print("="*50)