# nnU-Net v2 — Inference & Visualization Cockpit (Patched)

This notebook is a **cockpit** (QC + visualization).  
- **Training** remains headless (terminal).  
- This notebook focuses on: **(1) picking cases, (2) running ensemble inference, (3) radiology-style visualization, (4) saving QC images**.

✅ Key features in this patched version
- Plane-aware slicing (**axial/coronal/sagittal**) with correct slider bounds  
- Voxel-spacing–aware aspect ratio  
- CT window/level sliders  
- Optional rotate/flip (display only)  
- **Save current view to PNG** button (timestamped)


## 0) Configuration (edit these once)

In [45]:
from pathlib import Path
import os

# --- EDIT THESE PATHS TO MATCH YOUR MACHINE ---
REPO_ROOT = Path(r"C:\Users\hyeon\Documents\miniconda_medimg_env\abdomen-multiorgan-segmentation")

# Inputs/outputs for inference
IN_DIR  = REPO_ROOT / "inference" / "inputs" / "my_ct_cases"
OUT_DIR = REPO_ROOT / "inference" / "outputs" / "my_ct_cases_pred"

# Where QC images will be saved
QC_DIR  = REPO_ROOT / "inference" / "qc_exports"
QC_DIR.mkdir(parents=True, exist_ok=True)

# nnU-Net dataset id + config
DATASET_ID = 701
CONFIG = "3d_fullres"
FOLDS = [0, 1, 2, 3, 4]   # ensemble


## 1) Environment check (nnU-Net paths)

In [46]:
import os

for k in ["nnUNet_raw", "nnUNet_preprocessed", "nnUNet_results"]:
    print(f"{k} =", os.environ.get(k))

missing = [k for k in ["nnUNet_raw","nnUNet_preprocessed","nnUNet_results"] if not os.environ.get(k)]
if missing:
    raise RuntimeError(
        "Missing nnU-Net env vars. Launch Jupyter like:\n"
        "  conda activate medimg\n"
        "  call baseline_nnunet\\set_nnunet_paths.bat\n"
        "  jupyter lab"
    )
print("✅ nnU-Net paths look set.")


nnUNet_raw = C:\Users\hyeon\Documents\miniconda_medimg_env\data\nnunet_amos22\nnunet_raw
nnUNet_preprocessed = C:\Users\hyeon\Documents\miniconda_medimg_env\data\nnunet_amos22\nnunet_preprocessed
nnUNet_results = C:\Users\hyeon\Documents\miniconda_medimg_env\data\nnunet_amos22\nnunet_results
✅ nnU-Net paths look set.


## 2) Optional helper — copy a random AMOS22 training case into IN_DIR

In [47]:
import random, shutil
from pathlib import Path
import os

nnunet_raw = os.environ.get("nnUNet_raw")
dataset_name = f"Dataset{DATASET_ID:03d}_AMOS22"
imagesTr = Path(nnunet_raw) / dataset_name / "imagesTr"

if not imagesTr.exists():
    raise RuntimeError(f"Cannot find imagesTr at: {imagesTr}")

cases = sorted(imagesTr.glob("*_0000.nii.gz"))
print("Total candidates:", len(cases))

case = random.choice(cases)
IN_DIR.mkdir(parents=True, exist_ok=True)

dest = IN_DIR / case.name
if not dest.exists():
    shutil.copy(case, dest)
    print("Copied:", case.name, "→", dest)
else:
    print("Already exists:", dest)

print("\nNow run inference (next cell) to produce predictions in OUT_DIR.")


Total candidates: 240
Copied: AMOS22_amos_0088_0000.nii.gz → C:\Users\hyeon\Documents\miniconda_medimg_env\abdomen-multiorgan-segmentation\inference\inputs\my_ct_cases\AMOS22_amos_0088_0000.nii.gz

Now run inference (next cell) to produce predictions in OUT_DIR.


## 3) Run ensemble inference (uncomment to execute)

In [None]:
OUT_DIR.mkdir(parents=True, exist_ok=True)

cmd = (
    f'nnUNetv2_predict '
    f'-i "{IN_DIR}" '
    f'-o "{OUT_DIR}" '
    f'-d {DATASET_ID} '
    f'-c {CONFIG} '
    f'-f {" ".join(map(str, FOLDS))}'
)
print("Command:\n", cmd)

# Uncomment to run:
!{cmd}

print("\nAfter inference, OUT_DIR should contain files like: <CASE_ID>.nii.gz")


## 4) Pick a case from IN_DIR and load prediction from OUT_DIR

In [81]:
CASE_ID = case_ids[9]  # change if desired

In [78]:
import nibabel as nib
import numpy as np
from pathlib import Path

def load_nifti_canonical(path: Path):
    # Load NIfTI and force closest-canonical (RAS) orientation for consistent viewing.
    nii = nib.load(str(path))
    nii = nib.as_closest_canonical(nii)
    data = nii.get_fdata(dtype=np.float32)
    return nii, data

def find_pred_file(out_dir: Path, case_id: str):
    candidates = [
        out_dir / f"{case_id}.nii.gz",
        out_dir / f"{case_id}.nii",
        out_dir / f"{case_id}_seg.nii.gz",
        out_dir / f"{case_id}_seg.nii",
    ]
    for c in candidates:
        if c.exists():
            return c
    globs = sorted(out_dir.glob(f"{case_id}*.nii*"))
    return globs[0] if globs else None

# List cases available
case_ids = sorted([p.name.replace("_0000.nii.gz","") for p in IN_DIR.glob("*_0000.nii.gz")])
print("Cases in IN_DIR:", len(case_ids))
print("First 10:", case_ids[:10])

if not case_ids:
    raise RuntimeError(f"No cases found in {IN_DIR}. Copy one in, or add your own CT as *_0000.nii.gz")

img_path = IN_DIR / f"{CASE_ID}_0000.nii.gz"
pred_path = find_pred_file(OUT_DIR, CASE_ID)

print("Image:", img_path.exists(), img_path)
print("Pred :", (pred_path.exists() if pred_path else False), pred_path)

if pred_path is None:
    raise FileNotFoundError(
        f"No prediction found for {CASE_ID}. Run the inference cell first. Expected e.g.\n{OUT_DIR / (CASE_ID + '.nii.gz')}"
    )

img_nii, img = load_nifti_canonical(img_path)
pred_nii, pred = load_nifti_canonical(pred_path)
pred = pred.astype(np.int16)

print("Image shape:", img.shape, "dtype:", img.dtype)
print("Pred  shape:", pred.shape, "dtype:", pred.dtype)
assert img.shape == pred.shape, "Image/pred shape mismatch (should not happen with nnU-Net output)."

# Labels (AMOS22)
LABELS = {
  0: "background",
  1: "spleen",
  2: "right kidney",
  3: "left kidney",
  4: "gall bladder",
  5: "esophagus",
  6: "liver",
  7: "stomach",
  8: "aorta",
  9: "postcava",
  10: "pancreas",
  11: "right adrenal gland",
  12: "left adrenal gland",
  13: "duodenum",
  14: "bladder",
  15: "prostate/uterus",
}

# Spacing (mm) from affine columns
spacing = np.sqrt((img_nii.affine[:3, :3] ** 2).sum(0))  # (sx, sy, sz) in mm
print("Voxel spacing (mm):", spacing)


Cases in IN_DIR: 10
First 10: ['AMOS22_amos_0088', 'AMOS22_amos_0254', 'AMOS22_amos_0259', 'AMOS22_amos_0263', 'AMOS22_amos_0264', 'AMOS22_amos_0268', 'AMOS22_amos_0272', 'AMOS22_amos_0273', 'AMOS22_amos_0274', 'AMOS22_amos_0276']
Image: True C:\Users\hyeon\Documents\miniconda_medimg_env\abdomen-multiorgan-segmentation\inference\inputs\my_ct_cases\AMOS22_amos_0276_0000.nii.gz
Pred : True C:\Users\hyeon\Documents\miniconda_medimg_env\abdomen-multiorgan-segmentation\inference\outputs\my_ct_cases_pred\AMOS22_amos_0276.nii.gz
Image shape: (512, 512, 192) dtype: float32
Pred  shape: (512, 512, 192) dtype: int16
Voxel spacing (mm): [0.61132812 0.61132812 2.        ]


## 5) Radiology-style interactive viewer (plane-aware) + Save PNG

In [None]:
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import datetime

def window_ct(x, wl=40, ww=400):
    lo = wl - ww/2
    hi = wl + ww/2
    x = np.clip(x, lo, hi)
    return (x - lo) / (hi - lo)

def get_slice(vol, idx, plane):
    # vol is canonicalized to RAS, axis meaning: X (LR), Y (PA), Z (IS)
    if plane == "axial":     # fix Z -> show X-Y
        return vol[:, :, idx], (spacing[1], spacing[0])  # (row_sp, col_sp) = (sy, sx)
    if plane == "coronal":   # fix Y -> show X-Z
        return vol[:, idx, :], (spacing[0], spacing[2])  # (sz, sx)
    if plane == "sagittal":  # fix X -> show Y-Z
        return vol[idx, :, :], (spacing[0], spacing[2])  # (sz, sy)
    raise ValueError("plane must be axial/coronal/sagittal")

# --- widgets ---
plane_dd = widgets.Dropdown(options=["axial","coronal","sagittal"], value="axial", description="Plane")
idx_slider = widgets.IntSlider(value=img.shape[2]//2, min=0, max=img.shape[2]-1, step=1, description="Index")

label_options = [("All foreground", -1)] + [(f"{k:02d} — {v}", k) for k, v in LABELS.items() if k != 0]
label_dd = widgets.Dropdown(options=label_options, value=-1, description="Label")

alpha_slider = widgets.FloatSlider(value=0.35, min=0.0, max=1.0, step=0.05, description="Alpha")
wl_slider = widgets.IntSlider(value=40, min=-200, max=200, step=5, description="WL")
ww_slider = widgets.IntSlider(value=400, min=50, max=2000, step=50, description="WW")

rot_chk = widgets.Checkbox(value=True, description="Rotate 90° left")
flip_lr_chk = widgets.Checkbox(value=False, description="Flip LR (display)")

contour_chk = widgets.Checkbox(value=True, description="Contour")
fill_chk = widgets.Checkbox(value=False, description="Fill mask")

save_btn = widgets.Button(description="Save PNG", button_style="success")
save_status = widgets.HTML(value="")

ui1 = widgets.HBox([plane_dd, idx_slider, label_dd])
ui2 = widgets.HBox([alpha_slider, wl_slider, ww_slider])
ui3 = widgets.HBox([rot_chk, flip_lr_chk, contour_chk, fill_chk, save_btn])

out = widgets.Output()

# internal state for saving
_last_fig = {"fig": None, "meta": None}

def update_slider_for_plane(*args):
    plane = plane_dd.value
    if plane == "axial":
        idx_slider.max = img.shape[2] - 1
    elif plane == "coronal":
        idx_slider.max = img.shape[1] - 1
    elif plane == "sagittal":
        idx_slider.max = img.shape[0] - 1
    idx_slider.value = min(idx_slider.value, idx_slider.max)

plane_dd.observe(update_slider_for_plane, names="value")
update_slider_for_plane()

def render(plane, idx, label_id, alpha, wl, ww, rot90_left, flip_lr, contour, fill_mask):
    with out:
        clear_output(wait=True)

        ct2d, (row_sp, col_sp) = get_slice(img, idx, plane)
        seg2d, _ = get_slice(pred.astype(np.int16), idx, plane)

        ct2d = window_ct(ct2d, wl=wl, ww=ww)

        # Optional display-only transforms
        if rot90_left:
            ct2d = np.rot90(ct2d, k=1)
            seg2d = np.rot90(seg2d, k=1)
            row_sp, col_sp = col_sp, row_sp  # swap after rotation

        if flip_lr:
            ct2d = np.fliplr(ct2d)
            seg2d = np.fliplr(seg2d)

        aspect = row_sp / col_sp

        fig = plt.figure(figsize=(7, 7))
        ax = plt.gca()
        ax.imshow(ct2d, cmap="gray", aspect=aspect)
        ax.axis("off")

        if label_id == -1:
            mask = (seg2d != 0)
            label_name = "AllForeground"
            title = f"{CASE_ID} — {plane} idx={idx} — All foreground"
        else:
            label_id = int(label_id)
            mask = (seg2d == label_id)
            label_name = f"label{label_id:02d}_{LABELS.get(label_id, str(label_id)).replace(' ','_')}"
            title = f"{CASE_ID} — {plane} idx={idx} — {LABELS.get(label_id, label_id)} (id={label_id})"

        if fill_mask:
            if plane == "axial":     # fix Z -> show X-Y
                ax.imshow(mask.astype(np.float32), alpha=alpha, aspect=spacing[1]/spacing[0])
            elif plane == "coronal":   # fix Y -> show X-Z
                ax.imshow(mask.astype(np.float32), alpha=alpha, aspect=spacing[2]/spacing[0])
            elif plane == "sagittal":  # fix X -> show Y-Z
                ax.imshow(mask.astype(np.float32), alpha=alpha, aspect=spacing[2]/spacing[0])

        if contour:
            ax.contour(mask.astype(np.float32), levels=[0.5], linewidths=1.0)

        ax.set_title(title)

        plt.show()

        # stash for saving
        _last_fig["fig"] = fig
        _last_fig["meta"] = dict(case_id=CASE_ID, plane=plane, idx=idx, label_name=label_name, wl=wl, ww=ww)

def on_save_clicked(_):
    meta = _last_fig.get("meta")
    fig = _last_fig.get("fig")
    if meta is None or fig is None:
        save_status.value = "<b style='color:red'>Nothing to save yet. Render a view first.</b>"
        return
    ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    fname = f"{meta['case_id']}_{meta['plane']}_idx{meta['idx']:04d}_{meta['label_name']}_WL{meta['wl']}_WW{meta['ww']}_{ts}.png"
    out_path = QC_DIR / fname
    fig.savefig(out_path, dpi=200, bbox_inches="tight")
    save_status.value = f"<b style='color:green'>Saved:</b> {out_path}"

save_btn.on_click(on_save_clicked)

widgets.interactive_output(
    render,
    {
        "plane": plane_dd,
        "idx": idx_slider,
        "label_id": label_dd,
        "alpha": alpha_slider,
        "wl": wl_slider,
        "ww": ww_slider,
        "rot90_left": rot_chk,
        "flip_lr": flip_lr_chk,
        "contour": contour_chk,
        "fill_mask": fill_chk,
    },
)

display(ui1, ui2, ui3, save_status, out)
render(plane_dd.value, idx_slider.value, label_dd.value, alpha_slider.value, wl_slider.value, ww_slider.value,
       rot_chk.value, flip_lr_chk.value, contour_chk.value, fill_chk.value)


## 6) Optional: batch export a montage (fixed slices)

In [15]:
import matplotlib.pyplot as plt
import datetime

def export_montage(plane="axial", label_id=6, n=12, wl=40, ww=400, rot90_left=True, flip_lr=False):
    # choose slices where the label exists
    if label_id == -1:
        vol_mask = (pred != 0)
        label_name = "AllForeground"
    else:
        vol_mask = (pred == int(label_id))
        label_name = f"label{int(label_id):02d}_{LABELS.get(int(label_id), str(label_id)).replace(' ','_')}"

    # projection to find which indices have label
    if plane == "axial":
        idxs = np.where(vol_mask.any(axis=(0,1)))[0]
    elif plane == "coronal":
        idxs = np.where(vol_mask.any(axis=(0,2)))[0]
    elif plane == "sagittal":
        idxs = np.where(vol_mask.any(axis=(1,2)))[0]
    else:
        raise ValueError("plane must be axial/coronal/sagittal")

    if len(idxs) == 0:
        raise RuntimeError("No voxels found for selected label/plane.")

    sel = np.linspace(idxs.min(), idxs.max(), n).astype(int)

    ncols = int(np.ceil(n/3))
    fig, axes = plt.subplots(3, ncols, figsize=(14, 9))
    axes = axes.flatten()

    for ax, idx in zip(axes, sel):
        ct2d, (row_sp, col_sp) = get_slice(img, idx, plane)
        seg2d, _ = get_slice(pred.astype(np.int16), idx, plane)
        ct2d = window_ct(ct2d, wl=wl, ww=ww)

        if rot90_left:
            ct2d = np.rot90(ct2d, k=1)
            seg2d = np.rot90(seg2d, k=1)
            row_sp, col_sp = col_sp, row_sp
        if flip_lr:
            ct2d = np.fliplr(ct2d)
            seg2d = np.fliplr(seg2d)

        aspect = row_sp / col_sp

        ax.imshow(ct2d, cmap="gray", aspect=aspect)
        if label_id == -1:
            mask = (seg2d != 0)
        else:
            mask = (seg2d == int(label_id))
        ax.contour(mask.astype(np.float32), levels=[0.5], linewidths=1.0)
        ax.set_title(f"idx={idx}")
        ax.axis("off")

    for ax in axes[len(sel):]:
        ax.axis("off")

    fig.suptitle(f"{CASE_ID} — {plane} — {label_name}")
    ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    out_path = QC_DIR / f"{CASE_ID}_{plane}_{label_name}_montage_{ts}.png"
    fig.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.show()
    print("Saved montage:", out_path)

# Example:
# export_montage(plane="axial", label_id=6, n=12, wl=40, ww=400)
