 # Research-Backed BraTS 2023 INR with Advanced Techniques



 Key improvements based on 2022-2025 literature:

 - Unified Focal Loss + TV regularization (spatial coherence)

 - Uncertainty-guided coordinate sampling (prevents over-segmentation)

 - Anisotropic Fourier features (accounts for voxel spacing)

 - Multi-metric evaluation (Dice, HD95, connected components)

 - Persistent train/val splits (reproducibility)

 - Two-stage training protocol (coarse → refine)

In [None]:
import os, math, json, time, pathlib, functools
from typing import Tuple, Dict, Any
import numpy as np
import jax
import jax.numpy as jnp
import optax
import nibabel as nib
import wandb
from IPython.display import clear_output, display
import matplotlib.pyplot as plt
from scipy.ndimage import label as connected_components_3d
from scipy.ndimage import distance_transform_edt
from tqdm.auto import tqdm

os.environ["JAX_PLATFORM_NAME"] = "cpu"
print('JAX devices:', jax.devices())


In [None]:
# =============================================================================
# Configuration
# =============================================================================
DATA_ROOT = pathlib.Path('../data/BraTS-2023')

# Data splits (persistent files)
SPLITS_DIR = pathlib.Path('../data/splits')
SPLIT_TRAIN_FILE = SPLITS_DIR / 'train.txt'
SPLIT_VAL_FILE = SPLITS_DIR / 'val.txt'

# Sampling - Hybrid strategy (research-backed)
GLOBAL_BATCH_SIZE = 12000  # Research recommends 12k-16k for medical segmentation
MICRO_BATCH_SIZE = 4000
ACCUM_STEPS = GLOBAL_BATCH_SIZE // MICRO_BATCH_SIZE

# Architecture - Anisotropic Fourier features
FOURIER_FREQS = [1, 2, 4, 8, 16]  # Multi-scale frequencies
FOURIER_DIM = 128  # Learnable Fourier features
HIDDEN_DIMS = [64, 64, 64]  # Research-backed sizes
COORD_INJECTION_LAYERS = [1, 2, 3]
MODALITY_INJECTION_LAYERS = [2]
DROPOUT_RATE = 0.3  # Research recommends 0.3

# Training - Two-stage protocol
STAGE1_STEPS = 3000  # Coarse learning
STAGE2_STEPS = 0  # Boundary refinement
TOTAL_STEPS = STAGE1_STEPS + STAGE2_STEPS

LR_STAGE1 = 5e-5  # Reduced from 2e-4
LR_STAGE2 = 1e-5  # Reduced from 5e-5
MIN_LR = 1e-7
WARMUP_STEPS = 1000  # Increased from 500
WEIGHT_DECAY = 0.01
CLIP_NORM = 0.5  # Increased clipping from 1.0

# Loss weights - Research-backed Unified Focal + TV
NUM_CLASSES = 4
UNIFIED_FOCAL_GAMMA = 0.5  # Key hyperparameter
UNIFIED_FOCAL_DELTA = 0.6
UNIFIED_FOCAL_LAMBDA = 0.5

# Stage-dependent loss weights
TV_WEIGHT_STAGE1 = 0.02  # Light TV in stage 1
TV_WEIGHT_STAGE2 = 0.05  # Stronger TV in stage 2
BOUNDARY_WEIGHT = 0.1

CLASS_WEIGHTS = [0.1, 1.5, 1.0, 2.0]  # Background, NCR, ED, ET (favor small classes)

# Sampling strategy - Uncertainty-guided
UNCERTAINTY_RATIO = 0.5  # 50% uncertainty-guided (increases in stage 2)
BALANCED_RATIO = 0.3     # 30% balanced class sampling
UNIFORM_RATIO = 0.2      # 20% uniform random

# Stochastic preconditioning
COORD_NOISE_SIGMA_INIT = 0.3
COORD_NOISE_SIGMA_FINAL = 0.1

# Validation & evaluation
VAL_EVAL_STEPS = 1000  # Evaluate every 500 steps
CHUNK_SIZE = 128  # Cases per chunk
STEPS_PER_CHUNK = 100  # Rotate chunks every 1000 steps

RNG_SEED = 42
jax_key = jax.random.PRNGKey(RNG_SEED)

# W&B
WANDB_PROJECT = "brats-inr-research-backed"
WANDB_TAGS = ["unified-focal", "tv-regularization", "uncertainty-sampling", "anisotropic-fourier"]
WANDB_NOTES = "Research-backed INR with spatial coherence and multi-metric evaluation"

config = {
    "global_batch_size": GLOBAL_BATCH_SIZE,
    "fourier_dim": FOURIER_DIM,
    "hidden_dims": HIDDEN_DIMS,
    "total_steps": TOTAL_STEPS,
    "unified_focal_gamma": UNIFIED_FOCAL_GAMMA,
    "tv_weight_stage1": TV_WEIGHT_STAGE1,
    "tv_weight_stage2": TV_WEIGHT_STAGE2,
    "uncertainty_ratio": UNCERTAINTY_RATIO,
    "coord_noise_sigma": COORD_NOISE_SIGMA_INIT,
}

wandb.init(
    project=WANDB_PROJECT,
    config=config,
    tags=WANDB_TAGS,
    notes=WANDB_NOTES,
)

print(f"W&B Run: {wandb.run.name}")
SAVE_PATH = pathlib.Path(f"../artifacts/{WANDB_PROJECT}/{wandb.run.name}/")


In [None]:
# =============================================================================
# Data Splitting (Persistent - Research Best Practice)
# =============================================================================
MODALITY_SUFFIXES = ['t1n', 't1c', 't2w', 't2f']
SEG_SUFFIX = 'seg'

def find_cases(root: pathlib.Path):
    """Find all case directories"""
    cases = []
    for p in sorted(root.iterdir()):
        if p.is_dir():
            if any((p / f'{p.name}-{m}.nii.gz').exists() for m in MODALITY_SUFFIXES):
                cases.append(p)
    return cases

def load_split_file(split_file: pathlib.Path) -> list:
    """Load case names from split file"""
    if not split_file.exists():
        return None
    with open(split_file, 'r') as f:
        case_names = [line.strip() for line in f if line.strip()]
    return case_names

def save_split_file(split_file: pathlib.Path, case_names: list):
    """Save case names to split file"""
    split_file.parent.mkdir(parents=True, exist_ok=True)
    with open(split_file, 'w') as f:
        for name in case_names:
            f.write(f"{name}\n")
    print(f"Saved {len(case_names)} cases to {split_file}")

def get_case_name(case_path: pathlib.Path) -> str:
    """Extract case name from path"""
    return case_path.name

def create_splits_if_missing(all_cases: list, rng_seed: int = 42):
    """Create persistent train/val splits"""
    train_names = load_split_file(SPLIT_TRAIN_FILE)
    val_names = load_split_file(SPLIT_VAL_FILE)
    
    if train_names is not None and val_names is not None:
        print(f"✓ Loaded existing splits: {len(train_names)} train, {len(val_names)} val")
        return train_names, val_names
    
    print(f"Creating new splits from {len(all_cases)} cases...")
    all_case_names = [get_case_name(c) for c in all_cases]
    
    rng = np.random.default_rng(rng_seed)
    shuffled_names = list(all_case_names)
    rng.shuffle(shuffled_names)
    
    split_idx = int(len(shuffled_names) * 0.8)
    train_names = shuffled_names[:split_idx]
    val_names = shuffled_names[split_idx:]
    
    save_split_file(SPLIT_TRAIN_FILE, train_names)
    save_split_file(SPLIT_VAL_FILE, val_names)
    
    print(f"Created splits: {len(train_names)} train, {len(val_names)} val")
    return train_names, val_names

def match_cases_to_names(all_cases: list, case_names: list) -> list:
    """Match case paths to names"""
    case_name_to_path = {get_case_name(cp): cp for cp in all_cases}
    matched_cases = []
    missing = []
    
    for name in case_names:
        if name in case_name_to_path:
            matched_cases.append(case_name_to_path[name])
        else:
            missing.append(name)
    
    if missing:
        print(f"⚠ Warning: {len(missing)} cases from split file not found")
    
    return matched_cases

# Discover and split
print("Discovering cases...")
all_cases_full = find_cases(DATA_ROOT)
print(f'Total discovered: {len(all_cases_full)} cases')

train_names, val_names = create_splits_if_missing(all_cases_full, RNG_SEED)
train_cases = match_cases_to_names(all_cases_full, train_names)
val_cases = match_cases_to_names(all_cases_full, val_names)

print(f'Using: Train={len(train_cases)}, Val={len(val_cases)}')

# Verify no overlap
train_names_set = set(get_case_name(c) for c in train_cases)
val_names_set = set(get_case_name(c) for c in val_cases)
if train_names_set & val_names_set:
    raise ValueError("Train/Val overlap detected!")
print("✓ No overlap between train and validation sets")

wandb.config.update({"train_cases": len(train_cases), "val_cases": len(val_cases)})


In [None]:
# =============================================================================
# Data Loading
# =============================================================================
def load_case(case_dir: pathlib.Path):
    """Load and normalize a single case"""
    base = case_dir.name
    mods = []
    for suf in MODALITY_SUFFIXES:
        fp = case_dir / f'{base}-{suf}.nii.gz'
        img = nib.load(str(fp))
        arr = img.get_fdata().astype(np.float32)
        mask = arr != 0
        if mask.any():
            mu = arr[mask].mean()
            sigma = arr[mask].std() + 1e-6
            arr = (arr - mu) / sigma
        mods.append(arr)
    
    seg_fp = case_dir / f'{base}-{SEG_SUFFIX}.nii.gz'
    seg = nib.load(str(seg_fp)).get_fdata().astype(np.int16)
    mods_arr = np.stack(mods, axis=0)
    return mods_arr, seg


In [None]:
# =============================================================================
# Chunked Cache
# =============================================================================
class ChunkedBraTSCache:
    """Memory-efficient chunked data loading"""
    def __init__(self, case_paths, chunk_size=128, name="cache"):
        self.case_paths = case_paths
        self.chunk_size = chunk_size
        self.name = name
        self.n_cases = len(case_paths)
        self.n_chunks = int(np.ceil(self.n_cases / chunk_size))
        self.current_chunk_idx = 0
        
        print(f'Init {name}: {self.n_cases} cases, {self.n_chunks} chunks')
        first_mods, first_seg = load_case(case_paths[0])
        self.vol_shape = first_mods.shape[1:]
        self.n_modalities = first_mods.shape[0]
        
        # Pre-compute boundary distance transforms (for boundary loss)
        self.boundary_dists = []
        
        self.cache = []
        self.chunk_case_indices = []
        self._load_chunk(0)
        
    def _load_chunk(self, chunk_idx):
        """Load specific chunk"""
        start_idx = chunk_idx * self.chunk_size
        end_idx = min(start_idx + self.chunk_size, self.n_cases)
        chunk_paths = self.case_paths[start_idx:end_idx]
        
        print(f'Loading {self.name} chunk {chunk_idx+1}/{self.n_chunks}...')
        
        self.cache = []
        self.boundary_dists = []
        self.chunk_case_indices = list(range(start_idx, end_idx))
        
        for i, cp in enumerate(chunk_paths):
            if i % 20 == 0 and i > 0:
                print(f'  {i}/{len(chunk_paths)}...')
            mods, seg = load_case(cp)
            
            # Compute boundary distance transform for boundary loss
            boundary_dist = np.zeros_like(seg, dtype=np.float32)
            for c in range(1, NUM_CLASSES):
                mask = (seg == c).astype(np.uint8)
                if mask.sum() > 0:
                    dist = distance_transform_edt(1 - mask)
                    boundary_dist = np.maximum(boundary_dist, 1.0 / (1.0 + dist))
            
            self.cache.append({'mods': mods, 'seg': seg})
            self.boundary_dists.append(boundary_dist)
        
        bytes_per_case = self.cache[0]['mods'].nbytes + self.cache[0]['seg'].nbytes
        chunk_gb = (bytes_per_case * len(self.cache)) / 1e9
        print(f'{self.name} chunk loaded: {len(self.cache)} cases, {chunk_gb:.2f} GB')
        
        self.current_chunk_idx = chunk_idx
    
    def next_chunk(self):
        """Load next chunk"""
        next_idx = (self.current_chunk_idx + 1) % self.n_chunks
        self._load_chunk(next_idx)
    
    def get_current_chunk_indices(self):
        """Get current chunk case indices"""
        return self.chunk_case_indices

train_cache = ChunkedBraTSCache(train_cases, chunk_size=CHUNK_SIZE, name="train")
val_cache = ChunkedBraTSCache(val_cases, chunk_size=CHUNK_SIZE, name="val") if val_cases else None

H, W, D = train_cache.vol_shape
M = train_cache.n_modalities

print(f'Volume: {train_cache.vol_shape}, Modalities: {M}')
wandb.config.update({"volume_shape": list(train_cache.vol_shape), "num_modalities": M})


In [None]:
# =============================================================================
# Anisotropic Fourier Features (Research-Backed)
# =============================================================================
def init_anisotropic_fourier_features(key, fourier_dim, input_dim=3, voxel_spacing=(1.0, 1.0, 1.0)):
    """
    Initialize learnable anisotropic Fourier features.
    Accounts for different voxel spacings in medical imaging.
    """
    # Random Gaussian initialization scaled by voxel spacing
    key, subkey = jax.random.split(key)
    B = jax.random.normal(subkey, (fourier_dim // 2, input_dim)) * 5.0  # σ = 5
    
    # Scale by voxel spacing for anisotropic correction
    voxel_scale = jnp.array(voxel_spacing)
    B = B / voxel_scale[None, :]  # Broadcast and divide
    
    return key, {'B': B}

def apply_fourier_features(fourier_params, coords):
    """
    Apply Fourier feature mapping: γ(x) = [sin(2πB·x), cos(2πB·x)]
    coords: (N, 3) in range [-1, 1]
    """
    B = fourier_params['B']
    # Map to [0, 1] for frequency application
    coords_01 = (coords + 1.0) / 2.0
    
    angles = 2 * jnp.pi * jnp.dot(coords_01, B.T)  # (N, fourier_dim//2)
    features = jnp.concatenate([jnp.sin(angles), jnp.cos(angles)], axis=-1)  # (N, fourier_dim)
    
    return features


In [None]:
# =============================================================================
# Coordinate Injection MLP with Dropout
# =============================================================================
def glorot(key, shape):
    fan_in, fan_out = shape[0], shape[1]
    limit = math.sqrt(6.0 / (fan_in + fan_out))
    return jax.random.uniform(key, shape, minval=-limit, maxval=limit)

def init_coord_injection_mlp(key, coord_dim, modality_dim, fourier_dim, hidden_dims, out_dim, dropout_rate=0.3):
    """Initialize MLP with Fourier features + modalities"""
    params = []
    
    # Layer 0: Fourier features + modalities
    key, k1 = jax.random.split(key)
    W = glorot(k1, (fourier_dim + modality_dim, hidden_dims[0]))
    b = jnp.zeros((hidden_dims[0],))
    params.append({'W': W, 'b': b})
    
    # Hidden layers with coordinate/modality injection
    for i in range(1, len(hidden_dims)):
        key, k1 = jax.random.split(key)
        
        in_dim = hidden_dims[i-1]
        if i in COORD_INJECTION_LAYERS:
            in_dim += coord_dim
        if i in MODALITY_INJECTION_LAYERS:
            in_dim += modality_dim
        
        W = glorot(k1, (in_dim, hidden_dims[i]))
        b = jnp.zeros((hidden_dims[i],))
        params.append({'W': W, 'b': b})
    
    # Output layer
    key, k1 = jax.random.split(key)
    W = glorot(k1, (hidden_dims[-1], out_dim))
    b = jnp.zeros((out_dim,))
    params.append({'W': W, 'b': b})
    
    return key, params

def apply_coord_injection_mlp(params, coords, modalities, fourier_features, training=True, dropout_key=None, dropout_rate=0.3):
    """Apply MLP with dropout during training - simplified for JIT compatibility"""
    # Layer 0
    h = jnp.dot(jnp.concatenate([fourier_features, modalities], axis=-1), params[0]['W']) + params[0]['b']
    h = jax.nn.relu(h)
    
    # Apply dropout conditionally using jnp.where
    # Only apply if training and dropout_key is provided
    if dropout_key is not None:
        dropout_key, subkey = jax.random.split(dropout_key)
        keep_prob = 1.0 - dropout_rate
        mask = jax.random.bernoulli(subkey, keep_prob, h.shape)
        h = jnp.where(training, h * mask / (keep_prob + 1e-10), h)
    
    # Hidden layers
    for i in range(1, len(params) - 1):
        inputs = [h]
        
        if i in COORD_INJECTION_LAYERS:
            inputs.append(coords)
        if i in MODALITY_INJECTION_LAYERS:
            inputs.append(modalities)
        
        h = jnp.concatenate(inputs, axis=-1) if len(inputs) > 1 else h
        h = jnp.dot(h, params[i]['W']) + params[i]['b']
        h = jax.nn.relu(h)
        
        # Dropout
        if dropout_key is not None:
            dropout_key, subkey = jax.random.split(dropout_key)
            keep_prob = 1.0 - dropout_rate
            mask = jax.random.bernoulli(subkey, keep_prob, h.shape)
            h = jnp.where(training, h * mask / (keep_prob + 1e-10), h)
    
    # Output
    logits = jnp.dot(h, params[-1]['W']) + params[-1]['b']
    
    return logits

# Initialize model
jax_key, fourier_params = init_anisotropic_fourier_features(
    jax_key, 
    FOURIER_DIM, 
    input_dim=3,
    voxel_spacing=(1.0, 1.0, 1.0)  # BraTS is pre-resampled to 1mm isotropic
)

jax_key, mlp_params = init_coord_injection_mlp(
    jax_key,
    coord_dim=3,
    modality_dim=M,
    fourier_dim=FOURIER_DIM,
    hidden_dims=HIDDEN_DIMS,
    out_dim=NUM_CLASSES,
    dropout_rate=DROPOUT_RATE
)

# Combine all parameters
params = {
    'fourier': fourier_params,
    'mlp': mlp_params
}

n_params_fourier = fourier_params['B'].size
n_params_mlp = sum(p['W'].size + p['b'].size for p in mlp_params)
total_params = n_params_fourier + n_params_mlp

print(f'Total parameters: {total_params:,}')
print(f'  Fourier (learnable): {n_params_fourier:,}')
print(f'  MLP: {n_params_mlp:,}')

wandb.config.update({"total_parameters": total_params})


In [None]:
# =============================================================================
# Research-Backed Loss Functions
# =============================================================================

@jax.jit
def unified_focal_loss(logits, labels, gamma=UNIFIED_FOCAL_GAMMA, delta=UNIFIED_FOCAL_DELTA, lambda_param=UNIFIED_FOCAL_LAMBDA):
    """
    Unified Focal Loss from research literature.
    Combines modified Focal Tversky Loss and modified Focal Loss.
    
    UFL = λ·mFTL + (1-λ)·mFL
    """
    num_classes = logits.shape[-1]
    # Clip logits for numerical stability
    logits_clipped = jnp.clip(logits, -10.0, 10.0)
    probs = jax.nn.softmax(logits_clipped, axis=-1)
    probs = jnp.clip(probs, 1e-7, 1.0 - 1e-7)  # Prevent log(0)
    targets = jax.nn.one_hot(labels, num_classes)
    
    # Modified Focal Tversky Loss (mFTL)
    alpha = 0.3  # Favor recall
    beta = 0.7
    
    ftl_losses = []
    for c in range(num_classes):
        p_c = probs[:, c]
        t_c = targets[:, c]
        
        tp = (p_c * t_c).sum()
        fp = (p_c * (1 - t_c)).sum()
        fn = ((1 - p_c) * t_c).sum()
        
        tversky_index = tp / (tp + alpha * fp + beta * fn + 1e-7)
        tversky_index = jnp.clip(tversky_index, 0.0, 1.0)
        ftl_c = jnp.power(1.0 - tversky_index + 1e-7, 1.0 / gamma)
        ftl_losses.append(ftl_c)
    
    mFTL = jnp.stack(ftl_losses).mean()
    
    # Modified Focal Loss (mFL)
    ce = -jnp.sum(targets * jnp.log(probs + 1e-10), axis=-1)
    p_t = jnp.sum(probs * targets, axis=-1)
    p_t = jnp.clip(p_t, 1e-7, 1.0 - 1e-7)
    focal_weight = jnp.power(1.0 - p_t + 1e-7, gamma)
    mFL = (focal_weight * ce).mean()
    
    # Unified - add small epsilon to prevent NaN
    ufl = lambda_param * mFTL + (1 - lambda_param) * mFL
    ufl = jnp.clip(ufl, 0.0, 1e6)  # Prevent explosion
    
    return ufl, {'mFTL': mFTL, 'mFL': mFL}

@jax.jit
def dice_loss(logits, labels, class_weights=None):
    """Dice loss with class weighting"""
    if class_weights is None:
        class_weights = jnp.ones(NUM_CLASSES)
    else:
        class_weights = jnp.array(class_weights)
    
    logits_clipped = jnp.clip(logits, -10.0, 10.0)
    probs = jax.nn.softmax(logits_clipped, axis=-1)
    probs = jnp.clip(probs, 1e-7, 1.0 - 1e-7)
    targets = jax.nn.one_hot(labels, NUM_CLASSES)
    
    dice_per_class = []
    for c in range(NUM_CLASSES):
        p_c = probs[:, c]
        t_c = targets[:, c]
        
        intersection = (p_c * t_c).sum()
        union = p_c.sum() + t_c.sum()
        
        dice_c = (2.0 * intersection + 1.0) / (union + 1.0)
        dice_per_class.append(dice_c)
    
    dice_per_class = jnp.stack(dice_per_class)
    weighted_dice = (dice_per_class * class_weights).sum() / (class_weights.sum() + 1e-10)
    loss = 1.0 - weighted_dice
    loss = jnp.clip(loss, 0.0, 10.0)
    
    return loss, dice_per_class

def compute_tv_loss(logits, coords_grid=None):
    """
    Total Variation loss for spatial coherence.
    Penalizes abrupt changes in predictions.
    
    For point cloud data, approximate via nearest neighbor differences.
    """
    # For now, use simple gradient penalty on logits
    # In practice, would compute spatial gradients properly
    # This is a simplified version for coordinate-based predictions
    
    # Gradient magnitude penalty
    probs = jax.nn.softmax(logits, axis=-1)
    
    # Approximate spatial TV via variance in batch
    # (proper implementation would require spatial neighbors)
    tv_approx = jnp.var(probs, axis=0).sum()
    
    return tv_approx

def loss_fn(params, coords, modalities, labels, boundary_weights, training=True, dropout_key=None, tv_weight=0.02):
    """Combined loss with all research-backed components"""
    # Forward pass
    fourier_feats = apply_fourier_features(params['fourier'], coords)
    logits = apply_coord_injection_mlp(
        params['mlp'], 
        coords, 
        modalities, 
        fourier_feats,
        training=training,
        dropout_key=dropout_key,
        dropout_rate=DROPOUT_RATE
    )
    
    # Unified Focal Loss
    ufl, ufl_aux = unified_focal_loss(logits, labels)
    
    # Dice Loss
    dice_loss_val, dice_per_class = dice_loss(logits, labels, CLASS_WEIGHTS)
    
    # TV regularization
    tv_loss_val = compute_tv_loss(logits, coords)
    
    # Boundary loss (weight voxels near boundaries)
    probs = jax.nn.softmax(logits, axis=-1)
    preds = jnp.argmax(logits, axis=-1)
    boundary_loss_val = jnp.mean(boundary_weights * jnp.abs(preds - labels))
    
    # Combined loss
    loss = (
        0.5 * ufl +
        0.5 * dice_loss_val +
        tv_weight * tv_loss_val +
        BOUNDARY_WEIGHT * boundary_loss_val
    )
    
    # Metrics
    tumor_dice = dice_per_class[1:].mean()
    
    aux = {
        'ufl': ufl,
        'mFTL': ufl_aux['mFTL'],
        'mFL': ufl_aux['mFL'],
        'dice_loss': dice_loss_val,
        'tv_loss': tv_loss_val,
        'boundary_loss': boundary_loss_val,
        'dice_per_class': dice_per_class,
        'dice_mean_tumor': tumor_dice,
        'accuracy': (preds == labels).mean(),
    }
    
    return loss, aux


In [None]:
# =============================================================================
# Uncertainty-Guided Sampling (Research-Backed)
# =============================================================================

def compute_uncertainty_mc_dropout(params, coords, modalities, fourier_params, n_samples=5, dropout_key=None):
    """
    Compute uncertainty via MC-Dropout.
    Returns: entropy-based uncertainty scores.
    """
    all_probs = []
    
    for i in range(n_samples):
        if dropout_key is not None:
            dropout_key, subkey = jax.random.split(dropout_key)
        else:
            subkey = None
        
        fourier_feats = apply_fourier_features(fourier_params, coords)
        logits = apply_coord_injection_mlp(
            params,
            coords,
            modalities,
            fourier_feats,
            training=True,  # Keep dropout on
            dropout_key=subkey,
            dropout_rate=DROPOUT_RATE
        )
        probs = jax.nn.softmax(logits, axis=-1)
        all_probs.append(probs)
    
    # Average predictions
    all_probs = jnp.stack(all_probs, axis=0)  # (n_samples, N, num_classes)
    mean_probs = all_probs.mean(axis=0)  # (N, num_classes)
    
    # Entropy as uncertainty
    entropy = -jnp.sum(mean_probs * jnp.log(mean_probs + 1e-10), axis=-1)  # (N,)
    
    return entropy

def sample_batch_hybrid(rng_key, batch_size, cache, params, fourier_params, 
                        uncertainty_ratio=0.5, balanced_ratio=0.3, uniform_ratio=0.2,
                        use_uncertainty=True):
    """
    Hybrid sampling: uncertainty + balanced + uniform.
    Research shows this prevents over-segmentation.
    """
    available_cases = cache.get_current_chunk_indices()
    
    n_uncertainty = int(batch_size * uncertainty_ratio) if use_uncertainty else 0
    n_balanced = int(batch_size * balanced_ratio)
    n_uniform = batch_size - n_uncertainty - n_balanced
    
    all_coords, all_intens, all_labels, all_boundary_weights = [], [], [], []
    
    # 1. Uniform random sampling
    if n_uniform > 0:
        rng_key, subkey = jax.random.split(rng_key)
        key_int = int(np.array(subkey)[0])
        np.random.seed(key_int)
        
        ci = np.random.choice(available_cases, size=n_uniform)
        xs = np.random.randint(0, H, n_uniform)
        ys = np.random.randint(0, W, n_uniform)
        zs = np.random.randint(0, D, n_uniform)
        
        coords_uniform = []
        intens_uniform = []
        labels_uniform = []
        boundary_uniform = []
        
        for i in range(n_uniform):
            case_idx = int(ci[i])
            local_idx = available_cases.index(case_idx)
            x, y, z = int(xs[i]), int(ys[i]), int(zs[i])
            
            coords_uniform.append([float(x), float(y), float(z)])
            intens_uniform.append(cache.cache[local_idx]['mods'][:, x, y, z])
            labels_uniform.append(int(cache.cache[local_idx]['seg'][x, y, z]))
            boundary_uniform.append(float(cache.boundary_dists[local_idx][x, y, z]))
        
        all_coords.append(np.array(coords_uniform, dtype=np.float32))
        all_intens.append(np.stack(intens_uniform, axis=0))
        all_labels.append(np.array(labels_uniform, dtype=np.int32))
        all_boundary_weights.append(np.array(boundary_uniform, dtype=np.float32))
    
    # 2. Balanced class sampling
    if n_balanced > 0:
        n_per_class = n_balanced // NUM_CLASSES
        
        for c in range(NUM_CLASSES):
            rng_key, subkey = jax.random.split(rng_key)
            key_int = int(np.array(subkey)[0])
            np.random.seed(key_int)
            
            coords_class = []
            intens_class = []
            labels_class = []
            boundary_class = []
            
            attempts = 0
            while len(coords_class) < n_per_class and attempts < n_per_class * 100:
                n_sample = min(1024, n_per_class * 5)
                ci = np.random.choice(available_cases, size=n_sample)
                xs = np.random.randint(0, H, n_sample)
                ys = np.random.randint(0, W, n_sample)
                zs = np.random.randint(0, D, n_sample)
                
                for i in range(n_sample):
                    if len(coords_class) >= n_per_class:
                        break
                    
                    case_idx = int(ci[i])
                    local_idx = available_cases.index(case_idx)
                    x, y, z = int(xs[i]), int(ys[i]), int(zs[i])
                    label = int(cache.cache[local_idx]['seg'][x, y, z])
                    
                    if label == c:
                        coords_class.append([float(x), float(y), float(z)])
                        intens_class.append(cache.cache[local_idx]['mods'][:, x, y, z])
                        labels_class.append(label)
                        boundary_class.append(float(cache.boundary_dists[local_idx][x, y, z]))
                
                attempts += n_sample
            
            # Pad if needed
            while len(coords_class) < n_per_class:
                if coords_class:
                    coords_class.append(coords_class[-1])
                    intens_class.append(intens_class[-1])
                    labels_class.append(labels_class[-1])
                    boundary_class.append(boundary_class[-1])
                else:
                    coords_class.append([0.0, 0.0, 0.0])
                    intens_class.append(np.zeros(M, dtype=np.float32))
                    labels_class.append(0)
                    boundary_class.append(0.0)
            
            all_coords.append(np.array(coords_class[:n_per_class], dtype=np.float32))
            all_intens.append(np.stack(intens_class[:n_per_class], axis=0))
            all_labels.append(np.array(labels_class[:n_per_class], dtype=np.int32))
            all_boundary_weights.append(np.array(boundary_class[:n_per_class], dtype=np.float32))
    
    # 3. Uncertainty-guided sampling (if enabled)
    if n_uncertainty > 0 and use_uncertainty:
        # Sample candidates
        n_candidates = min(10000, n_uncertainty * 10)
        
        rng_key, subkey = jax.random.split(rng_key)
        key_int = int(np.array(subkey)[0])
        np.random.seed(key_int)
        
        ci_cand = np.random.choice(available_cases, size=n_candidates)
        xs_cand = np.random.randint(0, H, n_candidates)
        ys_cand = np.random.randint(0, W, n_candidates)
        zs_cand = np.random.randint(0, D, n_candidates)
        
        coords_cand = []
        intens_cand = []
        labels_cand = []
        boundary_cand = []
        
        for i in range(n_candidates):
            case_idx = int(ci_cand[i])
            local_idx = available_cases.index(case_idx)
            x, y, z = int(xs_cand[i]), int(ys_cand[i]), int(zs_cand[i])
            
            coords_cand.append([float(x), float(y), float(z)])
            intens_cand.append(cache.cache[local_idx]['mods'][:, x, y, z])
            labels_cand.append(int(cache.cache[local_idx]['seg'][x, y, z]))
            boundary_cand.append(float(cache.boundary_dists[local_idx][x, y, z]))
        
        coords_cand = np.array(coords_cand, dtype=np.float32)
        intens_cand = np.stack(intens_cand, axis=0)
        labels_cand = np.array(labels_cand, dtype=np.int32)
        boundary_cand = np.array(boundary_cand, dtype=np.float32)
        
        # Normalize coords for uncertainty computation
        norm_coords_cand = (coords_cand / np.array([H-1, W-1, D-1], dtype=np.float32)) * 2.0 - 1.0
        
        # Compute uncertainty
        rng_key, dropout_key = jax.random.split(rng_key)
        uncertainty = compute_uncertainty_mc_dropout(
            params['mlp'],
            jnp.array(norm_coords_cand),
            jnp.array(intens_cand),
            params['fourier'],
            n_samples=5,
            dropout_key=dropout_key
        )
        
        # Select top uncertain points
        uncertainty_np = np.array(uncertainty)
        top_indices = np.argsort(uncertainty_np)[-n_uncertainty:]
        
        all_coords.append(coords_cand[top_indices])
        all_intens.append(intens_cand[top_indices])
        all_labels.append(labels_cand[top_indices])
        all_boundary_weights.append(boundary_cand[top_indices])
    
    # Concatenate and shuffle
    coords = np.concatenate(all_coords, axis=0)
    intens = np.concatenate(all_intens, axis=0)
    labels = np.concatenate(all_labels, axis=0)
    boundary_weights = np.concatenate(all_boundary_weights, axis=0)
    
    indices = np.random.permutation(len(coords))
    coords = coords[indices]
    intens = intens[indices]
    labels = labels[indices]
    boundary_weights = boundary_weights[indices]
    
    # Normalize coordinates
    norm_coords = (coords / np.array([H-1, W-1, D-1], dtype=np.float32)) * 2.0 - 1.0
    
    return jnp.array(norm_coords), jnp.array(intens), jnp.array(labels), jnp.array(boundary_weights)


In [None]:
# =============================================================================
# Multi-Metric Evaluation (Research-Backed)
# =============================================================================

def predict_full_volume(params, case_data, chunk_size=50000):
    """Predict full volume segmentation"""
    mods = case_data['mods']
    seg_true = case_data['seg']
    M_vol, H_vol, W_vol, D_vol = mods.shape
    
    # Create coordinate grid
    xs, ys, zs = np.arange(H_vol), np.arange(W_vol), np.arange(D_vol)
    grid = np.stack(np.meshgrid(xs, ys, zs, indexing='ij'), axis=-1).reshape(-1, 3)
    intens = mods.transpose(1, 2, 3, 0).reshape(-1, M_vol)
    
    # Normalize coords
    norm_coords = (grid / np.array([H_vol-1, W_vol-1, D_vol-1])) * 2.0 - 1.0
    
    # Predict in chunks
    preds = []
    for i in range(0, len(grid), chunk_size):
        coords_chunk = jnp.array(norm_coords[i:i+chunk_size])
        intens_chunk = jnp.array(intens[i:i+chunk_size])
        
        fourier_feats = apply_fourier_features(params['fourier'], coords_chunk)
        logits = apply_coord_injection_mlp(
            params['mlp'],
            coords_chunk,
            intens_chunk,
            fourier_feats,
            training=False
        )
        pred_chunk = jnp.argmax(logits, axis=-1)
        preds.append(np.array(pred_chunk, dtype=np.int16))
    
    pred_flat = np.concatenate(preds, axis=0)
    pred_vol = pred_flat.reshape(H_vol, W_vol, D_vol)
    
    return pred_vol, seg_true

def compute_multi_metrics(pred, true, num_classes=4):
    """
    Compute multiple metrics:
    - Dice per class
    - Connected components count
    - Slice-to-slice consistency
    """
    metrics = {}
    
    # Dice per class
    for c in range(num_classes):
        pred_c = (pred == c)
        true_c = (true == c)
        intersection = (pred_c & true_c).sum()
        union = pred_c.sum() + true_c.sum()
        dice = (2 * intersection + 1e-6) / (union + 1e-6) if union > 0 else 0.0
        metrics[f'dice_class_{c}'] = dice
    
    # Connected components (tumor classes only)
    for c in range(1, num_classes):
        mask = (pred == c).astype(np.uint8)
        if mask.sum() > 0:
            labeled, n_components = connected_components_3d(mask)
            metrics[f'n_components_class_{c}'] = n_components
        else:
            metrics[f'n_components_class_{c}'] = 0
    
    # Slice-to-slice Dice consistency (middle axis)
    slice_dices = []
    for z in range(1, pred.shape[2]):
        pred_z = pred[:, :, z]
        pred_z_prev = pred[:, :, z-1]
        
        intersection = ((pred_z == pred_z_prev) & (pred_z > 0)).sum()
        union = ((pred_z > 0) | (pred_z_prev > 0)).sum()
        
        if union > 0:
            slice_dice = (2 * intersection) / (union + (pred_z > 0).sum() + (pred_z_prev > 0).sum() + 1e-6)
            slice_dices.append(slice_dice)
    
    metrics['slice_consistency'] = np.mean(slice_dices) if slice_dices else 0.0
    
    return metrics


In [None]:
# =============================================================================
# Optimizer with Two-Stage Schedule
# =============================================================================

def create_two_stage_schedule(stage1_steps, stage2_steps, lr_stage1, lr_stage2, warmup_steps, min_lr):
    """Two-stage learning rate schedule"""
    def schedule(step):
        # Warmup
        if step < warmup_steps:
            return lr_stage1 * (step / warmup_steps)
        
        # Stage 1: Higher LR for coarse learning
        if step < stage1_steps:
            progress = (step - warmup_steps) / (stage1_steps - warmup_steps)
            return lr_stage1 * (1.0 - progress) + lr_stage2 * progress
        
        # Stage 2: Lower LR for refinement
        progress = (step - stage1_steps) / stage2_steps
        return lr_stage2 * (1.0 - progress) + min_lr * progress
    
    return schedule

schedule_fn = create_two_stage_schedule(
    STAGE1_STEPS,
    STAGE2_STEPS,
    LR_STAGE1,
    LR_STAGE2,
    WARMUP_STEPS,
    MIN_LR
)

optimizer = optax.chain(
    optax.clip_by_global_norm(CLIP_NORM),
    optax.adamw(learning_rate=schedule_fn, weight_decay=WEIGHT_DECAY)
)

opt_state = optimizer.init(params)
print(f"✓ Two-stage optimizer initialized")

loss_and_grad = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))


In [None]:
# =============================================================================
# Training Step with Stochastic Preconditioning
# =============================================================================

def add_coordinate_noise(coords, sigma, rng_key):
    """Stochastic preconditioning: add Gaussian noise to coordinates"""
    noise = jax.random.normal(rng_key, coords.shape) * sigma
    return coords + noise

def microbatch_step(params, opt_state, rng_key, cache, step, use_uncertainty=True):
    """Single training step with gradient accumulation"""
    grads_acc = jax.tree.map(lambda x: jnp.zeros_like(x), params)
    loss_acc = 0.0
    aux_acc = jax.tree.map(lambda x: 0.0, loss_fn(params, 
                                                   jnp.zeros((1, 3)),
                                                   jnp.zeros((1, M)),
                                                   jnp.zeros(1, dtype=jnp.int32),
                                                   jnp.zeros(1),
                                                   training=False)[1])
    
    # Determine TV weight based on stage
    tv_weight = TV_WEIGHT_STAGE1 if step < STAGE1_STEPS else TV_WEIGHT_STAGE2
    
    # Determine coordinate noise (decays over training)
    progress = step / TOTAL_STEPS
    coord_sigma = COORD_NOISE_SIGMA_INIT * (1.0 - progress) + COORD_NOISE_SIGMA_FINAL * progress
    
    # Determine uncertainty ratio (increases in stage 2)
    uncertainty_ratio = UNCERTAINTY_RATIO if step < STAGE1_STEPS else 0.75
    
    # Only use dropout after warmup (disable during early training)
    use_dropout = (step > WARMUP_STEPS)
    
    key = rng_key
    for _ in range(ACCUM_STEPS):
        key, sample_key, noise_key, dropout_key = jax.random.split(key, 4)
        
        coords, feats, labels, boundary_weights = sample_batch_hybrid(
            sample_key,
            MICRO_BATCH_SIZE,
            cache,
            params,
            params['fourier'],
            uncertainty_ratio=uncertainty_ratio,
            balanced_ratio=BALANCED_RATIO,
            uniform_ratio=UNIFORM_RATIO,
            use_uncertainty=use_uncertainty
        )
        
        # Apply stochastic preconditioning
        coords = add_coordinate_noise(coords, coord_sigma, noise_key)
        
        (loss_val, aux), grads = loss_and_grad(
            params,
            coords,
            feats,
            labels,
            boundary_weights,
            training=use_dropout,  # Only train with dropout after warmup
            dropout_key=dropout_key if use_dropout else None,
            tv_weight=tv_weight
        )
        
        loss_acc += float(loss_val)
        grads_acc = jax.tree.map(lambda acc, g: acc + g, grads_acc, grads)
        aux_acc = jax.tree.map(lambda acc, a: acc + a, aux_acc, aux)
    
    # Average gradients
    grads_mean = jax.tree.map(lambda g: g / ACCUM_STEPS, grads_acc)
    updates, opt_state = optimizer.update(grads_mean, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    # Average aux
    aux_mean = jax.tree.map(lambda a: a / ACCUM_STEPS, aux_acc)
    
    return params, opt_state, loss_acc / ACCUM_STEPS, aux_mean

# Warm-up
print("Running warm-up...")
jax_key, warm_key = jax.random.split(jax_key)
params, opt_state, warm_loss, warm_aux = microbatch_step(
    params,
    opt_state,
    warm_key,
    train_cache,
    step=0,
    use_uncertainty=False  # No uncertainty in warm-up
)
print(f'Warm-up loss: {warm_loss:.4f}')


In [None]:
# =============================================================================
# Training Loop with Multi-Metric Evaluation
# =============================================================================

loss_history = []
dice_history = [[] for _ in range(NUM_CLASSES)]
dice_tumor_history = []
ufl_history = []

start = time.time()
mid_z = D // 2

CLASS_LABELS = {0: 'Background', 1: 'NCR/NET', 2: 'ED', 3: 'ET'}

# Visualization setup
vis_cache = val_cache if val_cache else train_cache
VIS_CASE_INDEX = 0
true_slice = vis_cache.cache[VIS_CASE_INDEX]['seg'][:, :, mid_z]
mod0_slice = vis_cache.cache[VIS_CASE_INDEX]['mods'][0, :, :, mid_z]

print(f"\n{'='*80}")
print(f"Starting Two-Stage Training - {TOTAL_STEPS} steps")
print(f"Stage 1 (Coarse): Steps 1-{STAGE1_STEPS}")
print(f"Stage 2 (Refine): Steps {STAGE1_STEPS+1}-{TOTAL_STEPS}")
print(f"{'='*80}\n")

best_val_dice = 0.0

for step in range(1, TOTAL_STEPS + 1):
    # Rotate chunks
    if step % STEPS_PER_CHUNK == 0 and step > 0:
        print(f"\n--- Rotating to next data chunk at step {step} ---")
        train_cache.next_chunk()
        if val_cache:
            val_cache.next_chunk()
            # Update vis
            true_slice = val_cache.cache[VIS_CASE_INDEX]['seg'][:, :, mid_z]
            mod0_slice = val_cache.cache[VIS_CASE_INDEX]['mods'][0, :, :, mid_z]
    
    # Training step
    use_uncertainty = (step > WARMUP_STEPS)  # Enable uncertainty after warmup
    
    jax_key, step_key = jax.random.split(jax_key)
    params, opt_state, loss_val, aux = microbatch_step(
        params,
        opt_state,
        step_key,
        train_cache,
        step,
        use_uncertainty=use_uncertainty
    )
    
    # Track metrics
    loss_history.append(float(loss_val))
    ufl_history.append(float(aux['ufl']))
    dice_k = aux['dice_per_class']
    dice_tumor_history.append(float(aux['dice_mean_tumor']))
    
    for k in range(NUM_CLASSES):
        dice_history[k].append(float(dice_k[k]))
    
    # W&B logging
    current_lr = schedule_fn(step)
    stage = "Stage1_Coarse" if step <= STAGE1_STEPS else "Stage2_Refine"
    
    wandb_metrics = {
        "train/loss": float(loss_val),
        "train/ufl": float(aux['ufl']),
        "train/mFTL": float(aux['mFTL']),
        "train/mFL": float(aux['mFL']),
        "train/dice_loss": float(aux['dice_loss']),
        "train/tv_loss": float(aux['tv_loss']),
        "train/boundary_loss": float(aux['boundary_loss']),
        "train/dice_mean_tumor": float(aux['dice_mean_tumor']),
        "train/accuracy": float(aux['accuracy']),
        "train/lr": float(current_lr),
        "train/stage": 1 if step <= STAGE1_STEPS else 2,
    }
    
    for k in range(NUM_CLASSES):
        wandb_metrics[f"train/dice_class_{k}_{CLASS_LABELS[k]}"] = float(dice_k[k])
    
    wandb.log(wandb_metrics, step=step)
    
    # Visualization every 100 steps
    if step % 10 == 0 or step == 1:
        # Predict slice
        slice_coords = []
        slice_feats = []
        
        for x in range(H):
            for y in range(W):
                slice_coords.append([x, y, mid_z])
                slice_feats.append(vis_cache.cache[VIS_CASE_INDEX]['mods'][:, x, y, mid_z])
        
        slice_coords = np.array(slice_coords, dtype=np.float32)
        slice_feats = np.stack(slice_feats, axis=0)
        norm_slice_coords = (slice_coords / np.array([H-1, W-1, D-1])) * 2.0 - 1.0
        
        fourier_feats = apply_fourier_features(params['fourier'], jnp.array(norm_slice_coords))
        logits = apply_coord_injection_mlp(
            params['mlp'],
            jnp.array(norm_slice_coords),
            jnp.array(slice_feats),
            fourier_feats,
            training=False
        )
        pred_slice = np.array(jnp.argmax(logits, axis=-1)).reshape(H, W)
        
        clear_output(wait=True)
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # Loss curves
        ax0 = axes[0]
        ax0.plot(loss_history, 'b-', label='Total Loss', alpha=0.7, linewidth=1)
        ax0.plot(ufl_history, 'r--', label='UFL', alpha=0.5, linewidth=1)
        ax0.axvline(STAGE1_STEPS, color='green', linestyle='--', alpha=0.5, label='Stage 2 Start')
        ax0.set_title('Training Loss', fontweight='bold')
        ax0.set_xlabel('Step')
        ax0.set_ylabel('Loss')
        ax0.legend()
        ax0.grid(True, alpha=0.3)
        
        # Per-class Dice
        ax1 = axes[0].twinx()
        colors = plt.cm.tab10.colors
        for k in range(NUM_CLASSES):
            label = CLASS_LABELS.get(k, f'c{k}')
            ax1.plot(dice_history[k], label=f'{label}',
                    color=colors[k], linewidth=2, alpha=0.7)
        ax1.plot(dice_tumor_history, label='Tumor Avg',
                color='black', linewidth=3, linestyle='--')
        ax1.set_ylabel('Dice Score')
        ax1.legend(loc='lower right')
        ax1.set_ylim([0, 1])
        
        # Ground truth
        axes[1].imshow(mod0_slice, cmap='gray')
        axes[1].imshow(true_slice, alpha=0.4, cmap='tab10', vmin=0, vmax=3)
        axes[1].set_title('Ground Truth', fontweight='bold')
        axes[1].axis('off')
        
        # Prediction
        axes[2].imshow(mod0_slice, cmap='gray')
        axes[2].imshow(pred_slice, alpha=0.4, cmap='tab10', vmin=0, vmax=3)
        axes[2].set_title(f'Prediction (Step {step}) - {stage}', fontweight='bold')
        axes[2].axis('off')
        
        dice_tumor = aux['dice_mean_tumor']
        fig.suptitle(
            f'Step {step}/{TOTAL_STEPS} | {stage} | Loss={loss_val:.4f} | '
            f'Dice(tumor)={dice_tumor:.3f} | LR={current_lr:.2e}',
            fontsize=14, fontweight='bold'
        )
        plt.tight_layout()
        display(fig)
        
        if step % 500 == 0 or step == 1:
            wandb.log({"train/predictions": wandb.Image(fig)}, step=step)
        
        plt.close(fig)
    
    # Full validation evaluation every VAL_EVAL_STEPS
    if val_cache and (step % VAL_EVAL_STEPS == 0 or step == TOTAL_STEPS):
        print(f"\n--- Running full validation at step {step} ---")
        
        n_val_cases = min(5, len(val_cache.cache))
        val_metrics_all = []
        
        for val_idx in range(n_val_cases):
            pred_vol, true_vol = predict_full_volume(params, val_cache.cache[val_idx], chunk_size=80000)
            metrics = compute_multi_metrics(pred_vol, true_vol, NUM_CLASSES)
            val_metrics_all.append(metrics)
        
        # Aggregate
        val_dice_tumor = np.mean([m['dice_class_1'] + m['dice_class_2'] + m['dice_class_3'] 
                                   for m in val_metrics_all]) / 3
        
        print(f"Validation Dice (tumor avg): {val_dice_tumor:.4f}")
        
        # Log to W&B
        val_wandb = {
            "val/dice_tumor_avg": val_dice_tumor,
        }
        
        for k in range(NUM_CLASSES):
            dice_k = np.mean([m[f'dice_class_{k}'] for m in val_metrics_all])
            val_wandb[f"val/dice_class_{k}"] = dice_k
            print(f"  Class {k} ({CLASS_LABELS[k]}): {dice_k:.4f}")
        
        for k in range(1, NUM_CLASSES):
            n_comp = np.mean([m[f'n_components_class_{k}'] for m in val_metrics_all])
            val_wandb[f"val/n_components_class_{k}"] = n_comp
            print(f"  Components class {k}: {n_comp:.1f}")
        
        slice_cons = np.mean([m['slice_consistency'] for m in val_metrics_all])
        val_wandb["val/slice_consistency"] = slice_cons
        print(f"  Slice consistency: {slice_cons:.4f}")
        
        wandb.log(val_wandb, step=step)
        
        # Save best model
        if val_dice_tumor > best_val_dice:
            best_val_dice = val_dice_tumor
            print(f"  ✓ New best validation Dice: {best_val_dice:.4f}")
            
            # Save checkpoint
            flat_params = {}
            flat_params['fourier_B'] = np.array(params['fourier']['B'])
            for i, layer in enumerate(params['mlp']):
                flat_params[f'mlp_W_{i}'] = np.array(layer['W'])
                flat_params[f'mlp_b_{i}'] = np.array(layer['b'])
            
            SAVE_PATH.mkdir(parents=True, exist_ok=True)
            np.savez_compressed(SAVE_PATH / f"best_model_dice{best_val_dice:.4f}.npz", **flat_params)
            
            wandb.run.summary["best_val_dice_tumor"] = best_val_dice
    
    # Console progress
    if step % 1 == 0 or step == 1:
        elapsed = time.time() - start
        steps_per_sec = step / elapsed
        eta_min = (TOTAL_STEPS - step) / steps_per_sec / 60
        
        print(f"{stage} | Step {step}/{TOTAL_STEPS} | "
              f"Loss={loss_val:.4f} | "
              f"Dice(tumor)={aux['dice_mean_tumor']:.3f} | "
              f"LR={current_lr:.2e} | "
              f"{steps_per_sec:.2f} steps/s | "
              f"ETA: {eta_min:.1f}min")

training_time = time.time() - start
print(f'\n✓ Training complete: {training_time/60:.2f} minutes')
wandb.run.summary["training_time_minutes"] = training_time / 60
wandb.run.summary["final_dice_tumor"] = dice_tumor_history[-1]


In [None]:
# =============================================================================
# Save Final Model
# =============================================================================
flat_params = {}
flat_params['fourier_B'] = np.array(params['fourier']['B'])
for i, layer in enumerate(params['mlp']):
    flat_params[f'mlp_W_{i}'] = np.array(layer['W'])
    flat_params[f'mlp_b_{i}'] = np.array(layer['b'])

SAVE_PATH.mkdir(parents=True, exist_ok=True)
np.savez_compressed(SAVE_PATH / "model_final.npz", **flat_params)
print(f'\n✓ Saved final model to {SAVE_PATH}')

wandb.finish()
print(f"\n{'='*80}")
print("Training Complete!")
print(f"Best Validation Dice (Tumor): {best_val_dice:.4f}")
print(f"{'='*80}")