In [2]:
%pip install -q "monai[all]" ipywidgets nibabel

from google.colab import output, drive
output.enable_custom_widget_manager()
drive.mount("/content/drive")

import sys, importlib.util
from pathlib import Path
import numpy as np
import torch
import nibabel as nib
from monai.data import MetaTensor

WORK_DIR_DRIVE = Path("/content/drive/MyDrive/mmwhs2017/auto3dseg")
BUNDLE_NAMES = ["segresnet_0", "dints_0", "swinunetr_0"]

IMAGE_PATH = Path("/content/ct_test_2001_image.nii.gz") # change inference target
OUT_MASK_PATH = Path("/content/ct_test_2001_prediction.nii.gz")
OUT_PROB_PATH = Path("/content/ct_test_2001_prob.nii.gz")
OUT_VAR_PATH  = Path("/content/ct_test_2001_var.nii.gz")

if not WORK_DIR_DRIVE.is_dir():
    raise FileNotFoundError(str(WORK_DIR_DRIVE))
if not IMAGE_PATH.is_file():
    raise FileNotFoundError(str(IMAGE_PATH))
for n in BUNDLE_NAMES:
    if not (WORK_DIR_DRIVE / n).is_dir():
        raise FileNotFoundError(str(WORK_DIR_DRIVE / n))

Mounted at /content/drive




In [3]:
"""
Runs MONAI bundle inferers in-place from Drive (no bundle copying).
Writes:
- OUT_MASK_PATH (uint8, argmax)
- OUT_PROB_PATH (float32, channels-last)
- OUT_VAR_PATH  (float32, channels-last) variance across models
"""

def _maybe_patch_cfgs(bundle_root: Path):
    cfg_dir = bundle_root / "configs"
    cfgs = sorted([p for p in cfg_dir.iterdir() if p.suffix.lower() in (".yaml", ".yml")])
    if not cfgs:
        raise FileNotFoundError(str(cfg_dir))
    drive_model = str(bundle_root / "model")
    local_model = str(Path("/content") / bundle_root.name / "model")
    tmp_dir = Path("/content/_tmp_cfgs") / bundle_root.name
    tmp_dir.mkdir(parents=True, exist_ok=True)
    out = []
    for p in cfgs:
        t = p.read_text()
        if (local_model in t) and (drive_model not in t):
            t = t.replace(local_model, drive_model)
        q = tmp_dir / p.name
        q.write_text(t)
        out.append(str(q))
    return out

def _load_inferer(bundle_root: Path):
    infer_py = bundle_root / "scripts" / "infer.py"
    if not infer_py.is_file():
        raise FileNotFoundError(str(infer_py))
    cfg_files = _maybe_patch_cfgs(bundle_root)
    mod_name = f"_infer_{bundle_root.name}"
    spec = importlib.util.spec_from_file_location(mod_name, str(infer_py))
    mod = importlib.util.module_from_spec(spec)
    sys.modules[mod_name] = mod
    spec.loader.exec_module(mod)
    return mod.InferClass(cfg_files)

inferers = [_load_inferer(WORK_DIR_DRIVE / n) for n in BUNDLE_NAMES]
sample = {"image": str(IMAGE_PATH)}

with torch.no_grad():
    outs = []
    for inf in inferers:
        out = inf.infer(sample)
        out = out.as_tensor() if isinstance(out, MetaTensor) else torch.as_tensor(out)
        outs.append(out.detach().float().cpu())
    all_probs = torch.stack(outs, dim=0)
    prob = all_probs.mean(dim=0)
    var = all_probs.var(dim=0, unbiased=False)

if prob.ndim != 4:
    raise RuntimeError(f"Expected (C,H,W,D), got {tuple(prob.shape)}")

img = nib.load(str(IMAGE_PATH))

label_np = torch.argmax(prob, dim=0).numpy().astype(np.uint8)
seg_nii = nib.Nifti1Image(label_np, img.affine, img.header)
seg_nii.set_data_dtype(np.uint8)
nib.save(seg_nii, str(OUT_MASK_PATH))

prob_4d = np.moveaxis(prob.numpy(), 0, -1).astype(np.float32)
prob_nii = nib.Nifti1Image(prob_4d, img.affine, img.header)
prob_nii.set_data_dtype(np.float32)
nib.save(prob_nii, str(OUT_PROB_PATH))

var_4d = np.moveaxis(var.numpy(), 0, -1).astype(np.float32)
var_nii = nib.Nifti1Image(var_4d, img.affine, img.header)
var_nii.set_data_dtype(np.float32)
nib.save(var_nii, str(OUT_VAR_PATH))

monai.metrics.meandice DiceHelper.__init__:sigmoid: Argument `sigmoid` has been deprecated since version 1.5. It will be removed in version 1.7. Use `threshold` instead.
monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


2025-12-16 04:20:33.011531 - Length of input patch is recommended to be a multiple of 32.


Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:345.)
Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an erro

In [7]:
"""
Interactive viewer without ground truth.

Left:  CT + predicted mask (red)
Right: CT + disagreement heatmap (Std(P) by default)

Flagging (per class c, per slice z):

1) Compute per-voxel disagreement for each class:
   - If your file stores Var(P), convert to Std(P) via sqrt(var).

2) Slice-level summary (optionally only inside predicted mask):
   m[z,c] = mean_{voxels in slice}(disagree[x,y,z,c])

3) Score slice means across slices for that class:
   zscore[z,c] = (m[z,c] - mean_z(m[:,c])) / std_z(m[:,c])

4) Flag if zscore >= Z_THRESH
"""

import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, FloatSlider, Dropdown, Checkbox

# -----------------------------
# User inputs (must exist)
# -----------------------------
# IMAGE_PATH = ...
# OUT_MASK_PATH = ...
# OUT_VAR_PATH = ...

# -----------------------------
# Load data
# -----------------------------
ct = nib.load(str(IMAGE_PATH)).get_fdata(dtype=np.float32)
ct = np.squeeze(ct) if (ct.ndim == 4 and ct.shape[-1] == 1) else ct

seg = nib.load(str(OUT_MASK_PATH)).get_fdata(dtype=np.float32)
seg = np.squeeze(seg).astype(np.int16)

var = nib.load(str(OUT_VAR_PATH)).get_fdata(dtype=np.float32)  # expected Var(P) with shape (H,W,D,C)

if ct.ndim != 3 or seg.ndim != 3 or var.ndim != 4:
    raise RuntimeError(f"Shapes: ct={ct.shape}, seg={seg.shape}, var={var.shape}")

H, W, D = ct.shape
if seg.shape != (H, W, D) or var.shape[:3] != (H, W, D):
    raise RuntimeError(f"Mismatch: ct={ct.shape}, seg={seg.shape}, var={var.shape}")

C = int(var.shape[3])

# -----------------------------
# Settings
# -----------------------------
# If var is Var(P), convert to Std(P) (recommended if you mean "std(P)")
USE_STD = True

USE_PRED_MASK_ONLY = True   # compute slice means only inside predicted class mask
MIN_MASK_VOXELS = 50        # ignore slices with too few predicted voxels for that class
DEFAULT_Z_THRESH = 1.0

eps = 1e-8

# -----------------------------
# Choose disagreement volume
# -----------------------------
if USE_STD:
    # Var(P) -> Std(P)
    disagree = np.sqrt(np.clip(var, 0.0, None)).astype(np.float32)
    disagree_label = "Std(P)"
else:
    disagree = var.astype(np.float32)
    disagree_label = "Var(P)"

# -----------------------------
# Precompute slice means per class
# -----------------------------
mean_slice = np.full((D, C), np.nan, dtype=np.float32)
nvox_slice = np.zeros((D, C), dtype=np.int32)

for z in range(D):
    seg_z = seg[:, :, z]
    for c in range(C):
        v = disagree[:, :, z, c]
        if USE_PRED_MASK_ONLY:
            m = (seg_z == c)
            n = int(m.sum())
            nvox_slice[z, c] = n
            if n >= MIN_MASK_VOXELS:
                mean_slice[z, c] = float(v[m].mean())
        else:
            nvox_slice[z, c] = H * W
            mean_slice[z, c] = float(v.mean())

# z-score across slices (per class) based on slice-means
mu_c = np.nanmean(mean_slice, axis=0)  # (C,)
sd_c = np.nanstd(mean_slice, axis=0)   # (C,)
zscore = (mean_slice - mu_c[None, :]) / (sd_c[None, :] + eps)

# -----------------------------
# Helpers
# -----------------------------
class_opts = [(f"class {i}", i) for i in range(C)]

def flagged_slices_for_class(cls_idx: int, z_thresh: float):
    valid = np.isfinite(zscore[:, cls_idx])
    idx = np.where(valid & (zscore[:, cls_idx] >= z_thresh))[0].astype(int).tolist()
    return idx

def flagged_slices_all(z_thresh: float):
    out = {}
    for c in range(C):
        out[c] = flagged_slices_for_class(c, z_thresh)
    return out

# Print initial flagged slices at DEFAULT_Z_THRESH
print(f"Initial flags using Z_THRESH={DEFAULT_Z_THRESH} on slice-mean {disagree_label}:")
init_flags = flagged_slices_all(DEFAULT_Z_THRESH)
for c in range(C):
    print(c, init_flags[c])

# -----------------------------
# Interactive viewer
# -----------------------------
@interact(
    z=IntSlider(min=0, max=D - 1, step=1, value=D // 2, description="Slice"),
    cls_idx=Dropdown(options=class_opts, value=0, description="Class"),
    alpha_seg=FloatSlider(min=0.0, max=1.0, step=0.1, value=0.5, description="Seg α"),
    alpha_heat=FloatSlider(min=0.0, max=1.0, step=0.1, value=0.6, description="Heat α"),
    z_thresh=FloatSlider(min=0.0, max=5.0, step=0.1, value=DEFAULT_Z_THRESH, description="Z thr"),
    show_only_flagged=Checkbox(value=False, description="Show only flagged slices"),
)
def view_pred_vs_disagree(z, cls_idx, alpha_seg, alpha_heat, z_thresh, show_only_flagged):
    # If user wants only flagged slices, jump to nearest flagged slice (simple behavior)
    flagged_cls = flagged_slices_for_class(cls_idx, z_thresh)
    if show_only_flagged and len(flagged_cls) > 0 and (z not in flagged_cls):
        # pick nearest flagged slice to current z
        z = min(flagged_cls, key=lambda zz: abs(zz - z))

    slice_ct = ct[:, :, z]

    mask = (seg[:, :, z] == cls_idx)
    heat = disagree[:, :, z, cls_idx]

    # For display: make "mean_heat" consistent with how mean_slice/zscore were computed
    if USE_PRED_MASK_ONLY:
        n = int(mask.sum())
        mean_heat = float(heat[mask].mean()) if n > 0 else float("nan")
    else:
        n = H * W
        mean_heat = float(heat.mean())

    if np.isfinite(mean_slice[z, cls_idx]):
        zs = float(zscore[z, cls_idx])
        is_flag = bool(zs >= z_thresh)
        info = (
            f"slice_mean({disagree_label})={float(mean_slice[z, cls_idx]):.6f}  "
            f"z={zs:.2f}  n={int(nvox_slice[z, cls_idx])}  FLAG={is_flag}"
        )
    else:
        info = f"slice_mean({disagree_label})=NA (n<{MIN_MASK_VOXELS} if masked)  FLAG=False"

    plt.figure(figsize=(12, 5))

    ax1 = plt.subplot(1, 2, 1)
    ax1.imshow(slice_ct, cmap="gray")
    overlay = np.zeros((H, W, 3), dtype=np.float32)
    overlay[..., 0] = mask.astype(np.float32)  # red channel
    ax1.imshow(overlay, alpha=alpha_seg)
    ax1.set_title(f"Pred mask (class {cls_idx})")
    ax1.axis("off")

    ax2 = plt.subplot(1, 2, 2)
    ax2.imshow(slice_ct, cmap="gray")
    im = ax2.imshow(heat, alpha=alpha_heat)
    ax2.set_title(f"{disagree_label} heatmap (class {cls_idx}), mean {mean_heat:.6f}")
    ax2.axis("off")
    plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04, label=disagree_label)

    # show which slices are flagged for this class at current threshold
    flagged_now = flagged_slices_for_class(cls_idx, z_thresh)
    flagged_str = flagged_now[:30]
    more = "" if len(flagged_now) <= 30 else f" ... (+{len(flagged_now)-30} more)"
    plt.suptitle(
        f"Slice {z} | class {cls_idx} | {info}\n"
        f"Flagged slices for class {cls_idx} @ z>={z_thresh:.1f}: {flagged_str}{more}"
    )
    plt.tight_layout()
    plt.show()

Initial flags using Z_THRESH=1.0 on slice-mean Std(P):
0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
1 [46, 47, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 156, 157, 158, 159, 160]
2 [54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102]
3 [98, 99, 100, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174]
4 [170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185]
5 [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 172, 173, 174, 175]
6 [128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 223]
7 [151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168]


interactive(children=(IntSlider(value=112, description='Slice', max=223), Dropdown(description='Class', option…