# JAX Tumor INR – Data Split (BraTS 2023)

This starter notebook only locates BraTS 2023 training cases under `data/BraTS-2023` and produces an 80/20 train/val split.

In [None]:
# Configuration
from pathlib import Path
DATA_ROOT = Path('../data/BraTS-2023').resolve()
SEED = 42
SPLIT_TRAIN_FILE = DATA_ROOT / 'split_train.txt'
SPLIT_VAL_FILE = DATA_ROOT / 'split_val.txt'
BATCH_SIZE = 2
PE_LEVELS = 6
INCLUDE_RAW_COORDS = True
MODALITY_SUFFIX = {
    't1n': '-t1n.nii.gz',
    't1c': '-t1c.nii.gz',
    't2f': '-t2f.nii.gz',
    't2w': '-t2w.nii.gz',
    'seg': '-seg.nii.gz',
}
MODALITY_ORDER = ['t1n','t1c','t2f','t2w']
LABEL_MAP = {4:3}
ENABLE_VIS = False
ENABLE_INTERACTIVE_VIS = True  # if True, ipywidgets viewer enabled
USE_MPL_WIDGET = True  # use %matplotlib widget backend for live updates
MASK_ALPHA = 0.7        # opacity for labels if not using RGBA colormap
assert DATA_ROOT.exists(), f'BraTS-2023 folder not found at: {DATA_ROOT}'


In [34]:
# Discover cases and create/use 80/20 split
import random
case_dirs = sorted([p for p in DATA_ROOT.iterdir() if p.is_dir() and p.name.startswith('BraTS-')])
print(f'Found {len(case_dirs)} cases under: {DATA_ROOT}')
if SPLIT_TRAIN_FILE.exists() and SPLIT_VAL_FILE.exists():
    train_cases = [DATA_ROOT / line.strip() for line in SPLIT_TRAIN_FILE.read_text().splitlines() if line.strip()]
    val_cases = [DATA_ROOT / line.strip() for line in SPLIT_VAL_FILE.read_text().splitlines() if line.strip()]
else:
    rng = random.Random(SEED)
    rng.shuffle(case_dirs)
    split = int(len(case_dirs) * 0.8)
    train_cases = case_dirs[:split]
    val_cases = case_dirs[split:]
    SPLIT_TRAIN_FILE.write_text('\n'.join(p.name for p in train_cases))
    SPLIT_VAL_FILE.write_text('\n'.join(p.name for p in val_cases))
    print('Wrote split files:', SPLIT_TRAIN_FILE, 'and', SPLIT_VAL_FILE)
print(f'Train: {len(train_cases)} cases, Val: {len(val_cases)} cases')
for label, subset in [('train', train_cases[:3]), ('val', val_cases[:3])]:
    print(f'[{label}]', [p.name for p in subset])


Found 1251 cases under: /Users/kylelukaszek/Classes/AI/Project/data/BraTS-2023
Train: 1000 cases, Val: 251 cases
[train] ['BraTS-GLI-01299-000', 'BraTS-GLI-01514-000', 'BraTS-GLI-01664-000']
[val] ['BraTS-GLI-00510-000', 'BraTS-GLI-00003-000', 'BraTS-GLI-00499-000']


In [35]:
# JAX data loader with 3D positional encoding
import math, numpy as np, nibabel as nib, jax, jax.numpy as jnp

def _read_nifti(path: Path) -> np.ndarray:
    img = nib.load(str(path))
    arr = np.asarray(img.get_fdata(), dtype=np.float32)
    arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0)
    return arr

def _normalize_intensity(vol: np.ndarray) -> np.ndarray:
    p1, p99 = np.percentile(vol, 1), np.percentile(vol, 99)
    if p99 > p1: vol = np.clip(vol, p1, p99)
    m, s = vol.mean(), vol.std() + 1e-6
    return (vol - m) / s

def _remap_labels(mask: np.ndarray) -> np.ndarray:
    out = mask.astype(np.int32)
    for k,v in LABEL_MAP.items():
        out[out==k] = v
    return out

def find_case_files(case_dir: Path):
    name = case_dir.name
    files = {}
    for key, suf in MODALITY_SUFFIX.items():
        f = case_dir / f'{name}{suf}'
        if f.exists(): files[key] = f
    for k in ['t1n','t1c','t2f','t2w','seg']:
        if k not in files: raise FileNotFoundError(f'Missing {k} in {case_dir}')
    return files

def load_case_volumes(case_dir: Path):
    files = find_case_files(case_dir)
    vols = {k: _read_nifti(v) for k,v in files.items()}
    mods = [_normalize_intensity(vols[k]) for k in MODALITY_ORDER]
    seg = _remap_labels(vols['seg'])
    shape = mods[0].shape
    assert all(m.shape == shape for m in mods) and seg.shape == shape
    img = np.stack(mods, axis=-1)  # [H,W,D,C]
    return img, seg

def build_coords(H, W, z_idx, D):
    ys = np.linspace(-1,1,H,dtype=np.float32)
    xs = np.linspace(-1,1,W,dtype=np.float32)
    yy, xx = np.meshgrid(ys, xs, indexing='ij')
    z = np.float32(2*(z_idx/max(D-1,1)) - 1)
    zz = np.full((H,W), z, dtype=np.float32)
    return np.stack([xx, yy, zz], axis=-1)

def fourier_features(coords: np.ndarray, L: int):
    if L <= 0: return coords if INCLUDE_RAW_COORDS else np.zeros((*coords.shape[:-1],0), dtype=np.float32)
    feats = []
    for l in range(L):
        f = (2**l) * math.pi
        feats.append(np.sin(f*coords))
        feats.append(np.cos(f*coords))
    pe = np.concatenate(feats, axis=-1)
    return pe if not INCLUDE_RAW_COORDS else np.concatenate([coords, pe], axis=-1)

def make_batch_iterator(case_list, batch_size: int, seed: int=0):
    rng = np.random.default_rng(seed)
    cases = list(case_list)
    while True:
        imgs=[]; masks=[]; coords_list=[]; pe_list=[]
        for _ in range(batch_size):
            cdir = cases[rng.integers(0, len(cases))]
            img3d, seg3d = load_case_volumes(cdir)
            H,W,D,C = img3d.shape
            z = int(rng.integers(0,D))
            img = img3d[:,:,z,:]
            mask = seg3d[:,:,z]
            coords = build_coords(H,W,z,D)
            pe = fourier_features(coords, PE_LEVELS).astype(np.float32)
            imgs.append(img); masks.append(mask); coords_list.append(coords); pe_list.append(pe)
        batch = {
            'image': jnp.array(np.stack(imgs,0)),
            'mask': jnp.array(np.stack(masks,0)),
            'coords': jnp.array(np.stack(coords_list,0)),
            'pe': jnp.array(np.stack(pe_list,0)),
        }
        yield batch

train_iter = make_batch_iterator(train_cases, BATCH_SIZE, seed=SEED)
val_iter = make_batch_iterator(val_cases, BATCH_SIZE, seed=SEED+1)
sample = next(train_iter)
print('Sample batch shapes:', {k: tuple(v.shape) for k,v in sample.items()})


Sample batch shapes: {'image': (2, 240, 240, 4), 'mask': (2, 240, 240), 'coords': (2, 240, 240, 3), 'pe': (2, 240, 240, 39)}


In [36]:
# Optional interactive visualization — single viewer
if ENABLE_VIS and ENABLE_INTERACTIVE_VIS:
    try:
        import numpy as np
        import matplotlib.pyplot as plt
        from matplotlib.colors import ListedColormap
        from matplotlib.patches import Patch
        from ipywidgets import Dropdown, IntSlider, VBox, HBox, Play, jslink, Label, FloatSlider
        from IPython.display import display
        from IPython import get_ipython

        # Use widget backend for responsive updates if requested
        if 'USE_MPL_WIDGET' in globals() and USE_MPL_WIDGET:
            ip = get_ipython()
            if ip is not None:
                try:
                    ip.run_line_magic('matplotlib', 'widget')
                except Exception:
                    pass

        # ---- Data helpers ----
        def get_cases(split):
            return train_cases if split == 'train' else val_cases
        _cache = {}  # (split, name) -> (img3d, seg3d)
        def get_case(split, name):
            key = (split, name)
            if key not in _cache:
                cdir = {c.name: c for c in get_cases(split)}[name]
                _cache[key] = load_case_volumes(cdir)
            return _cache[key]

        # ---- UI controls ----
        split_dd = Dropdown(description='Split', options=['train','val'], value='train')
        modality_options = MODALITY_ORDER
        case_names = sorted([c.name for c in get_cases(split_dd.value)])
        case_dd = Dropdown(description='Case', options=case_names, value=(case_names[0] if case_names else None))
        mod_dd  = Dropdown(description='Modality', options=modality_options, value=modality_options[0])
        alpha_slider = FloatSlider(description='Mask α', min=0.0, max=1.0, step=0.05, value=0.6, readout_format='.2f', continuous_update=False)

        first_img, first_seg = get_case(split_dd.value, case_dd.value)
        D = first_img.shape[2]
        z_slider = IntSlider(description='Slice z', min=0, max=max(D-1,0), step=1, value=(D//2 if D>0 else 0), continuous_update=False)
        play = Play(interval=120, min=0, max=max(D-1,0), step=1, value=z_slider.value)
        jslink((play, 'value'), (z_slider, 'value'))
        status = Label(value='')

        # ---- Colormap & legend (0 transparent) ----
        cmap = ListedColormap([
            (0,0,0,0.0),
            (0.12, 0.47, 0.71, 1.0),  # 1 NCR/NET (blue)
            (0.20, 0.63, 0.17, 1.0),  # 2 ED (green)
            (0.84, 0.15, 0.16, 1.0),  # 3 ET (red)
        ])
        legend_items = [('NCR/NET (1)', cmap(1)), ('ED (2)', cmap(2)), ('ET (3)', cmap(3))]

        # ---- Figure init ----
        # Disable auto-display to avoid duplicate viewers
        plt.ioff()
        fig, ax = plt.subplots(figsize=(5.5,5.5))
        base = first_img[:, :, z_slider.value, modality_options.index(mod_dd.value)]
        mask = first_seg[:, :, z_slider.value]
        base_im = ax.imshow(base, cmap='gray')
        mask_im = ax.imshow(np.ma.masked_where(mask==0, mask), cmap=cmap, vmin=0, vmax=3, alpha=alpha_slider.value)
        ax.axis('off')
        ax.set_title(f'{split_dd.value}:{case_dd.value} z={z_slider.value} modality={mod_dd.value}')
        handles = [Patch(facecolor=col, edgecolor='k', label=lab) for lab, col in legend_items]
        ax.legend(handles=handles, loc='lower right', frameon=True)
        status.value = f'Shape: {first_img.shape} | z={z_slider.value}'

        # ---- Update helpers ----
        def update_depth_bounds(img):
            newD = img.shape[2]
            z_slider.max = max(newD-1, 0)
            play.max = max(newD-1, 0)
            if z_slider.value > z_slider.max:
                z_slider.value = (newD//2 if newD>0 else 0)
        def redraw():
            img3d, seg3d = get_case(split_dd.value, case_dd.value)
            update_depth_bounds(img3d)
            z = z_slider.value
            m_idx = modality_options.index(mod_dd.value)
            base_im.set_data(img3d[:, :, z, m_idx])
            mask_im.set_data(np.ma.masked_where(seg3d[:, :, z]==0, seg3d[:, :, z]))
            mask_im.set_alpha(alpha_slider.value)
            ax.set_title(f'{split_dd.value}:{case_dd.value} z={z} modality={mod_dd.value}')
            status.value = f'Shape: {img3d.shape} | z={z}'
            fig.canvas.draw_idle()

        # ---- Observers ----
        def on_split_change(change):
            if change['name'] == 'value':
                names = sorted([c.name for c in get_cases(split_dd.value)])
                case_dd.options = names
                if names and case_dd.value not in names:
                    case_dd.value = names[0]
                redraw()
        def on_simple_change(change):
            if change['name'] == 'value':
                redraw()
        split_dd.observe(on_split_change, names='value')
        case_dd.observe(on_simple_change, names='value')
        mod_dd.observe(on_simple_change, names='value')
        z_slider.observe(on_simple_change, names='value')
        alpha_slider.observe(on_simple_change, names='value')

        # ---- Layout & single display ----
        ui = VBox([HBox([split_dd, case_dd, mod_dd]), HBox([play, z_slider]), alpha_slider, status])
        container = VBox([ui, fig.canvas]) if ('USE_MPL_WIDGET' in globals() and USE_MPL_WIDGET) else VBox([ui, fig])
        display(container)
        # Re-enable interactive display for subsequent figures
        plt.ion()
        _ = None  # prevent duplicate implicit display
    except Exception as e:
        print('Interactive viewer error:', e)


VBox(children=(VBox(children=(HBox(children=(Dropdown(description='Split', options=('train', 'val'), value='tr…

### Notes
- Iterator samples random z-slices; provides image, mask, raw coords, and Fourier positional encodings.
- Adjust `PE_LEVELS`, `INCLUDE_RAW_COORDS`, `BATCH_SIZE` in the config cell.
- A tiny ReLU MLP can consume `pe` (and optionally voxel intensity features).
- Next step: define lightweight INR model (e.g., small MLP) taking (x,y,z PE) -> class logits.
