# Unified dataset builder for dual-task (segmentation + classification)
- Outputs nnU-Net-style folders: imagesTr/ (with _0000 channel suffix), labelsTr/
- dataset.json includes classes per case (non-standard nnU-Net extension)
- train.csv / val.csv / test.csv with columns: case_id, class_label, image_path, label_path
#
## NOTES:
- Reorients to RAS+ via nibabel.as_closest_canonical
- Binary conversions are applied per dataset-specific rules below
- Splits are patient-disjoint and stratified by class

### CONFIGURE these before running:

In [None]:
BASE_PATH = "/Users/chufal/projects/Datasets"
OUT_ROOT = "/Users/chufal/projects/DHAI-Brain-Segmentation/derived/unified_dualtask"
SPLIT_RATIOS = (0.7, 0.15, 0.15)  # (train, val, test) for the unified dataset
SEED = 42

### Importing libraries

In [None]:
from pathlib import Path
import json, csv, re, random
from typing import List, Dict, Tuple, Optional
import numpy as np
import nibabel as nib
import os
import matplotlib.pyplot as plt
from ipywidgets import (
    Dropdown, IntSlider, FloatSlider, Checkbox, RadioButtons, HBox, VBox, Output, HTML, Layout,
    Button, Text, Textarea
)
from IPython.display import display as _display
from matplotlib.colors import ListedColormap, to_rgba

# Data Curation Function and helpers

In [None]:
# ---------- IO & helpers ----------
def as_ras(img: nib.Nifti1Image) -> nib.Nifti1Image:
    return nib.as_closest_canonical(img)

def save_nifti(path: Path, data: np.ndarray, affine: np.ndarray) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    nib.save(nib.Nifti1Image(data, affine), str(path))

def load_arr_ras(path: Path, is_seg: bool = False) -> Tuple[np.ndarray, np.ndarray]:
    img = nib.load(str(path))
    img_ras = as_ras(img)
    arr = img_ras.get_fdata(dtype=np.float32)
    if is_seg:
        arr = np.rint(arr).astype(np.int16)
    return arr, img_ras.affine

def ensure_binary(arr: np.ndarray) -> np.ndarray:
    return (arr > 0).astype(np.uint8)

# ---------- Binary conversion rules (per dataset) ----------
def bin_ucsf_pdgm(lbl: np.ndarray) -> np.ndarray:
    # If 1 or 4 present -> those -> 1, else >0 -> 1
    uniq = set(np.unique(lbl).tolist())
    if 1 in uniq or 4 in uniq:
        out = ((lbl == 1) | (lbl == 4)).astype(np.uint8)
    else:
        out = (lbl > 0).astype(np.uint8)
    return out

def bin_brats_africa(lbl: np.ndarray) -> np.ndarray:
    uniq = set(np.unique(lbl).tolist())
    if 1 in uniq or 3 in uniq:
        out = ((lbl == 1) | (lbl == 3)).astype(np.uint8)
    else:
        out = (lbl > 0).astype(np.uint8)
    return out

def bin_mu_glioma_post(lbl: np.ndarray) -> np.ndarray:
    uniq = set(np.unique(lbl).tolist())
    if 1 in uniq or 3 in uniq:
        out = ((lbl == 1) | (lbl == 3)).astype(np.uint8)
    else:
        out = (lbl > 0).astype(np.uint8)
    return out

def bin_ucsd_ptgbm(lbl: np.ndarray) -> np.ndarray:
    uniq = set(np.unique(lbl).tolist())
    if 1 in uniq or 3 in uniq:
        out = ((lbl == 1) | (lbl == 3)).astype(np.uint8)
    else:
        out = (lbl > 0).astype(np.uint8)
    return out

def bin_upenn(lbl: np.ndarray) -> np.ndarray:
    uniq = set(np.unique(lbl).tolist())
    if 1 in uniq or 4 in uniq:
        out = ((lbl == 1) | (lbl == 4)).astype(np.uint8)
    else:
        out = (lbl > 0).astype(np.uint8)
    return out

def bin_pretreat_mets(lbl: np.ndarray) -> np.ndarray:
    uniq = set(np.unique(lbl).tolist())
    if 3 in uniq:
        out = (lbl == 3).astype(np.uint8)
    else:
        out = (lbl > 0).astype(np.uint8)
    return out

# ---------- Dataset scanners ----------

def scan_ucsf_pdgm(base: Path) -> List[Dict]:
    # Base_path/UCSF-PDGM-v5/ contains per-patient folders.
    root = base / "PKG - UCSF-PDGM Version 5"  # adjust name if directory differs
    if not root.exists():
        # Try alternative common folder name
        root = base / "UCSF-PDGM-v5"
    results = []
    if not root.exists():
        return results
    for case_dir in sorted([p for p in root.iterdir() if p.is_dir()]):
        files = list(case_dir.glob("*.nii.gz"))
        img = next((f for f in files if re.search(r"_T1c_bias\.nii\.gz$", f.name, re.IGNORECASE)), None)
        seg = next((f for f in files if re.search(r"_tumor_segmentation\.nii\.gz$", f.name, re.IGNORECASE)), None)
        if img is None or seg is None:
            continue
        results.append({
            "case_id": case_dir.name,
            "image": img,
            "label": seg,
            "class_label": 0,  # Glioma
            "bin_fn": bin_ucsf_pdgm,
        })
    return results

def scan_brats_africa(base: Path) -> List[Dict]:
    root = base / "PKG-BraTS-Africa" / "95_Glioma"
    results = []
    if not root.exists():
        return results
    for case_dir in sorted([p for p in root.iterdir() if p.is_dir()]):
        img = next((f for f in case_dir.glob("*-t1c.nii.gz")), None)
        seg = next((f for f in case_dir.glob("*-seg.nii.gz")), None)
        if img is None or seg is None:
            continue
        results.append({
            "case_id": case_dir.name,
            "image": img,
            "label": seg,
            "class_label": 0,
            "bin_fn": bin_brats_africa,
        })
    return results

def scan_mu_glioma_post(base: Path) -> List[Dict]:
    root = base / "PKG-MU-Glioma-Post" / "MU-Glioma-Post"
    results = []
    if not root.exists():
        return results
    for patient in sorted([p for p in root.iterdir() if p.is_dir()]):
        for tp in sorted([p for p in patient.iterdir() if p.is_dir() and p.name.startswith("Timepoint_")]):
            img = next((f for f in tp.glob("*brain_t1c.nii.gz")), None)
            seg = next((f for f in tp.glob("*tumorMask.nii.gz")), None)
            if img is None or seg is None:
                continue
            case_id = f"{patient.name}_{tp.name}"
            results.append({
                "case_id": case_id,
                "image": img,
                "label": seg,
                "class_label": 0,
                "bin_fn": bin_mu_glioma_post,
            })
    return results

def scan_ucsd_ptgbm(base: Path) -> List[Dict]:
    root = base / "PKG-UCSD-PTGBM-v1" / "UCSD-PTGBM"
    results = []
    if not root.exists():
        return results
    for case_dir in sorted([p for p in root.iterdir() if p.is_dir()]):
        img = next((f for f in case_dir.glob("*_T1post.nii.gz")), None)
        # Prefer BraTS-style multi-class seg, if available
        seg = next((f for f in case_dir.glob("*BraTS_tumor_seg.nii.gz")), None)
        if img is None or seg is None:
            continue
        results.append({
            "case_id": case_dir.name,
            "image": img,
            "label": seg,
            "class_label": 0,
            "bin_fn": bin_ucsd_ptgbm,
        })
    return results

def scan_upenn(base: Path) -> List[Dict]:
    root = base / "PKG-UPENN-GBM-NIfTI" / "UPENN-GBM" / "NIfTI-files"
    images_root = root / "images_structural_unstripped"
    masks_root = root / "images_segm"
    results = []
    if not images_root.exists() or not masks_root.exists():
        return results
    id_img_re = re.compile(r"(UPENN-GBM-\d{5}_\d{2})_T1GD_unstripped\.nii\.gz$", re.IGNORECASE)
    id_mask_re = re.compile(r"(UPENN-GBM-\d{5}_\d{2})_segm\.nii\.gz$", re.IGNORECASE)
    # Index masks
    id_to_mask = {}
    for m in masks_root.glob("*.nii.gz"):
        mm = id_mask_re.match(m.name)
        if mm:
            id_to_mask[mm.group(1)] = m
    # Walk images
    for case_dir in sorted([d for d in images_root.iterdir() if d.is_dir()]):
        img = next((f for f in case_dir.glob("*.nii.gz") if id_img_re.match(f.name)), None)
        if img is None:
            continue
        cid = id_img_re.match(img.name).group(1)
        seg = id_to_mask.get(cid)
        if seg is None:
            continue
        results.append({
            "case_id": cid,
            "image": img,
            "label": seg,
            "class_label": 0,
            "bin_fn": bin_upenn,
        })
    return results

def scan_bcbm_radiogenomics(base: Path) -> List[Dict]:
    root = base / "PKG-BCBM-RadioGenomics_Images_Masks_Dec2024" / "BCBM_KSC_curated_data"
    results = []
    if not root.exists():
        return results
    for case_dir in sorted([p for p in root.iterdir() if p.is_dir()]):
        img = next((f for f in case_dir.glob("*_image_ss_n4.nii.gz")), None)
        masks = [f for f in case_dir.glob("*_mask_*.nii.gz")]
        if img is None or not masks:
            continue
        results.append({
            "case_id": case_dir.name,
            "image": img,
            "mask_list": masks,  # union required
            "class_label": 1,  # Metastatic
            "bin_fn": None,    # handled via union
        })
    return results

def scan_pretreat_mets(base: Path) -> List[Dict]:
    root = base / "PKG-Pretreat-MetsToBrain-Masks" / "Pretreat-MetsToBrain-Masks"
    results = []
    if not root.exists():
        return results
    for case_dir in sorted([p for p in root.iterdir() if p.is_dir()]):
        img = next((f for f in case_dir.glob("*-t1c.nii.gz")), None)
        seg = next((f for f in case_dir.glob("*-seg.nii.gz")), None)
        if img is None or seg is None:
            continue
        results.append({
            "case_id": case_dir.name,
            "image": img,
            "label": seg,
            "class_label": 1,
            "bin_fn": bin_pretreat_mets,
        })
    return results

# ---------- Build unified dataset ----------

def union_masks(mask_paths: List[Path]) -> np.ndarray:
    accum = None
    for mp in mask_paths:
        arr, _ = load_arr_ras(mp, is_seg=True)
        binm = (arr > 0).astype(np.uint8)
        accum = binm if accum is None else np.logical_or(accum, binm).astype(np.uint8)
    return accum

def to_unified_case_id(source: str, case_id: str) -> str:
    return f"{source}_{case_id}"

def write_case(out_root: Path, unified_id: str, img_arr: np.ndarray, img_aff: np.ndarray, lbl_arr: np.ndarray) -> Tuple[Path, Path]:
    # nnU-Net naming: image ends with _0000.nii.gz; label uses base name
    images_dir = out_root / "imagesTr"
    labels_dir = out_root / "labelsTr"
    images_dir.mkdir(parents=True, exist_ok=True)
    labels_dir.mkdir(parents=True, exist_ok=True)
    img_path = images_dir / f"{unified_id}_0000.nii.gz"
    lbl_path = labels_dir / f"{unified_id}.nii.gz"
    save_nifti(img_path, img_arr.astype(np.float32), img_aff)
    save_nifti(lbl_path, lbl_arr.astype(np.uint8), img_aff)
    return img_path, lbl_path

def main_build(base_path: str, out_root: str, splits: Tuple[float,float,float]=(0.7,0.15,0.15), seed: int=42):
    base = Path(base_path)
    out = Path(out_root)
    out.mkdir(parents=True, exist_ok=True)

    # Discover all sources
    entries: List[Dict] = []
    entries += [{"source":"UCSF_PDGM", **e} for e in scan_ucsf_pdgm(base)]
    entries += [{"source":"BRATS_AFRICA", **e} for e in scan_brats_africa(base)]
    entries += [{"source":"MU_GLIOMA_POST", **e} for e in scan_mu_glioma_post(base)]
    entries += [{"source":"UCSD_PTGBM", **e} for e in scan_ucsd_ptgbm(base)]
    entries += [{"source":"UPENN_GBM", **e} for e in scan_upenn(base)]
    entries += [{"source":"BCBM_RADIOGENOMICS", **e} for e in scan_bcbm_radiogenomics(base)]
    entries += [{"source":"PRETREAT_METS", **e} for e in scan_pretreat_mets(base)]

    if not entries:
        raise RuntimeError("No cases found across the specified datasets")

    # Build unified cases
    index_rows = []  # for dataset.json + CSVs
    for e in entries:
        unified_id = to_unified_case_id(e["source"], e["case_id"])
        # Load image
        img_arr, img_aff = load_arr_ras(e["image"], is_seg=False)

        # Build label
        if e["source"] == "BCBM_RADIOGENOMICS":
            lbl_arr = union_masks(e["mask_list"])
        else:
            lbl_raw, _ = load_arr_ras(e["label"], is_seg=True)
            lbl_arr = e["bin_fn"](lbl_raw) if e.get("bin_fn") else ensure_binary(lbl_raw)

        # Write
        img_path, lbl_path = write_case(out, unified_id, img_arr, img_aff, lbl_arr)
        index_rows.append({
            "case_id": unified_id,
            "class_label": int(e["class_label"]),
            "image": str(img_path),
            "label": str(lbl_path),
            "source": e["source"],
            "orig_case_id": e["case_id"],
        })

    # Write dataset.json (nnU-Net style + extra 'class_label' in training entries)
    ds_json = {
        "name": "UnifiedDualTask",
        "description": "Unified T1c + binary tumor dataset (glioma vs metastatic) built from multiple sources",
        "tensorImageSize": "3D",
        "modality": {"0": "t1gd"},
        "labels": {"0": "background", "1": "tumor"},
        "numTraining": len(index_rows),
        "training": [
            {"image": os.path.relpath(r["image"], out_root), "label": os.path.relpath(r["label"], out_root), "class_label": r["class_label"]}
            for r in index_rows
        ],
        "numTest": 0,
        "test": []
    }
    with open(out / "dataset.json", "w") as f:
        json.dump(ds_json, f, indent=2)

    # Make splits (patient/case disjoint, stratified by class)
    # Group by case_id -> class_label
    class_to_cases: Dict[int, List[str]] = {}
    for r in index_rows:
        class_to_cases.setdefault(int(r["class_label"]), []).append(r["case_id"])
    rnd = random.Random(seed)
    train_ids, val_ids, test_ids = [], [], []
    trf, vf, tf = splits
    for lab, cids in class_to_cases.items():
        rnd.shuffle(cids)
        n = len(cids)
        n_train = int(round(trf * n))
        n_val = int(round(vf * n))
        n_test = max(0, n - n_train - n_val)
        train_ids += cids[:n_train]
        val_ids += cids[n_train:n_train + n_val]
        test_ids += cids[n_train + n_val:]

    def _write_split(name: str, ids: List[str]):
        p = out / f"{name}.csv"
        with p.open("w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["case_id", "class_label", "image_path", "label_path"])
            for cid in ids:
                r = next(rr for rr in index_rows if rr["case_id"] == cid)
                w.writerow([cid, r["class_label"], r["image"], r["label"]])
        return p

    p_train = _write_split("train", train_ids)
    p_val = _write_split("val", val_ids)
    p_test = _write_split("test", test_ids)

    print(f"[DONE] Unified dataset written to: {out_root}")
    print(f" - imagesTr: {len(list((out/'imagesTr').glob('*.nii*')))}")
    print(f" - labelsTr: {len(list((out/'labelsTr').glob('*.nii*')))}")
    print(f" - train: {p_train}")
    print(f" - val:   {p_val}")
    print(f" - test:  {p_test}")

# Run the builder (do not run in this message; copy/paste into your environment and execute)
# main_build(BASE_PATH, OUT_ROOT, splits=SPLIT_RATIOS, seed=SEED)

In [None]:
# main_build(BASE_PATH, OUT_ROOT, splits=SPLIT_RATIOS, seed=SEED)

# Review curated dataset

In [3]:
# ----------------------------------------------------------------------
#  Dataset review / curation helper (nnU‑Net style)
#  ---------------------------------------------------------------
#  * Works with a `dataset.json` under <root_path>
#  * Optional train/val/test CSV split files are respected
#  * Unsatisfactory cases can be flagged with a note and saved to CSV
#
#  Requirements: nibabel, numpy, matplotlib, ipywidgets
# ----------------------------------------------------------------------
from pathlib import Path
import json, csv
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, to_rgba
from ipywidgets import (
    Dropdown, IntSlider, RadioButtons, FloatSlider,
    Checkbox, Textarea, Button, HBox, VBox, HTML, Output, Layout, Text
)
from IPython.display import display

# ----------------------------------------------------------------------
#  Utility helpers
# ----------------------------------------------------------------------
def _as_ras(img: nib.Nifti1Image) -> nib.Nifti1Image:
    """Return a NIfTI image in RAS orientation."""
    return nib.as_closest_canonical(img)

def _load_arr_ras(path: Path, is_seg: bool = False) -> tuple[np.ndarray, np.ndarray]:
    """Load a nifti file, force it to RAS, and return its 3‑D array + affine."""
    img = nib.load(str(path))
    img = _as_ras(img)
    arr = img.get_fdata(dtype=np.float32)
    if is_seg:
        arr = np.rint(arr).astype(np.int16)        # make segmentation integers
    return arr, img.affine

def _normalize(img2d: np.ndarray, lo: int = 1, hi: int = 99) -> np.ndarray:
    """Clip the image to the 1‑99 percentile window and scale to 0‑1."""
    valid = np.isfinite(img2d)
    if not np.any(valid):
        return img2d
    a, b = np.percentile(img2d[valid], [lo, hi])
    if b <= a:
        return img2d
    img2d = np.clip(img2d, a, b)
    return (img2d - a) / (b - a + 1e-6)

def _apply_window_percentile(img2d: np.ndarray, p_low: float, p_high: float) -> np.ndarray:
    valid = np.isfinite(img2d)
    if not np.any(valid): return img2d
    lo, hi = np.percentile(img2d[valid], [p_low, p_high])
    if hi <= lo: return img2d
    img2d = np.clip(img2d, lo, hi)
    return (img2d - lo) / (hi - lo + 1e-6)

def _apply_window_center_width(img2d: np.ndarray, center: float, width: float) -> np.ndarray:
    lo, hi = center - width/2.0, center + width/2.0
    if hi <= lo: return img2d
    img2d = np.clip(img2d, lo, hi)
    return (img2d - lo) / (hi - lo + 1e-6)

def _extract_slice(vol: np.ndarray, axis: int, idx: int) -> np.ndarray:
    """Return a 2‑D slice of a 3‑D volume."""
    return vol.take(indices=idx, axis=axis)

# ----------------------------------------------------------------------
#  Dataset parsing helpers
# ----------------------------------------------------------------------
def _parse_dataset(root: Path) -> list[dict]:
    """Return a list of cases from `dataset.json` (or by convention)."""
    root = Path(root)
    dsj = root / "dataset.json"

    if not dsj.exists():
        # fall back to imagesTr/labelsTr convention
        images = sorted((root / "imagesTr").glob("*.nii*"))
        labels = {p.name.replace("_0000", ""): p for p in (root / "labelsTr").glob("*.nii*")}
        rows = []
        for im in images:
            lbl_name = im.name.replace("_0000", "")
            lp = labels.get(lbl_name)
            if lp:
                rows.append({
                    "case_id": im.stem.split("_")[0],
                    "image": str(im),
                    "label": str(lp),
                    "class_label": None,
                    "source": None
                })
        return rows

    # normal nnU‑Net dataset.json
    with dsj.open() as f:
        meta = json.load(f)

    rows = []
    for t in meta.get("training", []):
        im = (root / t["image"]).resolve()
        lb = (root / t["label"]).resolve()
        rows.append({
            "case_id": Path(lb).stem,
            "image": str(im),
            "label": str(lb),
            "class_label": t.get("class_label"),
            "source": t.get("source")
        })
    return rows

def _load_split_csvs(root: Path) -> dict[str, str]:
    """Return a dict {case_id : 'train'/'val'/'test'} if split CSVs exist."""
    cid2split = {}
    for nm in ("train", "val", "test"):
        p = root / f"{nm}.csv"
        if p.exists():
            with p.open() as f:
                for r in csv.DictReader(f):
                    cid = r.get("case_id")
                    if cid:
                        cid2split[cid] = nm
    return cid2split


# ----------------------------------------------------------------------
#  Main interactive UI
# ----------------------------------------------------------------------
def launch_curation_tool(root_path: str | Path,
                         output_csv: str | Path | None = None,
                         preload_csv: str | Path | None = None) -> None:
    """Show an interactive curation UI."""
    root = Path(root_path)
    cases = _parse_dataset(root)
    cid2split = _load_split_csvs(root)

    # -------------  add split info to each case  ---------------------
    for r in cases:
        r["split"] = cid2split.get(r["case_id"], "NA")

    # -------------  preload any previously flagged cases -----------
    flagged = {}
    if preload_csv:
        p = Path(preload_csv)
        if p.exists():
            with p.open() as f:
                for r in csv.DictReader(f):
                    flagged[r["case_id"]] = {
                        "note": r.get("note", ""),
                        "image_path": r.get("image_path", ""),
                        "label_path": r.get("label_path", ""),
                        "split": r.get("split", "")
                    }

    # ------------------------------------------------------------------
    #  UI widgets
    # ------------------------------------------------------------------
    split_opts = ["All"] + sorted(
        {r["split"] for r in cases if r["split"] != "NA"}
    )
    split_dd = Dropdown(
        options=split_opts,
        value="All",
        description="Split:",
        layout=Layout(width="180px")
    )

    def _filter_cases():
        s = split_dd.value
        return cases if s == "All" else [r for r in cases if r["split"] == s]

    case_dd = Dropdown(
        options=[r["case_id"] for r in _filter_cases()],
        description="Case:",
        layout=Layout(width="360px")
    )

    plane_rb = RadioButtons(
        options=[("Axial", 2), ("Coronal", 1), ("Sagittal", 0)],
        value=2,
        description="Plane:",
        layout=Layout(width="220px")
    )

    slice_slider = IntSlider(
        description="Slice:",
        min=0,
        max=1,
        value=0,
        continuous_update=False,
        layout=Layout(width="420px")
    )

    alpha_slider = FloatSlider(
        description="Alpha:",
        min=0.1,
        max=1.0,
        step=0.05,
        value=0.5,
        readout_format=".2f",
        layout=Layout(width="220px")
    )

    contour_cb = Checkbox(value=False, description="Contour only")
    unsat_cb = Checkbox(value=False, description="Unsatisfactory")

    # Zoom/magnifier
    fig_size = FloatSlider(description="FigSize:", min=5.0, max=12.0, step=0.5, value=6.5, layout=Layout(width="240px"))
    zoom_factor = FloatSlider(description="Zoom×", min=1.0, max=6.0, step=0.1, value=1.0, readout_format=".1f", layout=Layout(width="240px"))
    zoom_cx = FloatSlider(description="Cx", min=0.0, max=1.0, step=0.01, value=0.5, readout_format=".2f", layout=Layout(width="220px"))
    zoom_cy = FloatSlider(description="Cy", min=0.0, max=1.0, step=0.01, value=0.5, readout_format=".2f", layout=Layout(width="220px"))

    # Status (Reviewed vs Unsatisfactory). Replaces/augments your unsat_cb.
    status_rb = RadioButtons(
        options=[("Reviewed", "reviewed"), ("Unsatisfactory", "unsatisfactory")],
        value="reviewed",
        description="Status:",
        layout=Layout(width="220px")
    )

    # Counters
    count_html = HTML(value="", layout=Layout(width="100%"))

    

    window_mode = Dropdown(
    options=["Percentile", "Center/Width"], value="Percentile",
    description="Window:", layout=Layout(width="200px")
    )
    p_low = IntSlider(description="P_low:", min=0, max=20, value=1, layout=Layout(width="250px"))
    p_high = IntSlider(description="P_high:", min=80, max=100, value=99, layout=Layout(width="250px"))
    center_slider = FloatSlider(description="Center:", min=-500.0, max=500.0, step=1.0, value=0.0, readout_format=".1f", layout=Layout(width="300px"))
    width_slider = FloatSlider(description="Width:", min=1.0, max=2000.0, step=1.0, value=200.0, readout_format=".1f", layout=Layout(width="300px"))

    note_txt = Textarea(
        value="",
        placeholder="Optional note/reason ...",
        description="Note:",
        layout=Layout(width="420px", height="60px")
    )

    save_btn = Button(description="Save CSV", button_style="warning")
    next_btn = Button(description="Next ▶")
    prev_btn = Button(description="◀ Prev")

    info_html = HTML(layout=Layout(width="100%"))
    out = Output(layout=Layout(border="1px solid #ddd"))
    csv_path_txt = Text(
        value=str(output_csv or root / "unsatisfactory.csv"),
        description="CSV:",
        layout=Layout(width="520px")
    )

    stats_html = HTML(layout=Layout(width="100%"))

    # ------------------------------------------------------------------
    #  Internal state
    # ------------------------------------------------------------------
    current_img = None          # 3‑D np.ndarray
    current_lbl = None          # 3‑D np.ndarray
    _is_rendering = False

    # ------------------------------------------------------------------
    #  Helper functions
    # ------------------------------------------------------------------
    def _persist_note_for_case(case_id: str | None):
        if not case_id:
            return
        row = next((r for r in _filter_cases() if r["case_id"] == case_id), None)
        if not row:
            return
        # always persist current status; note can be empty
        flagged[case_id] = {
            "status": status_rb.value,        # "reviewed" or "unsatisfactory"
            "note": note_txt.value,
            "image_path": row["image"],
            "label_path": row["label"],
            "split": row.get("split", "NA")
        }
        _update_counters()

    def _set_slice_slider_safely(max_val: int, value: int):
        """Update the slider without firing its observers."""
        with slice_slider.hold_trait_notifications():
            slice_slider.max = max_val
            slice_slider.value = value

    def _load_case():
        """Load the 3‑D volumes for the currently selected case."""
        nonlocal current_img, current_lbl
        cid = case_dd.value
        row = next(r for r in _filter_cases() if r["case_id"] == cid)

        current_img, img_aff = _load_arr_ras(Path(row["image"]))
        current_lbl, lbl_aff = _load_arr_ras(Path(row["label"]), is_seg=True)

        valid = np.isfinite(current_img)
        if np.any(valid):
            lo, hi = np.percentile(current_img[valid], [1, 99])
            # Percentile sliders
            p_low.value = 1
            p_high.value = 99
            # Center/Width sliders
            center_slider.min = float(current_img[valid].min())
            center_slider.max = float(current_img[valid].max())
            center_slider.value = float((lo + hi) / 2.0)
            width_slider.min = 1.0
            width_slider.max = float(max(10.0, current_img[valid].max() - current_img[valid].min()))
            width_slider.value = float(max(10.0, hi - lo))

        # reset slider to the middle of the volume
        ax = plane_rb.value
        max_idx = max(0, min(current_img.shape[ax], current_lbl.shape[ax]) - 1)
        _set_slice_slider_safely(max_idx, max(0, max_idx // 2))

        # update flag / note widgets
        unsat_cb.unobserve(_on_unsat_change, names="value")
        note_txt.unobserve(_on_note_change, names="value")
        if cid in flagged:
            note_txt.value = flagged[cid]["note"]
            unsat_cb.value = True
        else:
            note_txt.value = ""
            unsat_cb.value = False
        unsat_cb.observe(_on_unsat_change, names="value")
        note_txt.observe(_on_note_change, names="value")

        # restore status/note safely
        status_rb.unobserve(_on_status_change, names="value")
        note_txt.unobserve(_on_note_change, names="value")
        try:
            if cid in flagged:
                status_rb.value = flagged[cid].get("status", "reviewed")
                note_txt.value = flagged[cid].get("note", "")
            else:
                status_rb.value = "reviewed"
                note_txt.value = ""
        finally:
            status_rb.observe(_on_status_change, names="value")
            note_txt.observe(_on_note_change, names="value")

        # also update zoom center defaults for this case
        zoom_cx.value = 0.5
        zoom_cy.value = 0.5

        # refresh counters (in case filter changed)
        _update_counters()

        # info line
        info_html.value = (
            f"Class: <b>{row.get('class_label','NA')}</b> | "
            f"Source: <b>{row.get('source','NA')}</b> | "
            f"Split: <b>{row.get('split','NA')}</b>"
        )

        # spacing from affines (mm)
        def _spacing_from_aff(aff):
            return (abs(float(aff[0, 0])), abs(float(aff[1, 1])), abs(float(aff[2, 2])))

        sp_img = _spacing_from_aff(img_aff)
        sp_lbl = _spacing_from_aff(lbl_aff)

        # intensity stats (raw 3D image)
        valid_img = np.isfinite(current_img)
        vmin = float(np.nanmin(current_img[valid_img])) if np.any(valid_img) else float("nan")
        vmax = float(np.nanmax(current_img[valid_img])) if np.any(valid_img) else float("nan")
        vmean = float(np.nanmean(current_img[valid_img])) if np.any(valid_img) else float("nan")
        vstd = float(np.nanstd(current_img[valid_img])) if np.any(valid_img) else float("nan")
        p1, p99 = (np.percentile(current_img[valid_img], [1, 99]) if np.any(valid_img) else (float("nan"), float("nan")))

        # label stats
        lbl_nz = int((current_lbl > 0).sum())
        lbl_total = int(np.prod(current_lbl.shape))
        lbl_cov = (lbl_nz / lbl_total) if lbl_total > 0 else 0.0
        lbl_uniq = ", ".join(str(int(x)) for x in np.unique(current_lbl))

        stats_html.value = (
            f"Image shape: <b>{current_img.shape}</b> | Label shape: <b>{current_lbl.shape}</b> | "
            f"Spacing img (mm): <b>{sp_img[0]:.3f}, {sp_img[1]:.3f}, {sp_img[2]:.3f}</b> | "
            f"Spacing lbl (mm): <b>{sp_lbl[0]:.3f}, {sp_lbl[1]:.3f}, {sp_lbl[2]:.3f}</b><br/>"
            f"Intensity (raw): min <b>{vmin:.3f}</b>, max <b>{vmax:.3f}</b>, mean <b>{vmean:.3f}</b>, std <b>{vstd:.3f}</b>, "
            f"p1 <b>{p1:.3f}</b>, p99 <b>{p99:.3f}</b><br/>"
            f"Labels: unique <b>{lbl_uniq}</b> | voxels >0: <b>{lbl_nz}</b> "
            f"({lbl_cov*100:.2f}% of volume)"
        )

    def _render(*_):
        """Draw a single frame inside the `Output` widget."""
        nonlocal _is_rendering
        if _is_rendering:
            return
        _is_rendering = True
        try:
            if current_img is None or current_lbl is None:
                _load_case()

            with out:
                out.clear_output(wait=True)

                ax_val = plane_rb.value
                # Guard for very thin volumes and keep label/image shapes in sync
                max_idx = max(0, min(current_img.shape[ax_val], current_lbl.shape[ax_val]) - 1)
                idx = min(slice_slider.value, max_idx)

                img2d = _extract_slice(current_img, ax_val, idx)
                lbl2d = _extract_slice(current_lbl, ax_val, idx)

                # windowing (reuse your existing windowing selection, else default percentile)
                if window_mode.value == "Percentile":
                    shown = _apply_window_percentile(img2d, float(p_low.value), float(p_high.value))
                else:
                    shown = _apply_window_center_width(img2d, float(center_slider.value), float(width_slider.value))

                # zoom crop (center in [0,1] units)
                shown = _crop_for_zoom(shown, float(zoom_factor.value), float(zoom_cx.value), float(zoom_cy.value))

                with plt.ioff():
                    fig, ax = plt.subplots(1, 1, figsize=(float(fig_size.value), float(fig_size.value)), dpi=100)
                    ax.imshow(shown, cmap="gray")
                    # mask computed from original lbl2d resized by cropping with same window
                    mask = _crop_for_zoom((lbl2d > 0).astype(np.uint8), float(zoom_factor.value), float(zoom_cx.value), float(zoom_cy.value)) > 0
                    if np.any(mask):
                        col = to_rgba("#d62728", alpha=alpha_slider.value)
                        if contour_cb.value:
                            ax.contour(mask.astype(float), levels=[0.5], colors=[col], linewidths=1.5)
                        else:
                            ax.imshow(mask.astype(int), cmap=ListedColormap([(0, 0, 0, 0), col]), interpolation="none")
                    ax.axis("off")
                    title = f"{case_dd.value} | {['Sagittal','Coronal','Axial'][ax_val]} | z={idx}"
                    ax.set_title(title)
                    out.append_display_data(fig)
                    plt.close(fig)

        finally:
            _is_rendering = False

    def _crop_for_zoom(img2d: np.ndarray, zoom: float, cx: float, cy: float) -> np.ndarray:
        if zoom <= 1.0:
            return img2d
        h, w = img2d.shape
        win_h = max(1, int(h / zoom))
        win_w = max(1, int(w / zoom))
        cy_px = int(np.clip(cy, 0.0, 1.0) * (h - 1))
        cx_px = int(np.clip(cx, 0.0, 1.0) * (w - 1))
        y0 = np.clip(cy_px - win_h // 2, 0, h - win_h)
        x0 = np.clip(cx_px - win_w // 2, 0, w - win_w)
        return img2d[y0:y0 + win_h, x0:x0 + win_w]

    def _update_counters():
        fc = _filter_cases()
        total = len(fc)
        # statuses we’ve saved so far
        reviewed_set = {cid for cid, info in flagged.items() if info.get("status") in {"reviewed", "unsatisfactory"}}
        unsat_set = {cid for cid, info in flagged.items() if info.get("status") == "unsatisfactory"}
        reviewed = len([r for r in fc if r["case_id"] in reviewed_set])
        unsat = len([r for r in fc if r["case_id"] in unsat_set])
        remaining = max(0, total - reviewed)
        count_html.value = f"Reviewed: <b>{reviewed}</b> | Unsatisfactory: <b>{unsat}</b> | Remaining: <b>{remaining}</b> | Total: <b>{total}</b>"

    # ------------------------------------------------------------------
    #  Event handlers
    # ------------------------------------------------------------------
    def _on_split_change(change):
        if change["name"] != "value":
            return
        _persist_note_for_case(case_dd.value)
        case_dd.options = [r["case_id"] for r in _filter_cases()]
        if case_dd.options:
            case_dd.value = case_dd.options[0]
            _load_case()
            _render()

    def _on_case_change(change):
        if change["name"] != "value":
            return
        _persist_note_for_case(change.get("old"))
        _load_case()
        _render()

    def _on_unsat_change(change):
        if change["name"] != "value":
            return
        cid = case_dd.value
        row = next(r for r in _filter_cases() if r["case_id"] == cid)

        if change["new"]:                       # checked
            if cid not in flagged:
                flagged[cid] = {
                    "note": note_txt.value,
                    "image_path": row["image"],
                    "label_path": row["label"],
                    "split": row.get("split", "NA")
                }
            else:
                # keep the existing note if the user edited it
                if note_txt.value:
                    flagged[cid]["note"] = note_txt.value
                flagged[cid]["image_path"] = row["image"]
                flagged[cid]["label_path"] = row["label"]
                flagged[cid]["split"] = row.get("split", "NA")
        else:                                   # unchecked
            flagged.pop(cid, None)

    def _on_note_change(change):
        cid = case_dd.value
        if unsat_cb.value:
            if cid not in flagged:
                flagged[cid] = {"note": "", "image_path": "", "label_path": "", "split": "NA"}
            flagged[cid]["note"] = note_txt.value

    def _save_csv(_):
        p = Path(csv_path_txt.value)
        p.parent.mkdir(parents=True, exist_ok=True)
        with p.open("w", newline="") as f:
            w = csv.writer(f)
            w.writerow(["case_id", "image_path", "label_path", "split", "status", "note"])
            for cid, info in flagged.items():
                w.writerow([
                    cid,
                    info.get("image_path", ""),
                    info.get("label_path", ""),
                    info.get("split", ""),
                    info.get("status", "reviewed"),
                    info.get("note", "")
                ])
        print(f"[SAVED] {p} ({len(flagged)} flagged)")

    def _goto_next(_):
        _persist_note_for_case(case_dd.value)
        idx = list(case_dd.options).index(case_dd.value)
        case_dd.value = case_dd.options[(idx + 1) % len(case_dd.options)]

    def _goto_prev(_):
        _persist_note_for_case(case_dd.value)
        idx = list(case_dd.options).index(case_dd.value)
        case_dd.value = case_dd.options[(idx - 1) % len(case_dd.options)]

    def _on_plane_change(change):
        if change["name"] != "value" or current_img is None or current_lbl is None:
            return
        ax = plane_rb.value
        max_idx = max(0, min(current_img.shape[ax], current_lbl.shape[ax]) - 1)
        _set_slice_slider_safely(max_idx, max(0, max_idx // 2))
        _render()

    def _on_slice_change(change):
        if change["name"] != "value" or _is_rendering:
            return
        _render()

    def _on_status_change(change):
        if change["name"] != "value":
            return
        _persist_note_for_case(case_dd.value)

    def _on_window_change(change):
        if change["name"] != "value" or _is_rendering:
            return
        _render()

    # ------------------------------------------------------------------
    #  Wire everything together
    # ------------------------------------------------------------------
    split_dd.observe(_on_split_change, names="value")
    case_dd.observe(_on_case_change, names="value")
    plane_rb.observe(_on_plane_change, names="value")
    slice_slider.observe(_on_slice_change, names="value")
    alpha_slider.observe(_render, names="value")
    contour_cb.observe(_render, names="value")
    unsat_cb.observe(_on_unsat_change, names="value")
    note_txt.observe(_on_note_change, names="value")
    save_btn.on_click(_save_csv)
    next_btn.on_click(_goto_next)
    prev_btn.on_click(_goto_prev)
    window_mode.observe(_render, names="value")
    p_low.observe(_render, names="value")
    p_high.observe(_render, names="value")
    center_slider.observe(_render, names="value")
    width_slider.observe(_render, names="value")
    status_rb.observe(_on_status_change, names="value")
    fig_size.observe(_on_window_change, names="value")
    zoom_factor.observe(_on_window_change, names="value")
    zoom_cx.observe(_on_window_change, names="value")
    zoom_cy.observe(_on_window_change, names="value")

    # ------------------------------------------------------------------
    #  Build the UI layout
    # ------------------------------------------------------------------
    row0 = HBox([split_dd, case_dd, prev_btn, next_btn])
    row1 = HBox([plane_rb, slice_slider])
    row2 = HBox([alpha_slider, contour_cb, status_rb])
    rowW1 = HBox([window_mode, p_low, p_high])
    rowW2 = HBox([center_slider, width_slider])
    rowZ  = HBox([fig_size, zoom_factor, zoom_cx, zoom_cy])
    row3 = HBox([HTML("CSV Path:"), csv_path_txt, save_btn])

    ui = VBox([row0, count_html, row1, row2, rowW1, rowW2, rowZ, note_txt, info_html, stats_html, out, row3])

    # ------------------------------------------------------------------
    #  Initial draw
    # ------------------------------------------------------------------
    if case_dd.options:
        # resume: go to the first unreviewed in current filter
        fc = _filter_cases()
        reviewed_or_unsat = {cid for cid, info in flagged.items() if info.get("status") in {"reviewed", "unsatisfactory"}}
        next_cid = next((r["case_id"] for r in fc if r["case_id"] not in reviewed_or_unsat), fc[0]["case_id"])
        case_dd.value = next_cid
        _load_case()
        _render()

    display(ui)

# ----------------------------------------------------------------------
#  Example usage
# ----------------------------------------------------------------------
# root = "/path/to/your/nnunet/dataset"
# launch_curation_tool(root, output_csv=None, preload_csv=None)

In [4]:
root = "/Users/chufal/projects/DHAI-Brain-Segmentation/derived/unified_dualtask"
launch_curation_tool(root, output_csv=None, preload_csv="/Users/chufal/projects/DHAI-Brain-Segmentation/derived/unified_dualtask/unsatisfactory.csv")

VBox(children=(HBox(children=(Dropdown(description='Split:', layout=Layout(width='180px'), options=('All',), v…