# Import necessary libraries

In [32]:
import os, re, json, shutil
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import numpy as np
import nibabel as nib
from collections import defaultdict
from sklearn.model_selection import StratifiedGroupKFold

# ----- CONFIG: set your paths and task info ------

In [37]:
# ---------------- CONFIG: set your paths and task info ----------------
# Target nnU-Net raw task
TASK_ID = 100
TASK_NAME = "DHAI_T1cEnhancingSeg"
TASK_DIR = Path(f"/Users/chufal/nnUNet_raw/Task{TASK_ID:03d}_{TASK_NAME}")  # change if desired

In [49]:
# Data roots
ROOTS = {
    "MU": "/Users/chufal/projects/Datasets/PKG-MU-Glioma-Post/MU-Glioma-Post",
    "MET": "/Users/chufal/projects/Datasets/PKG-Pretreat-MetsToBrain-Masks/Pretreat-MetsToBrain-Masks",
    "BCBM": "/Users/chufal/projects/Datasets/PKG-BCBM-RadioGenomics_Images_Masks_Dec2024/BCBM_KSC_curated_data",
    "UCSD": "/Users/chufal/projects/Datasets/PKG-UCSD-PTGBM-v1/UCSD-PTGBM",
    "UCSF": "/Users/chufal/projects/Datasets/PKG - UCSF-PDGM Version 5/UCSF-PDGM-v5",
    "UPENN": "/Users/chufal/projects/Datasets/PKG-UPENN-GBM-NIfTI/UPENN-GBM/NIfTI-files",
    "BRATSAfrica": "/Users/chufal/projects/Datasets/PKG-BraTS-Africa/BraTS-Africa",
    # "YALE": excluded (no segmentations)
}

In [50]:
# Enhancing-label overrides per dataset when masks are multi-class.
# If None, the code tries auto-detect: prefer 4, else 3, else 1 (BraTS-like patterns).
ENHANCING_LABEL_OVERRIDES: Dict[str, Optional[List[int]]] = {
    "MU": [3],      # FeTS label 3 = Enhancing Tissue (ET)
    "UCSF": None,    # likely multi-class; auto-detect per-case
    "UPENN": None,   # often binary; if multi-class, auto-detect
    "MET": [1],      # MET segs typically binary -> >0; setting [1] makes intent explicit
    "UCSD": [1],     # "_enhancing_cellular_tumor_seg" is binary -> >0
    "BCBM": [1],     # union of region masks -> binary
    "BRATSAfrica": [3],
}

In [51]:
# Staging options
WRITE_FILES = False        # set True to actually write files and dataset.json
LINK_METHOD = "symlink"    # "symlink" or "copy"
TRAIN_TEST_SPLIT = 0.2     # test fraction per dataset

# Utilities

In [53]:
def load_unique_values(mask_path: str, max_vox: int = 2_000_000) -> List[int]:
    img = nib.load(mask_path)
    arr = np.asanyarray(img.dataobj)
    if arr.size > max_vox:
        # random subsample for speed
        rng = np.random.default_rng(13)
        idx = rng.integers(0, arr.size, size=max_vox)
        uniq = np.unique(arr.reshape(-1)[idx])
    else:
        uniq = np.unique(arr)
    # int-like rounding if float labels
    if arr.dtype.kind not in ("i","u"):
        uniq = np.unique(np.round(uniq).astype(np.int32))
    return [int(u) for u in uniq.tolist()]

def pick_enhancing_labels(dataset_key: str, mask_path: str) -> List[int]:
    override = ENHANCING_LABEL_OVERRIDES.get(dataset_key)
    if override:
        return override
    uniq = load_unique_values(mask_path)
    uniq_nz = [u for u in uniq if u != 0]
    if not uniq_nz:
        return []
    # Prefer 3 (ET), else 1; NEVER pick 4 (RC)
    if 3 in uniq_nz:
        return [3]
    if 1 in uniq_nz:
        return [1]
    # if nothing matches heuristics, fallback to all non-zero but exclude 4
    return [u for u in uniq_nz if u != 4]

def mask_has_any_labels(mask_path: str, pos_labels: list[int]) -> bool:
    img = nib.load(mask_path)
    arr = np.asanyarray(img.dataobj)
    if arr.dtype.kind not in ("i","u"):
        arr = np.round(arr).astype(np.int32)
    return np.isin(arr, pos_labels).any()

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def stage_file(src: Path, dst: Path):
    ensure_dir(dst.parent)
    if dst.exists():
        return
    if LINK_METHOD == "symlink":
        os.symlink(src, dst)
    else:
        shutil.copy2(src, dst)

def write_binary_label_from_multiclass(src_mask: Path, dst_mask: Path, pos_labels: List[int]):
    ensure_dir(dst_mask.parent)
    img = nib.load(str(src_mask))
    data = np.asanyarray(img.dataobj)
    if data.dtype.kind not in ("i","u"):
        data = np.round(data).astype(np.int32)
    binmask = np.isin(data, pos_labels).astype(np.uint8)
    out = nib.Nifti1Image(binmask, affine=img.affine, header=img.header)
    out.header.set_data_dtype(np.uint8)
    nib.save(out, str(dst_mask))

# 4) OPTIONAL: make label writer robust to rare 4D one-hot segmentations
def write_binary_label_from_multiclass(src_mask: Path, dst_mask: Path, pos_labels: list[int]):
    ensure_dir(dst_mask.parent)
    img = nib.load(str(src_mask))
    data = np.asanyarray(img.dataobj)
    # If 4D one-hot, convert to label map first
    if data.ndim == 4 and data.shape[-1] <= 10 and np.array_equal(np.unique(data), [0, 1]):
        data = np.argmax(data, axis=-1)
    if data.dtype.kind not in ("i","u"):
        data = np.round(data).astype(np.int32)
    binmask = np.isin(data, pos_labels).astype(np.uint8)
    out = nib.Nifti1Image(binmask, affine=img.affine, header=img.header)
    out.header.set_data_dtype(np.uint8)
    nib.save(out, str(dst_mask))

# ---------------- Dataset-specific discoverers ----------------
def discover_MET() -> List[dict]:
    root = Path(ROOTS["MET"])
    out = []
    for case in sorted([d for d in root.iterdir() if d.is_dir() and d.name.startswith("BraTS-MET-")]):
        t1c = next((case/f for f in os.listdir(case) if f.endswith("-t1c.nii.gz")), None)
        seg = next((case/f for f in os.listdir(case) if f.endswith("-seg.nii.gz")), None)
        if t1c and seg:
            out.append({"dataset":"MET","id":case.name,"image":str(t1c),"mask":str(seg),"mask_kind":"binary"})
    return out

def discover_BCBM() -> list[dict]:
    root = Path(ROOTS["BCBM"])
    out: list[dict] = []
    if not root.is_dir():
        print(f"[WARN] Curated BCBM root not found: {root}")
        return out

    def pick_image(case_dir: Path) -> Path | None:
        # Preferred file name pattern
        preferred = [f for f in case_dir.iterdir()
                     if f.is_file() and f.name.endswith("_image_ss_n4.nii.gz")]
        if preferred:
            return preferred[0]
        # Fallback: first non-mask nii.gz
        others = [f for f in case_dir.iterdir()
                  if f.is_file() and f.name.endswith(".nii.gz") and "_mask_" not in f.name]
        return others[0] if others else None

    for case in sorted([d for d in root.iterdir() if d.is_dir() and not d.name.startswith(".")]):
        image = pick_image(case)
        mask_files = [f for f in case.iterdir()
                      if f.is_file() and f.name.endswith(".nii.gz") and "_mask_" in f.name]
        if image is None:
            print(f"[SKIP] {case.name}: no MRI found")
            continue
        if not mask_files:
            print(f"[SKIP] {case.name}: no masks in curated folder")
            continue
        out.append({
            "dataset": "BCBM",
            "id": case.name,
            "image": str(image),
            "mask_files": [str(m) for m in sorted(mask_files)],
            "mask_kind": "multi_file_union"
        })
    return out

def discover_UCSD() -> List[dict]:
    root = Path(ROOTS["UCSD"])
    out = []
    for case in sorted([d for d in root.iterdir() if d.is_dir()]):
        t1post = next((d for d in case.iterdir() if d.name.endswith("_T1post.nii.gz")), None)
        enh = next((d for d in case.iterdir() if d.name.endswith("_enhancing_cellular_tumor_seg.nii.gz")), None)
        if t1post and enh:
            out.append({"dataset":"UCSD","id":case.name,"image":str(t1post),"mask":str(enh),"mask_kind":"binary"})
    return out

def discover_UCSF() -> List[dict]:
    root = Path(ROOTS["UCSF"])
    out = []
    for case in sorted([d for d in root.iterdir() if d.is_dir()]):
        t1c = next((d for d in case.iterdir() if d.name.endswith("_T1c_bias.nii.gz")), None)
        if t1c is None:
            t1c = next((d for d in case.iterdir() if d.name.endswith("_T1c.nii.gz")), None)
        seg = next((d for d in case.iterdir() if d.name.endswith("_tumor_segmentation.nii.gz")), None)
        if t1c and seg:
            out.append({"dataset":"UCSF","id":case.name,"image":str(t1c),"mask":str(seg),"mask_kind":"multiclass"})
    return out

def discover_UPENN() -> List[dict]:
    root = Path(ROOTS["UPENN"])
    img_dir = Path(root) / "images_structural_unstripped"
    seg_dir = Path(root) / "images_segm"
    if not img_dir.is_dir():
        return []
    seg_map = {}
    if seg_dir.is_dir():
        for f in seg_dir.iterdir():
            if f.name.endswith("_segm.nii.gz"):
                key = f.name.replace("_segm.nii.gz","")
                seg_map[key] = f
    out = []
    for case in sorted([d for d in img_dir.iterdir() if d.is_dir()]):
        t1gd = next((f for f in case.iterdir() if f.name.endswith("_T1GD_unstripped.nii.gz")), None)
        if not t1gd:
            continue
        key = case.name  # matches seg key if present
        seg = seg_map.get(key)
        if seg:
            out.append({"dataset":"UPENN","id":case.name,"image":str(t1gd),"mask":str(seg),"mask_kind":"auto"})  # auto: detect binary/multi
    return out

def discover_MU() -> List[dict]:
    root = Path(ROOTS["MU"])
    out = []
    for patient in sorted([d for d in root.iterdir() if d.is_dir()]):
        for tp in sorted([d for d in patient.iterdir() if d.is_dir()]):
            files = [f for f in tp.iterdir() if f.is_file() and f.name.endswith(".nii.gz")]
            t1c = next((f for f in files if "brain_t1c" in f.name), None)
            seg = next((f for f in files if f.name.endswith("tumorMask.nii.gz")), None)
            if t1c and seg:
                out.append({"dataset":"MU","id":f"{patient.name}_{tp.name}","image":str(t1c),"mask":str(seg),"mask_kind":"multiclass"})
    return out

# 3) Discover BraTS-Africa (only 95_Glioma; ignore 51_OtherNeoplasms)
def discover_BRATSAfrica() -> list[dict]:
    root = Path(ROOTS["BRATSAfrica"])
    glioma_dir = root / "95_Glioma"
    out = []
    if not glioma_dir.is_dir():
        return out
    for case in sorted([d for d in glioma_dir.iterdir() if d.is_dir() and not d.name.startswith(".")]):
        t1c = next((case/f for f in os.listdir(case) if f.endswith("-t1c.nii.gz")), None)
        seg = next((case/f for f in os.listdir(case) if f.endswith("-seg.nii.gz")), None)
        if t1c and seg:
            out.append({
                "dataset": "BRATSAfrica",
                "id": case.name,
                "image": str(t1c),
                "mask": str(seg),
                "mask_kind": "multiclass"  # label map with ET=3
            })
    return out

# --- Collect candidates ---

In [54]:
def is_nonempty_nifti(path: str) -> bool:
    p = Path(path)
    try:
        if not (p.is_file() and p.stat().st_size > 0):
            return False
        # fast header sniff without reading full data
        _ = nib.load(str(p))  # will throw for corrupt/empty headers
        return True
    except Exception:
        return False

def load_unique_values(mask_path: str, max_vox: int = 2_000_000) -> list[int]:
    p = Path(mask_path)
    if not is_nonempty_nifti(str(p)):
        raise RuntimeError(f"Mask not readable or empty: {mask_path}")
    try:
        img = nib.load(str(p))
        arr = np.asanyarray(img.dataobj)
    except Exception as e:
        raise RuntimeError(f"Failed to read mask {mask_path}: {e}") from e

    if arr.size > max_vox:
        rng = np.random.default_rng(13)
        idx = rng.integers(0, arr.size, size=max_vox)
        uniq = np.unique(arr.reshape(-1)[idx])
    else:
        uniq = np.unique(arr)
    if arr.dtype.kind not in ("i","u"):
        uniq = np.unique(np.round(uniq).astype(np.int32))
    return [int(u) for u in uniq.tolist()]

In [47]:
def collect_candidates():
    all_items = []
    all_items += discover_MET()
    all_items += discover_BCBM()
    all_items += discover_UCSD()
    all_items += discover_UCSF()
    all_items += discover_UPENN()
    all_items += discover_MU()
    all_items += discover_BRATSAfrica()  # <- add here

    ok = []
    bad_files = []
    for it in all_items:
        img_ok = is_nonempty_nifti(it["image"])
        if not img_ok:
            bad_files.append(("image", it["dataset"], it.get("id"), it["image"]))
            continue

        # one mask file
        if it.get("mask"):
            if not is_nonempty_nifti(it["mask"]):
                bad_files.append(("mask", it["dataset"], it.get("id"), it["mask"]))
                continue

        # multiple mask files (union case)
        if it.get("mask_files"):
            mask_ok = True
            for p in it["mask_files"]:
                if not is_nonempty_nifti(p):
                    bad_files.append(("mask", it["dataset"], it.get("id"), p))
                    mask_ok = False
            if not mask_ok:
                continue

        ok.append(it)

    if bad_files:
        print(f"[WARN] Skipping {len(bad_files)} unreadable/empty files. Examples:")
        for kind, ds, cid, p in bad_files[:10]:
            print(f"  - {kind} | {ds} | {cid} | {p}")
    return ok


In [55]:
items = collect_candidates()
print(f"Found candidates: {len(items)} by dataset →",
      {k: sum(1 for it in items if it['dataset']==k) for k in sorted(set(it['dataset'] for it in items))})

[WARN] Skipping 2 unreadable/empty files. Examples:
  - mask | UCSF | UCSF-PDGM-0541_nifti | /Users/chufal/projects/Datasets/PKG - UCSF-PDGM Version 5/UCSF-PDGM-v5/UCSF-PDGM-0541_nifti/UCSF-PDGM-0541_tumor_segmentation.nii.gz
  - mask | MU | PatientID_0275_Timepoint_6 | /Users/chufal/projects/Datasets/PKG-MU-Glioma-Post/MU-Glioma-Post/PatientID_0275/Timepoint_6/PatientID_0275_Timepoint_6_tumorMask.nii.gz
Found candidates: 1983 by dataset → {'BCBM': 264, 'BRATSAfrica': 95, 'MET': 200, 'MU': 593, 'UCSD': 184, 'UCSF': 500, 'UPENN': 147}


In [56]:
def stage_and_generate(items: list[dict]):
    from pathlib import Path
    import numpy as np
    import nibabel as nib

    def nifti_readable(path: str) -> bool:
        p = Path(path)
        try:
            if not (p.is_file() and p.stat().st_size > 0):
                return False
            _ = nib.load(str(p))
            return True
        except Exception:
            return False

    def mask_has_any_labels(mask_path: str, pos_labels: list[int]) -> bool:
        img = nib.load(mask_path)
        arr = np.asanyarray(img.dataobj)
        if arr.dtype.kind not in ("i", "u"):
            arr = np.round(arr).astype(np.int32)
        return np.isin(arr, pos_labels).any()

    imagesTr = TASK_DIR / "imagesTr"
    labelsTr = TASK_DIR / "labelsTr"
    ensure_dir(imagesTr); ensure_dir(labelsTr)

    training_entries = []
    meta_rows = []

    for idx, it in enumerate(items):
        ds = it["dataset"]
        case_id = f"{ds}__{it['id']}"
        img_src = Path(it["image"])
        img_dst = imagesTr / f"{case_id}_0000.nii.gz"
        lbl_dst = labelsTr / f"{case_id}.nii.gz"

        # Validate image first
        if not nifti_readable(str(img_src)):
            print(f"[SKIP] {case_id}: image unreadable/empty → {img_src}")
            continue

        mk = it["mask_kind"]
        try:
            if mk == "binary":
                mask_src = Path(it["mask"])
                if not nifti_readable(str(mask_src)):
                    print(f"[SKIP] {case_id}: mask unreadable/empty → {mask_src}")
                    continue
                # Require any positive voxels
                if not mask_has_any_labels(str(mask_src), [1]):
                    print(f"[SKIP] {case_id}: binary mask has zero positive voxels")
                    continue
                if WRITE_FILES:
                    stage_file(img_src, img_dst)
                    write_binary_label_from_multiclass(mask_src, lbl_dst, pos_labels=[1])
                training_entries.append({"image": f"./imagesTr/{img_dst.name}", "label": f"./labelsTr/{lbl_dst.name}"})

            elif mk == "multiclass":
                mask_src = Path(it["mask"])
                if not nifti_readable(str(mask_src)):
                    print(f"[SKIP] {case_id}: mask unreadable/empty → {mask_src}")
                    continue
                pos = pick_enhancing_labels(ds, str(mask_src))
                if not pos:
                    print(f"[SKIP] {case_id}: no enhancing labels determined")
                    continue
                if not mask_has_any_labels(str(mask_src), pos):
                    print(f"[SKIP] {case_id}: no voxels for enhancing labels {pos}")
                    continue
                if WRITE_FILES:
                    stage_file(img_src, img_dst)
                    write_binary_label_from_multiclass(mask_src, lbl_dst, pos_labels=pos)
                training_entries.append({"image": f"./imagesTr/{img_dst.name}", "label": f"./labelsTr/{lbl_dst.name}"})

            elif mk == "auto":
                mask_src = Path(it["mask"])
                if not nifti_readable(str(mask_src)):
                    print(f"[SKIP] {case_id}: mask unreadable/empty → {mask_src}")
                    continue
                uniq = load_unique_values(str(mask_src))
                pos = [1] if set(uniq).issubset({0, 1}) else pick_enhancing_labels(ds, str(mask_src))
                if not pos:
                    print(f"[SKIP] {case_id}: cannot determine enhancing labels from {uniq}")
                    continue
                if not mask_has_any_labels(str(mask_src), pos):
                    print(f"[SKIP] {case_id}: no voxels for enhancing labels {pos}")
                    continue
                if WRITE_FILES:
                    stage_file(img_src, img_dst)
                    write_binary_label_from_multiclass(mask_src, lbl_dst, pos_labels=pos)
                training_entries.append({"image": f"./imagesTr/{img_dst.name}", "label": f"./labelsTr/{lbl_dst.name}"})

            elif mk == "multi_file_union":
                mask_files = [Path(p) for p in it["mask_files"]]
                all_readable = True
                for mp in mask_files:
                    if not nifti_readable(str(mp)):
                        print(f"[SKIP] {case_id}: union mask unreadable/empty → {mp}")
                        all_readable = False
                        break
                if not all_readable:
                    continue
                # Build union and verify non-empty
                img_ref = nib.load(str(mask_files[0]))
                union = None
                for mp in mask_files:
                    arr = np.asanyarray(nib.load(str(mp)).dataobj)
                    union = (arr > 0) if union is None else (union | (arr > 0))
                if not np.any(union):
                    print(f"[SKIP] {case_id}: union of masks is empty")
                    continue
                if WRITE_FILES:
                    stage_file(img_src, img_dst)
                    union = union.astype(np.uint8)
                    out = nib.Nifti1Image(union, affine=img_ref.affine, header=img_ref.header)
                    out.header.set_data_dtype(np.uint8)
                    nib.save(out, str(lbl_dst))
                training_entries.append({"image": f"./imagesTr/{img_dst.name}", "label": f"./labelsTr/{lbl_dst.name}"})

            else:
                print(f"[WARN] {case_id}: unknown mask_kind={mk}; skipping")
                continue

        except Exception as e:
            print(f"[SKIP] {case_id}: error while preparing → {e}")
            continue

        meta_rows.append({"case": case_id, "dataset": ds})

    # simple per-dataset split
    rng_seed = 42
    datasets = np.array([r["dataset"] for r in meta_rows])
    cases = np.array([r["case"] for r in meta_rows])
    train_cases, test_cases = [], []
    for ds_unique in sorted(set(datasets)):
        ds_cases = cases[datasets == ds_unique]
        n_test = max(1, int(len(ds_cases) * TRAIN_TEST_SPLIT))
        rs = np.random.RandomState(rng_seed)
        ds_cases = ds_cases.copy()
        rs.shuffle(ds_cases)
        test_cases.extend(ds_cases[:n_test].tolist())
        train_cases.extend(ds_cases[n_test:].tolist())

    dataset_json = {
        "name": TASK_NAME,
        "description": "T1c/T1GD enhancing tumor segmentation across multiple datasets",
        "reference": "DHAI",
        "licence": "private",
        "release": "0.1",
        "tensorImageSize": "3D",
        "modality": {"0": "T1gd"},
        "labels": {"background": 0, "ET": 1},
        "numTraining": len(training_entries),
        "file_ending": ".nii.gz",
        "training": training_entries,
        "test": []
    }

    if WRITE_FILES:
        with open(TASK_DIR / "dataset.json", "w") as f:
            json.dump(dataset_json, f, indent=2)
        print(f"Wrote dataset.json with {len(training_entries)} samples to {TASK_DIR}")

    print("Totals:", len(training_entries), "trainable samples")
    by_ds = {}
    for ds_unique in sorted(set(r["dataset"] for r in meta_rows)):
        by_ds[ds_unique] = sum(1 for e in training_entries if e["image"].startswith(f"./imagesTr/{ds_unique}__"))
    print("By dataset:", by_ds)
    return dataset_json, train_cases, test_cases

In [57]:
dataset_json, train_cases, test_cases = stage_and_generate(items)

print("DRY RUN (no writes):", not WRITE_FILES)
print("Example entry:", dataset_json["training"][0] if dataset_json["training"] else "No entries")

[SKIP] MET__BraTS-MET-00086-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00089-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00090-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00098-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00100-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00106-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00107-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00108-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00109-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00111-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00113-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00119-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00120-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-MET-00124-000: binary mask has zero positive voxels
[SKIP] MET__BraTS-ME

In [59]:
dataset_json.keys()

dict_keys(['name', 'description', 'reference', 'licence', 'release', 'tensorImageSize', 'modality', 'labels', 'numTraining', 'file_ending', 'training', 'test'])

In [81]:
dataset_json['training'][1800]

{'image': './imagesTr/BRATSAfrica__BraTS-SSA-00140-000_0000.nii.gz',
 'label': './labelsTr/BRATSAfrica__BraTS-SSA-00140-000.nii.gz'}