This pipeline includes:Multi-seed (≥3, ideally 5)

You’re running 5 seeds: SEEDS = [42, 1337, 2026, 7, 99]

You also export:

per-seed metrics

embc_paper_master_summary_all_seeds.csv

across-seed aggregated embc_summary_across_seeds.csv (mean/std across seeds)

ASSD + HD95

Added as boundary-distance metrics: ASSD_px, HD95_px

Computed from boundary pixel sets using distance transforms (standard approach)

Weighted boundary loss

Computes a pos_weight from training boundary targets (estimate_pos_weight)

Uses weighted cross entropy (logits) + Dice via boundary_loss_factory(pos_w)

This is exactly what you need to prevent the “predict-everything / predict-nothing” instability in sparse boundaries.

Threshold tuned to tolerant-F1

find_best_global_threshold() now tunes threshold based on mean tolerant F1 across [1,2,3,5] px, not Dice.

Same for UNet threshold tuning.

In [None]:


!pip install -q nibabel scikit-image pandas scipy

import os, zipfile, glob, random, math, shutil, json
import numpy as np
import tensorflow as tf
import nibabel as nib
import matplotlib.pyplot as plt
import pandas as pd
from google.colab import files
from scipy.ndimage import distance_transform_edt, binary_erosion

IMG_SIZE = (256, 256)
BATCH_SIZE = 8
LR = 1e-3

# multi-seed
SEEDS = [42, 1337, 2026, 7, 99]

PRETRAIN_SOBEL_EPOCHS = 2
FINETUNE_BOUNDARY_EPOCHS = 8
FINETUNE_UNET_EPOCHS = 8

RUN_SMALL_CNN_BASELINE   = True
RUN_SOBELFRONT_CNN       = True
RUN_TINY_UNET_SEG_BASE   = True  # trains on mask; evaluated as boundary

MAX_SLICES_PER_VOLUME = 6

TRAIN_FRAC = 0.70
VAL_FRAC   = 0.15
# remainder = TEST_LABELED

# Eval + figs
EVAL_LIMIT = None    # e.g., 200 for fast debug
VIS_K = 3
THRESHOLDS = [0.05,0.1,0.15,0.2,0.25,0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9]
TOL_PX_LIST = [1,2,3,5]

# Boundary target / imbalance handling
SKIP_EMPTY_MASK_SLICES = True
BOUNDARY_THICKNESS = 2

# File formats
VALID_2D_EXTS = {".png",".jpg",".jpeg",".bmp",".tif",".tiff"}
MASK_TOKENS = ["_mask", "mask", "seg", "label", "_gt", "groundtruth"]

# Drive / BraTS
USE_GOOGLE_DRIVE = True
BRATS_ZIP_PATH = "/content/drive/MyDrive/Brats Dataset.zip"
BRATS_EXTRACT_DIR = "/content/brats_extracted"
BRATS_TRAIN_DIR = None
BRATS_VAL_DIR   = None
EXTRA_DATASET_DIRS = {
  # "BUSI": "/content/drive/MyDrive/BUSI",
}

# Reproducibility helpers
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    tf.keras.utils.set_random_seed(seed)
    try:
        tf.config.experimental.enable_op_determinism()
    except:
        pass

plt.rc('figure', autolayout=True)
plt.rc('image', cmap='gray')

print("TensorFlow:", tf.__version__)
print("IMG_SIZE:", IMG_SIZE)
# detection + pairing
def is_nii(path):
    low = path.lower()
    return low.endswith(".nii") or low.endswith(".nii.gz")

def ext(path):
    return os.path.splitext(path)[1].lower()

def is_2d_image(path):
    return ext(path) in VALID_2D_EXTS

def is_mask_file(path):
    name = os.path.basename(path).lower()
    return any(tok in name for tok in MASK_TOKENS)

def strip_nii_ext(name):
    name = name.lower()
    if name.endswith(".nii.gz"):
        return name[:-7]
    if name.endswith(".nii"):
        return name[:-4]
    return os.path.splitext(name)[0].lower()

def norm_key_2d(path):
    base = os.path.splitext(os.path.basename(path).lower())[0]
    for tok in MASK_TOKENS:
        base = base.replace(tok, "")
    base = base.replace("-", "_").replace(" ", "_")
    base = base.replace("__", "_").strip("_")
    return base

MOD_SUFFIX = {"flair","t1","t1ce","t2","seg","mask","label","gt","groundtruth"}
def norm_key_nii(path):
    base = strip_nii_ext(os.path.basename(path))
    parts = base.split("_")
    if len(parts) > 1 and parts[-1] in MOD_SUFFIX:
        parts = parts[:-1]
    return "_".join(parts)

def group_id(desc):
    # patient-level grouping for BraTS (nii) and best-effort grouping for 2D
    if desc[0] == "nii":
        return norm_key_nii(desc[1])
    if desc[0] == "2d":
        # if BUSI folders store per-patient, you can use parent folder
        # return os.path.basename(os.path.dirname(desc[1]))
        return norm_key_2d(desc[1])
    return "unknown"

def safe_read_2d_as_tensor(path):
    try:
        raw = tf.io.read_file(path)
        img = tf.io.decode_image(raw, channels=1, expand_animations=False)
        img = tf.image.resize(img, IMG_SIZE)
        img = tf.image.convert_image_dtype(img, tf.float32)
        if img.shape[-1] != 1:
            return None
        return img
    except:
        return None

def resize_mask_nearest(mask01):
    return tf.image.resize(mask01, IMG_SIZE, method="nearest")

def scan_files(root):
    all_files = [f for f in glob.glob(os.path.join(root, "**/*"), recursive=True) if os.path.isfile(f)]
    img2d = [f for f in all_files if is_2d_image(f)]
    nii   = [f for f in all_files if is_nii(f)]
    return all_files, img2d, nii

def build_index_for_dataset(root):
    """
    samples:
      ("2d", img_path, mask_path_or_None)
      ("nii", img_vol_path, seg_vol_path_or_None, z_index)
    """
    _, img2d_all, nii_all = scan_files(root)
    samples, has_masks = [], False

    # 2D pairing
    if len(img2d_all) > 0:
        imgs  = [p for p in img2d_all if not is_mask_file(p)]
        masks = [p for p in img2d_all if is_mask_file(p)]

        mask_map = {}
        for m in masks:
            mask_map.setdefault(norm_key_2d(m), []).append(m)

        for im in imgs:
            key = norm_key_2d(im)
            mlist = mask_map.get(key, [])
            mpath = mlist[0] if len(mlist) > 0 else None
            if mpath is not None:
                has_masks = True
            samples.append(("2d", im, mpath))

    # NIfTI pairing (BraTS-like)
    if len(nii_all) > 0:
        nii_imgs  = [p for p in nii_all if not is_mask_file(p)]
        nii_masks = [p for p in nii_all if is_mask_file(p)]
        mask_map = {norm_key_nii(m): m for m in nii_masks}

        flair = [f for f in nii_imgs if "flair" in os.path.basename(f).lower()]
        use_inputs = flair if len(flair) > 0 else nii_imgs

        for vf in use_inputs:
            seg = mask_map.get(norm_key_nii(vf), None)
            if seg is not None:
                has_masks = True

            try:
                img = nib.load(vf)
                shp = img.shape
                if len(shp) < 3:
                    continue
                zdim = shp[2]
                zs = np.linspace(zdim*0.3, zdim*0.7, MAX_SLICES_PER_VOLUME).astype(int)
                zs = np.unique(np.clip(zs, 0, zdim-1))
                for z in zs:
                    samples.append(("nii", vf, seg, int(z)))
            except:
                continue

    return samples, has_masks

# ----------------------------
# TARGETS
# ----------------------------
def sobel_target(img01):
    img01 = tf.convert_to_tensor(img01, dtype=tf.float32)
    sob = tf.image.sobel_edges(img01[None])[0]  # (H,W,1,2)
    gx, gy = sob[...,0], sob[...,1]
    mag = tf.sqrt(gx**2 + gy**2)
    mag = mag / (tf.reduce_max(mag) + 1e-8)
    return mag

def mask_boundary_target(mask01):
    """
    boundary = dilation(mask) - erosion(mask), then thicken.
    """
    m = tf.cast(mask01 > 0.5, tf.float32)
    dil = tf.nn.max_pool2d(m[None], ksize=3, strides=1, padding="SAME")[0]
    ero = 1.0 - tf.nn.max_pool2d((1.0 - m)[None], ksize=3, strides=1, padding="SAME")[0]
    b = tf.clip_by_value(dil - ero, 0.0, 1.0)

    if BOUNDARY_THICKNESS > 1:
        for _ in range(BOUNDARY_THICKNESS - 1):
            b = tf.nn.max_pool2d(b[None], ksize=3, strides=1, padding="SAME")[0]
        b = tf.clip_by_value(b, 0.0, 1.0)

    return b

# FIXED SOBEL baseline
def fixed_sobel_conv(img01):
    img01 = tf.convert_to_tensor(img01, dtype=tf.float32)
    kx = tf.constant([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=tf.float32)
    ky = tf.constant([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=tf.float32)
    kx = tf.reshape(kx, (3,3,1,1))
    ky = tf.reshape(ky, (3,3,1,1))
    gx = tf.nn.conv2d(img01[None], kx, strides=1, padding="SAME")[0]
    gy = tf.nn.conv2d(img01[None], ky, strides=1, padding="SAME")[0]
    mag = tf.sqrt(gx**2 + gy**2)
    mag = mag / (tf.reduce_max(mag) + 1e-8)
    return mag

def load_sample(desc, target_mode="auto"):
    """
    target_mode:
      - "sobel":    target = sobel(input)
      - "boundary": target = boundary(mask) (requires mask)
      - "mask":     target = mask (requires mask)
      - "auto":     boundary if mask exists else sobel
    """
    kind = desc[0]

    if kind == "2d":
        img_path, mask_path = desc[1], desc[2]
        img = safe_read_2d_as_tensor(img_path)
        if img is None:
            return None
        img01 = img

        use_mask = (mask_path is not None)
        mode = target_mode if target_mode != "auto" else ("boundary" if use_mask else "sobel")

        if mode == "sobel":
            return img01, sobel_target(img01)

        if mode in ["boundary","mask"]:
            if not use_mask:
                return None
            m = safe_read_2d_as_tensor(mask_path)
            if m is None:
                return None
            m = resize_mask_nearest(m)
            m = tf.cast(m > 0.5, tf.float32)
            if SKIP_EMPTY_MASK_SLICES and tf.reduce_max(m).numpy() == 0.0:
                return None
            if mode == "mask":
                return img01, m
            return img01, mask_boundary_target(m)

        return None

    if kind == "nii":
        img_vol, seg_vol, z = desc[1], desc[2], desc[3]
        try:
            img = nib.load(img_vol)
            data = img.dataobj
            shp = img.shape
            sl = np.asanyarray(data[:, :, z]).astype(np.float32) if len(shp) == 3 else np.asanyarray(data[:, :, z, 0]).astype(np.float32)
            mn, mx = float(np.min(sl)), float(np.max(sl))
            sl = (sl - mn) / (mx - mn + 1e-8)

            img01 = tf.convert_to_tensor(sl[..., None], dtype=tf.float32)
            img01 = tf.image.resize(img01, IMG_SIZE)

            use_mask = (seg_vol is not None)
            mode = target_mode if target_mode != "auto" else ("boundary" if use_mask else "sobel")

            if mode == "sobel":
                return img01, sobel_target(img01)

            if mode in ["boundary","mask"]:
                if not use_mask:
                    return None
                seg = nib.load(seg_vol)
                segd = seg.dataobj
                sshape = seg.shape
                msl = np.asanyarray(segd[:, :, z]).astype(np.float32) if len(sshape) == 3 else np.asanyarray(segd[:, :, z, 0]).astype(np.float32)

                msl = (msl > 0).astype(np.float32)  # BraTS tumor mask
                if SKIP_EMPTY_MASK_SLICES and float(np.max(msl)) == 0.0:
                    return None

                m01 = tf.convert_to_tensor(msl[..., None], dtype=tf.float32)
                m01 = resize_mask_nearest(m01)
                if mode == "mask":
                    return img01, m01
                return img01, mask_boundary_target(m01)

        except:
            return None

    return None

def load_input_only(desc):
    kind = desc[0]
    if kind == "2d":
        return safe_read_2d_as_tensor(desc[1])
    if kind == "nii":
        img_vol, z = desc[1], desc[3]
        try:
            img = nib.load(img_vol)
            data = img.dataobj
            shp = img.shape
            sl = np.asanyarray(data[:, :, z]).astype(np.float32) if len(shp) == 3 else np.asanyarray(data[:, :, z, 0]).astype(np.float32)
            mn, mx = float(np.min(sl)), float(np.max(sl))
            sl = (sl - mn) / (mx - mn + 1e-8)
            img01 = tf.convert_to_tensor(sl[..., None], dtype=tf.float32)
            img01 = tf.image.resize(img01, IMG_SIZE)
            return img01
        except:
            return None
    return None

def make_tf_dataset(samples, target_mode="auto", shuffle=True, repeat=False, seed=42):
    def gen():
        for d in samples:
            out = load_sample(d, target_mode=target_mode)
            if out is None:
                continue
            yield out

    ds = tf.data.Dataset.from_generator(
        gen,
        output_signature=(
            tf.TensorSpec(shape=(IMG_SIZE[0], IMG_SIZE[1], 1), dtype=tf.float32),
            tf.TensorSpec(shape=(IMG_SIZE[0], IMG_SIZE[1], 1), dtype=tf.float32),
        )
    )
    if shuffle:
        ds = ds.shuffle(256, seed=seed, reshuffle_each_iteration=True)
    ds = ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    if repeat:
        ds = ds.repeat()
    return ds

def estimate_steps(samples):
    return max(1, math.ceil(len(samples) / BATCH_SIZE))

def split_by_group(samples, train_frac=0.70, val_frac=0.15, seed=42):
    rng = random.Random(seed)
    groups = {}
    for s in samples:
        gid = group_id(s)
        groups.setdefault(gid, []).append(s)
    gids = list(groups.keys())
    rng.shuffle(gids)

    n = len(gids)
    n_train = int(train_frac * n)
    n_val   = int(val_frac * n)

    g_train = set(gids[:n_train])
    g_val   = set(gids[n_train:n_train+n_val])
    g_test  = set(gids[n_train+n_val:])

    train = [s for gid in g_train for s in groups[gid]]
    val   = [s for gid in g_val   for s in groups[gid]]
    test  = [s for gid in g_test  for s in groups[gid]]
    return train, val, test


# MODELS
def sobel_init_weights():
    gx = np.array([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=np.float32)
    gy = np.array([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=np.float32)
    w = np.zeros((3,3,1,2), dtype=np.float32)
    w[:,:,0,0] = gx
    w[:,:,0,1] = gy
    return w

class LearnableSobelEdge(tf.keras.Model):
    """
    Minimal learnable(boundary) baseline:
      Conv(2,3x3) Sobel-init -> magnitude -> (optional per-image norm)
      -> alpha/beta -> sigmoid (during boundary training)
    """
    def __init__(self):
        super().__init__()
        self.conv = tf.keras.layers.Conv2D(
            filters=2, kernel_size=3, padding="same", use_bias=False,
            kernel_initializer=tf.keras.initializers.Constant(sobel_init_weights())
        )
        self.alpha = self.add_weight(name="alpha", shape=(), initializer="ones", trainable=True, dtype=tf.float32)
        self.beta  = self.add_weight(name="beta",  shape=(), initializer="zeros", trainable=True, dtype=tf.float32)
        self.use_per_image_norm = True
        self.use_sigmoid = False

    def call(self, x, training=False):
        g = self.conv(x)
        gx = g[:,:,:,0:1]
        gy = g[:,:,:,1:2]
        mag = tf.sqrt(gx*gx + gy*gy + 1e-8)

        if self.use_per_image_norm:
            mx = tf.reduce_max(tf.abs(mag), axis=[1,2,3], keepdims=True)
            mag = mag / (mx + 1e-8)

        mag = self.alpha * mag + self.beta

        if self.use_sigmoid:
            return tf.sigmoid(mag)
        return mag

def build_learnable_model(loss_obj):
    m = LearnableSobelEdge()
    m.compile(optimizer=tf.keras.optimizers.Adam(LR), loss=loss_obj)
    return m

def build_small_cnn_boundary(loss_obj):
    inp = tf.keras.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 1))
    x = tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu")(inp)
    x = tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu")(x)
    x = tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu")(x)
    out = tf.keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(x)
    m = tf.keras.Model(inp, out, name="SmallCNN_Boundary")
    m.compile(optimizer=tf.keras.optimizers.Adam(LR), loss=loss_obj)
    return m

def build_sobelfront_cnn(loss_obj):
    """
    Sobel-init front-end + small CNN head (Keras 3 safe).
    """
    inp = tf.keras.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 1))
    sob = tf.keras.layers.Conv2D(
        filters=2, kernel_size=3, padding="same", use_bias=False,
        kernel_initializer=tf.keras.initializers.Constant(sobel_init_weights()),
        name="SobelInitConv"
    )(inp)

    gx = tf.keras.layers.Lambda(lambda t: t[..., 0:1], name="gx")(sob)
    gy = tf.keras.layers.Lambda(lambda t: t[..., 1:2], name="gy")(sob)
    mag = tf.keras.layers.Lambda(lambda ab: tf.sqrt(ab[0]*ab[0] + ab[1]*ab[1] + 1e-8), name="mag")([gx, gy])

    x = tf.keras.layers.Concatenate(name="stack_feats")([inp, mag, gx, gy])
    x = tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu")(x)
    x = tf.keras.layers.Conv2D(16, 3, padding="same", activation="relu")(x)
    out = tf.keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(x)

    m = tf.keras.Model(inp, out, name="SobelFrontCNN")
    m.compile(optimizer=tf.keras.optimizers.Adam(LR), loss=loss_obj)
    return m

def conv_block(x, f):
    x = tf.keras.layers.Conv2D(f, 3, padding="same", activation="relu")(x)
    x = tf.keras.layers.Conv2D(f, 3, padding="same", activation="relu")(x)
    return x

def build_tiny_unet_seg(loss_obj):
    """
    Tiny U-Net predicting mask (sigmoid). We'll convert its output to boundary for evaluation.
    """
    inp = tf.keras.Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 1))
    c1 = conv_block(inp, 16)
    p1 = tf.keras.layers.MaxPool2D()(c1)

    c2 = conv_block(p1, 32)
    p2 = tf.keras.layers.MaxPool2D()(c2)

    b  = conv_block(p2, 64)

    u2 = tf.keras.layers.UpSampling2D()(b)
    u2 = tf.keras.layers.Concatenate()([u2, c2])
    c3 = conv_block(u2, 32)

    u1 = tf.keras.layers.UpSampling2D()(c3)
    u1 = tf.keras.layers.Concatenate()([u1, c1])
    c4 = conv_block(u1, 16)

    out = tf.keras.layers.Conv2D(1, 1, padding="same", activation="sigmoid")(c4)
    m = tf.keras.Model(inp, out, name="TinyUNet_Seg")
    m.compile(optimizer=tf.keras.optimizers.Adam(LR), loss=loss_obj)
    return m


# LOSSES
bce = tf.keras.losses.BinaryCrossentropy()

def soft_dice_coef(y_true, y_pred, eps=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    y_true_f = tf.reshape(y_true, [tf.shape(y_true)[0], -1])
    y_pred_f = tf.reshape(y_pred, [tf.shape(y_pred)[0], -1])
    inter = tf.reduce_sum(y_true_f * y_pred_f, axis=1)
    denom = tf.reduce_sum(y_true_f + y_pred_f, axis=1)
    dice = (2.0 * inter + eps) / (denom + eps)
    return tf.reduce_mean(dice)

def dice_loss(y_true, y_pred):
    return 1.0 - soft_dice_coef(y_true, y_pred)

def seg_loss(y_true, y_pred):
    return bce(y_true, y_pred) + dice_loss(y_true, y_pred)

def sobel_mse_loss(y_true, y_pred):
    return tf.reduce_mean(tf.square(y_true - y_pred))

# Weighted BCE (imbalance)
def estimate_pos_weight(samples, limit=300):
    pos, neg, count = 0, 0, 0
    for d in samples:
        out = load_sample(d, target_mode="boundary")
        if out is None:
            continue
        _, gt = out
        g = gt.numpy().astype(np.uint8)
        pos += int(g.sum())
        neg += int(g.size - g.sum())
        count += 1
        if limit and count >= limit:
            break
    pw = float(neg / (pos + 1e-9))
    pw = max(1.0, min(50.0, pw))  # clamp to avoid instability
    return pw

def weighted_bce_with_logits(pos_weight):
    def loss(y_true, y_pred):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.clip_by_value(tf.cast(y_pred, tf.float32), 1e-6, 1.0 - 1e-6)
        logits = tf.math.log(y_pred / (1.0 - y_pred))
        return tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(
            labels=y_true, logits=logits, pos_weight=pos_weight
        ))
    return loss

def boundary_loss_factory(pos_weight):
    wbce = weighted_bce_with_logits(pos_weight)
    def loss(y_true, y_pred):
        return wbce(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss


# METRICS
def bin_metrics(y_true01, y_pred01, thr=0.5, eps=1e-9):
    yt = (y_true01 >= 0.5).astype(np.uint8)
    yp = (y_pred01 >= thr).astype(np.uint8)
    tp = int(np.sum((yt==1) & (yp==1)))
    fp = int(np.sum((yt==0) & (yp==1)))
    fn = int(np.sum((yt==1) & (yp==0)))
    dice = (2*tp) / (2*tp + fp + fn + eps)
    prec = tp / (tp + fp + eps)
    rec  = tp / (tp + fn + eps)
    return dice, prec, rec

def tolerant_f1(gt_b, pr_b, tol_px, eps=1e-9):
    """
    Tolerant boundary F1:
      pred pixel is correct if within tol_px of any GT pixel; and vice versa.
    """
    gt = (gt_b > 0).astype(bool)
    pr = (pr_b > 0).astype(bool)

    if gt.sum() == 0 and pr.sum() == 0:
        return 1.0
    if gt.sum() == 0 and pr.sum() > 0:
        return 0.0
    if gt.sum() > 0 and pr.sum() == 0:
        return 0.0

    dist_to_gt = distance_transform_edt(~gt)
    dist_to_pr = distance_transform_edt(~pr)

    tp_pred = np.sum(pr & (dist_to_gt <= tol_px))
    tp_gt   = np.sum(gt & (dist_to_pr <= tol_px))

    prec = tp_pred / (pr.sum() + eps)
    rec  = tp_gt   / (gt.sum() + eps)
    f1 = (2*prec*rec) / (prec+rec+eps)
    return float(f1)

# ASSD + HD95 boundary metrics (pixels)
def boundary_map(bin_mask):
    m = bin_mask.astype(bool)
    if m.sum() == 0:
        return np.zeros_like(m, dtype=bool)
    er = binary_erosion(m)
    return m ^ er

def assd_hd95(gt_bin, pr_bin):
    """
    ASSD and HD95 computed on boundary pixels (pixel units).
    """
    gt_b = boundary_map(gt_bin)
    pr_b = boundary_map(pr_bin)

    if gt_b.sum() == 0 and pr_b.sum() == 0:
        return 0.0, 0.0
    if gt_b.sum() == 0 or pr_b.sum() == 0:
        return float("inf"), float("inf")

    dt_gt = distance_transform_edt(~gt_b)
    dt_pr = distance_transform_edt(~pr_b)

    d_pr_to_gt = dt_gt[pr_b]
    d_gt_to_pr = dt_pr[gt_b]

    all_d = np.concatenate([d_pr_to_gt, d_gt_to_pr])
    assd = float(all_d.mean())
    hd95 = float(np.percentile(all_d, 95))
    return assd, hd95

# threshold selection- tune on mean tolerant-F1 across TOL_PX_LIST (not Dice)
def find_best_global_threshold(predict_fn, val_samples, thresholds, limit=None):
    best_thr, best_mean = 0.5, -1.0
    for t in thresholds:
        scores = []
        count = 0
        for d in val_samples:
            out = load_sample(d, target_mode="boundary")
            if out is None:
                continue
            img01, gt = out
            pr = predict_fn(img01).numpy().squeeze().astype(np.float32)
            gt = gt.numpy().squeeze().astype(np.float32)
            pr = np.clip(pr, 0, 1)

            pr_bin = (pr >= t).astype(np.uint8)
            gt_bin = (gt >= 0.5).astype(np.uint8)

            f = float(np.mean([tolerant_f1(gt_bin, pr_bin, tol_px=px) for px in TOL_PX_LIST]))
            scores.append(f)

            count += 1
            if limit and count >= limit:
                break

        mean_s = float(np.mean(scores)) if len(scores) else 0.0
        if mean_s > best_mean:
            best_mean = mean_s
            best_thr = t
    return best_thr, best_mean

# UNet threshold selection: also tune on mean tolerant-F1 across TOL_PX_LIST
def find_best_threshold_for_unet_boundary(unet_predict_mask_fn, val_samples, thresholds, limit=None):
    best_thr, best_mean = 0.5, -1.0
    for t in thresholds:
        scores = []
        count = 0
        for d in val_samples:
            out_b = load_sample(d, target_mode="boundary")
            out_m = load_sample(d, target_mode="mask")
            if out_b is None or out_m is None:
                continue
            img01, gt_b = out_b

            pr_m = unet_predict_mask_fn(img01).numpy().squeeze()
            pr_m = (pr_m >= t).astype(np.float32)

            pr_m_tf = tf.convert_to_tensor(pr_m[..., None], dtype=tf.float32)
            pr_b = mask_boundary_target(pr_m_tf).numpy().squeeze().astype(np.float32)

            gt_b_np = gt_b.numpy().squeeze().astype(np.float32)
            gt_bin = (gt_b_np >= 0.5).astype(np.uint8)
            pr_bin = (pr_b >= 0.5).astype(np.uint8)

            f = float(np.mean([tolerant_f1(gt_bin, pr_bin, tol_px=px) for px in TOL_PX_LIST]))
            scores.append(f)

            count += 1
            if limit and count >= limit:
                break

        mean_s = float(np.mean(scores)) if len(scores) else 0.0
        if mean_s > best_mean:
            best_mean = mean_s
            best_thr = t
    return best_thr, best_mean

def eval_boundary_method(name, predict_boundary_fn, test_samples, thr_global, tol_list, limit=None):
    rows = []
    count = 0
    for d in test_samples:
        out = load_sample(d, target_mode="boundary")
        if out is None:
            continue
        img01, gt = out
        gt_b = gt.numpy().squeeze().astype(np.float32)

        pr = predict_boundary_fn(img01).numpy().squeeze().astype(np.float32)
        pr = np.clip(pr, 0, 1)

        dice05, p05, r05 = bin_metrics(gt_b, pr, thr=0.5)
        diceg, pg, rg = bin_metrics(gt_b, pr, thr=thr_global)

        pr_bin = (pr >= thr_global).astype(np.uint8)
        gt_bin = (gt_b >= 0.5).astype(np.uint8)

        assd, hd95 = assd_hd95(gt_bin.astype(bool), pr_bin.astype(bool))

        row = {
            "method": name,
            "Dice@0.5": dice05,
            "Dice@valThr": diceg,
            "Prec@valThr": pg,
            "Rec@valThr": rg,
            "ASSD_px": assd,
            "HD95_px": hd95,
            "Thr_val": thr_global,
        }
        for tol in tol_list:
            row[f"F1@tol{tol}px"] = tolerant_f1(gt_bin, pr_bin, tol_px=tol)

        rows.append(row)
        count += 1
        if limit and count >= limit:
            break
    return pd.DataFrame(rows)

def overlay_edges(img01, gt_b=None, pr_b=None, thr=0.5):
    img = img01.numpy().squeeze()
    img = (img - img.min())/(img.max()-img.min()+1e-8)
    rgb = np.stack([img,img,img], axis=-1)

    gt = None
    pr = None
    if gt_b is not None:
        gt = (gt_b.numpy().squeeze() >= 0.5)
    if pr_b is not None:
        pr = (pr_b >= thr)

    if gt is not None:
        rgb[gt] = [0,1,0]
    if pr is not None:
        rgb[pr] = [1,0,0]
    if gt is not None and pr is not None:
        both = gt & pr
        rgb[both] = [1,1,0]
    return rgb

def show_overlay_panel(models_dict, title, samples, thr_map, k=3, save_path=None):
    chosen = []
    for d in samples:
        out = load_sample(d, target_mode="boundary")
        if out is not None:
            chosen.append((d, out[0], out[1]))
        if len(chosen) >= k:
            break
    if len(chosen) == 0:
        print("No labeled samples for overlay.")
        return

    cols = 1 + len(models_dict)
    plt.figure(figsize=(4*cols, 4*len(chosen)))
    for r, (_, img01, gt_b) in enumerate(chosen):
        plt.subplot(len(chosen), cols, r*cols + 1)
        plt.imshow(overlay_edges(img01, gt_b=gt_b, pr_b=None, thr=0.5))
        plt.title("GT (green)"); plt.axis("off")

        c = 2
        for name, model in models_dict.items():
            pr = model(img01[None], training=False)[0].numpy().squeeze()
            thr = thr_map.get(name, 0.5)
            plt.subplot(len(chosen), cols, r*cols + c)
            plt.imshow(overlay_edges(img01, gt_b=gt_b, pr_b=pr, thr=thr))
            plt.title(f"{name} (red), overlap(yellow)"); plt.axis("off")
            c += 1

    plt.suptitle(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()

def show_inference_panel(models_dict, title, samples, thr=0.5, k=3, save_path=None):
    chosen = []
    for d in samples:
        img01 = load_input_only(d)
        if img01 is not None:
            chosen.append(img01)
        if len(chosen) >= k:
            break
    if len(chosen) == 0:
        print("No inference samples.")
        return

    cols = 1 + len(models_dict)
    plt.figure(figsize=(4*cols, 4*len(chosen)))
    for r, img01 in enumerate(chosen):
        plt.subplot(len(chosen), cols, r*cols + 1)
        plt.imshow(img01.numpy().squeeze(), cmap="gray"); plt.title("Input"); plt.axis("off")

        c = 2
        for name, model in models_dict.items():
            pr = model(img01[None], training=False)[0].numpy().squeeze()
            ov = overlay_edges(img01, gt_b=None, pr_b=pr, thr=thr)
            plt.subplot(len(chosen), cols, r*cols + c)
            plt.imshow(ov); plt.title(name); plt.axis("off")
            c += 1

    plt.suptitle(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()

# OUTPUT DIRS
BASE_DIR = "/content/embc_run"
OUT_DIR  = os.path.join(BASE_DIR, "outputs")
os.makedirs(OUT_DIR, exist_ok=True)


dataset_roots = {}
if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount("/content/drive")

    if not os.path.isfile(BRATS_ZIP_PATH):
        raise ValueError(f"BraTS zip not found at: {BRATS_ZIP_PATH}")

    sentinel = os.path.join(BRATS_EXTRACT_DIR, ".extracted_ok")
    if not os.path.exists(sentinel):
        if os.path.exists(BRATS_EXTRACT_DIR):
            shutil.rmtree(BRATS_EXTRACT_DIR)
        os.makedirs(BRATS_EXTRACT_DIR, exist_ok=True)
        print("Unzipping BraTS zip (this can take a bit)...")
        with zipfile.ZipFile(BRATS_ZIP_PATH, "r") as z:
            z.extractall(BRATS_EXTRACT_DIR)
        with open(sentinel, "w") as f:
            f.write("ok")
        print("Unzipped BraTS to:", BRATS_EXTRACT_DIR)
    else:
        print(" BraTS already unzipped at:", BRATS_EXTRACT_DIR)

    train_hits = (
        glob.glob(os.path.join(BRATS_EXTRACT_DIR, "**/MICCAI_BraTS2020_TrainingData"), recursive=True) +
        glob.glob(os.path.join(BRATS_EXTRACT_DIR, "**/BraTS2020_TrainingData"), recursive=True)
    )
    val_hits = (
        glob.glob(os.path.join(BRATS_EXTRACT_DIR, "**/MICCAI_BraTS2020_ValidationData"), recursive=True) +
        glob.glob(os.path.join(BRATS_EXTRACT_DIR, "**/BraTS2020_ValidationData"), recursive=True)
    )

    def best_dir_with_nii(hits):
        hits = [h for h in hits if os.path.isdir(h)]
        scored = []
        for h in hits:
            n = len(glob.glob(os.path.join(h, "**/*.nii*"), recursive=True))
            scored.append((n, len(h), h))
        return max(scored)[2] if scored else None

    BRATS_TRAIN_DIR = best_dir_with_nii(train_hits)
    BRATS_VAL_DIR   = best_dir_with_nii(val_hits)

    if not BRATS_TRAIN_DIR or not BRATS_VAL_DIR:
        raise ValueError("Could not auto-detect BraTS Training/Validation folders inside the zip.")

    dataset_roots["BraTS2020_train"] = BRATS_TRAIN_DIR
    dataset_roots["BraTS2020_official_val_unlabeled"] = BRATS_VAL_DIR

    for k,v in EXTRA_DATASET_DIRS.items():
        dataset_roots[k] = v

print("\nDataset roots:", dataset_roots)

# Build indices
train_samples_all, has_masks_train = build_index_for_dataset(dataset_roots["BraTS2020_train"])
official_val_samples, has_masks_off = build_index_for_dataset(dataset_roots["BraTS2020_official_val_unlabeled"])

if not has_masks_train:
    raise ValueError("BraTS TrainingData masks not detected. Check tokens / folder structure.")

print("\nBraTS index sizes:")
print("  TrainingData slices:", len(train_samples_all), " labeled=", has_masks_train)
print("  Official Val slices:", len(official_val_samples), " labeled=", has_masks_off)

# MAIN: TRAIN/EVAL per seed

all_seed_masters = []

for seed in SEEDS:
    print("\n" + "="*70)
    print(f"SEED = {seed}")
    print("="*70)

    set_all_seeds(seed)

    # Leak-free patient-level split
    train_s, val_s, test_labeled_s = split_by_group(train_samples_all, TRAIN_FRAC, VAL_FRAC, seed=seed)
    print(f"\nLeak-free patient split:")
    print("  train:", len(train_s))
    print("  val  :", len(val_s))
    print("  test_labeled:", len(test_labeled_s))

    # Compute pos_weight from training boundary targets
    pos_w = estimate_pos_weight(train_s, limit=300)
    b_loss = boundary_loss_factory(pos_w)
    print("Boundary pos_weight:", pos_w)

    # Datasets
    pretrain_ds = make_tf_dataset(train_s, target_mode="sobel", shuffle=True, repeat=True, seed=seed)
    pre_val_ds  = make_tf_dataset(val_s,   target_mode="sobel", shuffle=False, repeat=True, seed=seed)

    finetune_ds = make_tf_dataset(train_s, target_mode="boundary", shuffle=True, repeat=True, seed=seed)
    fin_val_ds  = make_tf_dataset(val_s,   target_mode="boundary", shuffle=False, repeat=True, seed=seed)

    unet_ds     = make_tf_dataset(train_s, target_mode="mask", shuffle=True, repeat=True, seed=seed)
    unet_val_ds = make_tf_dataset(val_s,   target_mode="mask", shuffle=False, repeat=True, seed=seed)

    steps_pre  = estimate_steps(train_s)
    vsteps_pre = estimate_steps(val_s)
    steps_fin  = estimate_steps(train_s)
    vsteps_fin = estimate_steps(val_s)

    # Models
    learnable = build_learnable_model(loss_obj=sobel_mse_loss)
    learnable.use_per_image_norm = True
    learnable.use_sigmoid = False

    if PRETRAIN_SOBEL_EPOCHS > 0:
        _ = learnable.fit(
            pretrain_ds, validation_data=pre_val_ds,
            epochs=PRETRAIN_SOBEL_EPOCHS,
            steps_per_epoch=steps_pre, validation_steps=vsteps_pre,
            verbose=1
        )

    # Boundary finetune (weighted loss)
    learnable.compile(optimizer=tf.keras.optimizers.Adam(LR), loss=b_loss)
    learnable.use_per_image_norm = False
    learnable.use_sigmoid = True

    _ = learnable.fit(
        finetune_ds, validation_data=fin_val_ds,
        epochs=FINETUNE_BOUNDARY_EPOCHS,
        steps_per_epoch=steps_fin, validation_steps=vsteps_fin,
        verbose=1
    )

    smallcnn = None
    sobelfront = None
    unet = None

    if RUN_SMALL_CNN_BASELINE:
        smallcnn = build_small_cnn_boundary(loss_obj=b_loss)
        _ = smallcnn.fit(
            finetune_ds, validation_data=fin_val_ds,
            epochs=FINETUNE_BOUNDARY_EPOCHS,
            steps_per_epoch=steps_fin, validation_steps=vsteps_fin,
            verbose=1
        )

    if RUN_SOBELFRONT_CNN:
        sobelfront = build_sobelfront_cnn(loss_obj=b_loss)
        _ = sobelfront.fit(
            finetune_ds, validation_data=fin_val_ds,
            epochs=FINETUNE_BOUNDARY_EPOCHS,
            steps_per_epoch=steps_fin, validation_steps=vsteps_fin,
            verbose=1
        )

    if RUN_TINY_UNET_SEG_BASE:
        unet = build_tiny_unet_seg(loss_obj=seg_loss)
        _ = unet.fit(
            unet_ds, validation_data=unet_val_ds,
            epochs=FINETUNE_UNET_EPOCHS,
            steps_per_epoch=steps_fin, validation_steps=vsteps_fin,
            verbose=1
        )

    # Predict fns (boundary)
    def pred_fixed(img01): return fixed_sobel_conv(img01)
    def pred_learnable(img01): return learnable(img01[None], training=False)[0]
    def pred_smallcnn(img01): return smallcnn(img01[None], training=False)[0]
    def pred_sobelfront(img01): return sobelfront(img01[None], training=False)[0]
    def pred_unet_mask(img01): return unet(img01[None], training=False)[0]

    # Threshold selection on VAL (obj = mean tolerant F1)
    thr_fixed, best_fixed = find_best_global_threshold(pred_fixed, val_s, THRESHOLDS, limit=EVAL_LIMIT)
    thr_learn, best_learn = find_best_global_threshold(pred_learnable, val_s, THRESHOLDS, limit=EVAL_LIMIT)

    thr_small = None
    thr_sf = None
    thr_unet = None

    if smallcnn:
        thr_small, best_small = find_best_global_threshold(pred_smallcnn, val_s, THRESHOLDS, limit=EVAL_LIMIT)
    if sobelfront:
        thr_sf, best_sf = find_best_global_threshold(pred_sobelfront, val_s, THRESHOLDS, limit=EVAL_LIMIT)
    if unet:
        thr_unet, best_unet = find_best_threshold_for_unet_boundary(pred_unet_mask, val_s, THRESHOLDS, limit=EVAL_LIMIT)

    print("\nVAL-tuned global thresholds (tuned on mean tolerant-F1):")
    print("  FixedSobel :", thr_fixed, "score:", best_fixed)
    print("  Learnable  :", thr_learn, "score:", best_learn)
    if thr_small is not None: print("  SmallCNN   :", thr_small)
    if thr_sf is not None:    print("  SobelFront :", thr_sf)
    if thr_unet is not None:  print("  TinyUNet(mask->boundary) mask_thr:", thr_unet)

    # Evaluate on TEST_LABELED (no leakage)
    print("\n PAPER EVAL: TEST_LABELED (leak-free, patient split)\n")

    # boundary from unet at chosen mask threshold
    def pred_unet_boundary(img01):
        pr_m = pred_unet_mask(img01).numpy().squeeze()
        pr_m = (pr_m >= thr_unet).astype(np.float32)
        pr_m_tf = tf.convert_to_tensor(pr_m[..., None], dtype=tf.float32)
        pr_b = mask_boundary_target(pr_m_tf).numpy().squeeze().astype(np.float32)
        return tf.convert_to_tensor(pr_b[..., None], dtype=tf.float32)

    # Per-method eval
    df_fixed = eval_boundary_method("B_fixed_sobel_edges", pred_fixed, test_labeled_s, thr_fixed, TOL_PX_LIST, limit=EVAL_LIMIT)
    df_learn = eval_boundary_method("C_learnable_sobel_pretrain_finetune", pred_learnable, test_labeled_s, thr_learn, TOL_PX_LIST, limit=EVAL_LIMIT)

    dfs = [df_fixed, df_learn]

    if smallcnn:
        df_small = eval_boundary_method("D_smallcnn_baseline", pred_smallcnn, test_labeled_s, thr_small, TOL_PX_LIST, limit=EVAL_LIMIT)
        dfs.append(df_small)
    if sobelfront:
        df_sf = eval_boundary_method("E_sobelfront_cnn", pred_sobelfront, test_labeled_s, thr_sf, TOL_PX_LIST, limit=EVAL_LIMIT)
        dfs.append(df_sf)
    if unet:
        df_unet = eval_boundary_method("F_tinyunet_seg_to_boundary", lambda img01: pred_unet_boundary(img01), test_labeled_s, 0.5, TOL_PX_LIST, limit=EVAL_LIMIT)
        df_unet["Thr_val"] = thr_unet
        dfs.append(df_unet)

    df_all = pd.concat(dfs, ignore_index=True)
    df_all["seed"] = seed

    # Summarize (mean/std) per seed
    metric_cols = ["Dice@0.5","Dice@valThr","Prec@valThr","Rec@valThr","ASSD_px","HD95_px"] + [f"F1@tol{t}px" for t in TOL_PX_LIST]
    summary = (df_all.groupby("method")[metric_cols]
               .agg(["mean","std"])
               .reset_index())
    summary.columns = ["_".join(c).strip("_") for c in summary.columns]
    summary["seed"] = seed
    summary["alpha"] = float(learnable.alpha.numpy())
    summary["beta"]  = float(learnable.beta.numpy())
    summary["pos_weight"] = float(pos_w)
    summary["learnable_params"] = learnable.count_params()
    summary["smallcnn_params"]  = smallcnn.count_params() if smallcnn else None
    summary["sobelfront_params"]= sobelfront.count_params() if sobelfront else None
    summary["unet_params"]      = unet.count_params() if unet else None

    # Perfect row (for reference)
    perfect_row = {"method_": "A_perfect_target_vs_target", "seed": seed,
                   "alpha": float(learnable.alpha.numpy()),
                   "beta": float(learnable.beta.numpy()),
                   "pos_weight": float(pos_w),
                   "learnable_params": learnable.count_params(),
                   "smallcnn_params": smallcnn.count_params() if smallcnn else None,
                   "sobelfront_params": sobelfront.count_params() if sobelfront else None,
                   "unet_params": unet.count_params() if unet else None}

    # Fill perfect metrics
    for col in metric_cols:
        if col in ["ASSD_px","HD95_px"]:
            perfect_row[f"{col}_mean"] = 0.0
            perfect_row[f"{col}_std"]  = 0.0
        else:
            perfect_row[f"{col}_mean"] = 1.0
            perfect_row[f"{col}_std"]  = 0.0

    perfect_row["Thr_val_mean"] = 0.5
    perfect_row["Thr_val_std"]  = 0.0

    # Save per-seed outputs
    seed_dir = os.path.join(OUT_DIR, f"seed_{seed}")
    os.makedirs(seed_dir, exist_ok=True)

    df_all.to_csv(os.path.join(seed_dir, "per_sample_metrics.csv"), index=False)
    summary.to_csv(os.path.join(seed_dir, "summary_metrics.csv"), index=False)

    # Overlays on TEST_LABELED
    thr_map = {
        "Learnable": thr_learn,
        "SmallCNN": thr_small if thr_small is not None else 0.5,
        "SobelFrontCNN": thr_sf if thr_sf is not None else 0.5,
    }

    overlay_models = {"Learnable": learnable}
    if smallcnn: overlay_models["SmallCNN"] = smallcnn
    if sobelfront: overlay_models["SobelFrontCNN"] = sobelfront

    show_overlay_panel(
        overlay_models,
        title=f"BraTS TEST_LABELED (seed {seed}): GT green, Pred red, overlap yellow",
        samples=test_labeled_s,
        thr_map=thr_map,
        k=VIS_K,
        save_path=os.path.join(seed_dir, "FIG_overlay_TEST_LABELED.png")
    )

    # Inference-only on official unlabeled val
    inf_models = {"Learnable(boundary)": learnable}
    if smallcnn: inf_models["SmallCNN(boundary)"] = smallcnn
    if sobelfront: inf_models["SobelFront(boundary)"] = sobelfront

    show_inference_panel(
        inf_models,
        title=f"BraTS official ValidationData (unlabeled) boundary inference (seed {seed})",
        samples=official_val_samples,
        thr=thr_map.get("SobelFrontCNN", 0.5),
        k=VIS_K,
        save_path=os.path.join(seed_dir, "FIG_infer_OFFICIAL_VAL.png")
    )

    # Build a "master rows" table for this seed: perfect + method summaries
    master_rows = [perfect_row]
    for _, row in summary.iterrows():
        d = row.to_dict()
        # unify method name column to method_
        d["method_"] = d.get("method_", None)
        if "method_" not in d or d["method_"] is None:
            # summary has "method" column name turned into "method_"? It is "method_" after groupby flatten
            # safest: recover from row keys
            if "method_" in row:
                d["method_"] = row["method_"]
            elif "method" in row:
                d["method_"] = row["method"]
        master_rows.append(d)

    all_seed_masters.append(pd.DataFrame(master_rows))

# MASTER EXPORT (all seeds)

master = pd.concat(all_seed_masters, ignore_index=True) if len(all_seed_masters) else pd.DataFrame()
master_path = os.path.join(OUT_DIR, "embc_paper_master_summary_all_seeds.csv")
master.to_csv(master_path, index=False)

print("\n Saved master summary:", master_path)
print("\n===== MASTER SUMMARY (preview) =====")
print(master.head(20) if len(master) else "Empty master table")

# Across-seed aggregate (mean/std across SEEDS) for each method_
m = master[master["method_"].notna()].copy()
num_cols = [c for c in m.columns if c not in ["method_", "method", "seed"] and pd.api.types.is_numeric_dtype(m[c])]
agg = (m.groupby("method_")[num_cols].agg(["mean","std"]).reset_index())
agg.columns = ["_".join(col).strip("_") for col in agg.columns]
agg_path = os.path.join(OUT_DIR, "embc_summary_across_seeds.csv")
agg.to_csv(agg_path, index=False)
print(" Saved across-seed summary:", agg_path)

# Save reproducibility config
config = dict(
    IMG_SIZE=IMG_SIZE, BATCH_SIZE=BATCH_SIZE, LR=LR,
    SEEDS=SEEDS,
    PRETRAIN_SOBEL_EPOCHS=PRETRAIN_SOBEL_EPOCHS,
    FINETUNE_BOUNDARY_EPOCHS=FINETUNE_BOUNDARY_EPOCHS,
    FINETUNE_UNET_EPOCHS=FINETUNE_UNET_EPOCHS,
    MAX_SLICES_PER_VOLUME=MAX_SLICES_PER_VOLUME,
    TRAIN_FRAC=TRAIN_FRAC, VAL_FRAC=VAL_FRAC,
    THRESHOLDS=THRESHOLDS,
    TOL_PX_LIST=TOL_PX_LIST,
    SKIP_EMPTY_MASK_SLICES=SKIP_EMPTY_MASK_SLICES,
    BOUNDARY_THICKNESS=BOUNDARY_THICKNESS,
    RUN_SMALL_CNN_BASELINE=RUN_SMALL_CNN_BASELINE,
    RUN_SOBELFRONT_CNN=RUN_SOBELFRONT_CNN,
    RUN_TINY_UNET_SEG_BASE=RUN_TINY_UNET_SEG_BASE,
    BRATS_ZIP_PATH=BRATS_ZIP_PATH
)
with open(os.path.join(OUT_DIR, "run_config.json"), "w") as f:
    json.dump(config, f, indent=2)

# Zip outputs
zip_out = "/content/embc_edge_outputs.zip"
!zip -qr {zip_out} {OUT_DIR}
files.download(zip_out)

print("\n DONE. Outputs include:")
print(" - per-seed per-sample + summary CSVs")
print(" - master summary across seeds: embc_paper_master_summary_all_seeds.csv")
print(" - across-seed summary: embc_summary_across_seeds.csv")
print(" - overlays on TEST_LABELED + inference on official unlabeled val")
print(" - run_config.json for reproducibility")


TensorFlow: 2.19.0
IMG_SIZE: (256, 256)
Mounted at /content/drive
Unzipping BraTS zip (this can take a bit)...
