In [None]:
# Install MONAI
import sys
!{sys.executable} -m pip install monai==1.3.0 -q


In [None]:
from monai.utils import first, set_determinism
from monai.transforms import(
    Compose,
    EnsureChannelFirstD,
    LoadImaged,
    DivisiblePadD,
    ToTensord,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
    Activations,
    ConcatItemsd,
    Compose,
    AsDiscrete,
)

from monai.networks.nets import UNet, SegResNet
from monai.networks.layers import Norm
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.metrics import DiceMetric, SurfaceDistanceMetric
from monai.inferers import sliding_window_inference


import numpy as np
import torch
from scipy.ndimage import label as cc_label

import matplotlib.pyplot as plt

import os
from glob import glob

In [None]:
in_dir = r'path/to/data_split' # path to data split folders
model_dir = r'path/to/results' #Path to best checkpoint folder

In [None]:
train_loss = np.load(os.path.join(model_dir, 'loss_train.npy'))
train_metric = np.load(os.path.join(model_dir, 'metric_train.npy'))
val_loss = np.load(os.path.join(model_dir, 'loss_val.npy'))
val_metric = np.load(os.path.join(model_dir, 'metric_val.npy'))

In [None]:
#Visualize metric curves from training

plt.figure("Metric Results Best SegResNet", (12, 6))
plt.subplot(2, 2, 1)
plt.title("Train dice loss")
x = [i + 1 for i in range(len(train_loss))]
y = train_loss
plt.xlabel("epoch")
plt.plot(x, y)

plt.subplot(2, 2, 2)
plt.title("Train metric DICE")
x = [i + 1 for i in range(len(train_metric))]
y = train_metric
plt.xlabel("epoch")
plt.plot(x, y)

plt.subplot(2, 2, 3)
plt.title("Test dice loss")
x = [i + 1 for i in range(len(val_loss))]
y = val_loss
plt.xlabel("epoch")
plt.plot(x, y)

plt.subplot(2, 2, 4)
plt.title("Test metric DICE")
x = [i + 1 for i in range(len(val_metric))]
y = val_metric
plt.xlabel("epoch")
plt.plot(x, y)

plt.show()

In [None]:
# Load dataset paths
path_test_volumes = sorted(glob(os.path.join(in_dir, "OptimTestVol", "*.nii.gz")))
path_test_seg_tumors = sorted(glob(os.path.join(in_dir, "OptimTestSegTumors", "*.nii.gz")))
path_test_seg_liver = sorted(glob(os.path.join(in_dir, "OptimTestSegLiverPred", "*.nii.gz")))

test_files = [{"vol": vol, "seg_tumor": seg_tumor, "seg_liver": seg_liver} 
              for vol, seg_tumor, seg_liver in zip(path_test_volumes, path_test_seg_tumors, path_test_seg_liver)]

#print(f"Found {len(path_test_volumes)} volumes, {len(path_test_seg_tumors)} tumors, {len(path_test_seg_liver)} livers for testing.")
#print(f"Final test file triplets: {len(test_files)}")

In [None]:
pixdim=(1.0, 1.0, 2.5)
a_min=-100
a_max=200

test_transforms = Compose([
        LoadImaged(keys=["vol", "seg_tumor", "seg_liver"]),
        EnsureChannelFirstD(keys=["vol", "seg_tumor", "seg_liver"]),

        Spacingd(keys=["vol", "seg_tumor", "seg_liver"], pixdim=pixdim, mode=("bilinear", "nearest", "nearest")),
        Orientationd(keys=["vol", "seg_tumor", "seg_liver"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["vol"], a_min=a_min, a_max=a_max, b_min=0.0, b_max=1.0, clip=True),

        CropForegroundd(keys=["vol", "seg_tumor", "seg_liver"], source_key="seg_liver"), 

        DivisiblePadD(keys=["vol", "seg_tumor","seg_liver"], k=16, mode="constant"),
        
        ConcatItemsd(keys=["vol", "seg_liver"], name="conc_image", dim=0),

        ToTensord(keys=["conc_image", "seg_tumor"])
    ])

In [None]:
test_ds = Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1)

In [None]:
device = torch.device("cuda:0")
model = SegResNet(
        spatial_dims=3,
        init_filters=16,
        in_channels=2,
        out_channels=2,
        dropout_prob=0.1,
).to(device)

In [None]:
# ---- Load best weights & config ----
ckpt_path = os.path.join(model_dir, "best_metric_model.pth")
model.load_state_dict(torch.load(ckpt_path, map_location=device))
model.eval()

# Regular Testing

In [None]:
#TESTING SETUP
 
roi_size = (128, 128, 64)
sw_batch_size = 1
overlap = 0.4

from monai.metrics import DiceMetric

dice_metric_case = DiceMetric(include_background=False, reduction="mean")
surf_metric = SurfaceDistanceMetric(include_background=False, symmetric=True)
surface_metric = SurfaceDistanceMetric(include_background=False, symmetric=True)

# Post transforms for DiceMetric (multi-class with background excluded)

post_pred  = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

# ---- (Optional) simple recall/precision helpers (voxel-wise, class=1) ----
def _recall_precision_from_masks(pred_bin: torch.Tensor, gt_bin: torch.Tensor):
    """pred_bin/gt_bin: [1,1,H,W,D] binary {0,1} tensors"""
    p = pred_bin > 0.5
    g = gt_bin > 0.5
    tp = (p & g).sum().float()
    fp = (p & ~g).sum().float()
    fn = (~p & g).sum().float()
    recall = (tp / (tp + fn + 1e-8)).item()
    precision = (tp / (tp + fp + 1e-8)).item()
    return recall, precision

def _summ(name, arr):
    arr = np.asarray(arr, dtype=float)
    mean, std = np.nanmean(arr), np.nanstd(arr)
    mn, mx = np.nanmin(arr), np.nanmax(arr)
    print(f"{name:>10}: {mean:.4f} ± {std:.4f} | min={mn:.4f} | max={mx:.4f}")

def dice_from_confusion(tp, fp, fn, eps=1e-8):
    return (2 * tp) / (2 * tp + fp + fn + eps)

In [None]:
# MASK COMPRAISON - original, liver-pred mask, tumor-pred mask  ---
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

idx, z = 4, 51

# 1) Load volume
case = test_ds[idx]
vol  = case["conc_image"][0].cpu().numpy()
assert 0 <= z < vol.shape[2], f"Slice {z} out of range 0..{vol.shape[2]-1}"

# 2) Load liver prediction mask
def get_liver_pred_for_index(i):
    if "seg_liver" in case and case["seg_liver"] is not None:
        return (case["seg_liver"][0].cpu().numpy() > 0).astype(np.uint8)
    lp = nib.load(path_test_seg_liver[i]).get_fdata()
    lp = np.squeeze(lp)
    return (lp > 0).astype(np.uint8)

liver_pred = get_liver_pred_for_index(idx)
z_liver = min(z, liver_pred.shape[2]-1)

# 3) Run tumor prediction for this case
with torch.no_grad():
    inp = case["conc_image"].unsqueeze(0).to(device)
    logits = sliding_window_inference(inp, roi_size, sw_batch_size, predictor=model, overlap=overlap)
    if isinstance(logits, list):
        logits = logits[-1]
    tumor_pred_full = torch.argmax(logits, dim=1)[0].cpu().numpy()

tumor_pred = (tumor_pred_full == 1).astype(np.uint8)
z_tumor = min(z, tumor_pred.shape[2]-1)

# 4) Plot
plt.figure(figsize=(12, 4))

# Original
plt.subplot(1, 3, 1)
plt.imshow(vol[:, :, z], cmap="gray")
plt.title(f"Original (idx={idx}, z={z})")
plt.axis("off")

# Liver prediction mask (white on black)
plt.subplot(1, 3, 2)
plt.imshow(liver_pred[:, :, z_liver], cmap="gray", vmin=0, vmax=1)
plt.title("Liver Prediction Mask")
plt.axis("off")

# Tumor prediction mask (white on black)
plt.subplot(1, 3, 3)
plt.imshow(tumor_pred[:, :, z_tumor], cmap="gray", vmin=0, vmax=1)
plt.title("Tumor Prediction Mask")
plt.axis("off")

plt.tight_layout()
plt.show()


In [None]:
#METRIC COMPUTATION

# global Dice like training
test_tp = torch.tensor(0, device=device)
test_fp = torch.tensor(0, device=device)
test_fn = torch.tensor(0, device=device)

per_case_metrics = []
case_idx = 0

with torch.no_grad():
    for batch in test_loader:
        vol = batch["conc_image"].to(device)               # Add 2-channel vol
        gt  = (batch["seg_tumor"] != 0).long().to(device)  

        logits = sliding_window_inference(
            vol, roi_size, sw_batch_size, model, overlap=overlap
        )
                # decollate
        logits_list = decollate_batch(logits)  # each: [C,H,W,D]
        gts_list    = decollate_batch(gt)      # each: [1,H,W,D]

        for logit_i, gt_i in zip(logits_list, gts_list):
            # ---- Dice (matching training) ----
            y_pred_i = post_pred(logit_i)        
            y_true_i = post_label(gt_i)      

            dice_metric_case.reset()
            dice_metric_case(y_pred=[y_pred_i], y=[y_true_i])
            dice_val = dice_metric_case.aggregate().item()
            dice_metric_case.reset()

            #  Binary masks for ASSD / recall / precision 
            pred_label = torch.argmax(logit_i, dim=0)  
            pred_bin   = (pred_label == 1).float().unsqueeze(0).unsqueeze(0).to(device)  
            gt_bin     = (gt_i == 1).float().unsqueeze(0).to(device)                     

            # ASSD
            if (pred_bin.sum() > 0) and (gt_bin.sum() > 0):
                surf_metric.reset()
                surf_metric(y_pred=pred_bin, y=gt_bin)
                try:
                    assd_val = surf_metric.aggregate().item()
                except (ValueError, AttributeError):
                    assd_val = np.nan
                surf_metric.reset()
            else:
                assd_val = np.nan

            # Recall / Precision
            recall_val, precision_val = _recall_precision_from_masks(pred_bin, gt_bin)

            # accumulate confusion for global Dice like training
            tp_i = (pred_bin.bool() & gt_bin.bool()).sum()
            fp_i = (pred_bin.bool() & ~gt_bin.bool()).sum()
            fn_i = (~pred_bin.bool() & gt_bin.bool()).sum()
            test_tp += tp_i
            test_fp += fp_i
            test_fn += fn_i

            per_case_metrics.append({
                "idx": case_idx,
                "dice": float(dice_val),
                "assd": float(assd_val) if assd_val == assd_val else np.nan,
                "recall": float(recall_val),
                "precision": float(precision_val),
            })
            case_idx += 1


dice_vals      = [m["dice"] for m in per_case_metrics]
assd_vals      = [m["assd"] for m in per_case_metrics]
recall_vals    = [m["recall"] for m in per_case_metrics]
precision_vals = [m["precision"] for m in per_case_metrics]

_summ("Dice", dice_vals)
_summ("ASSD", assd_vals)
_summ("Recall", recall_vals)
_summ("Precision", precision_vals)



final_global_dice = dice_from_confusion(test_tp.float(), test_fp.float(), test_fn.float()).item()
print(f"Global Dice (from TP/FP/FN): {final_global_dice:.4f}")


In [None]:
## OUTLIER BINS - per-patient DICE

import numpy as np
import matplotlib.pyplot as plt

valid_items = [(m["idx"], m["dice"]) for m in per_case_metrics if not np.isnan(m["dice"])]
if len(valid_items) == 0:
    print("No valid per-case Dice values . Nothing to plot.")
else:
    idxs  = np.array([it[0] for it in valid_items], dtype=int)
    dscs  = np.array([it[1] for it in valid_items], dtype=float)
    dscs_pct = np.clip(dscs * 100.0, 0.0, 100.0)

    # Define bins and labels
    bins   = [0, 20, 40, 60, 80, 100]
    labels = ["0–20", "20–40", "40–60", "60–80", "80–100"]

    counts, edges = np.histogram(dscs_pct, bins=bins)

    # Map patients to bins (for quick outlier inspection)
    bin_to_patients = {lab: [] for lab in labels}
    # np.digitize returns bin indices in 1..len(bins)-1 with left-closed right-open, except rightmost
    bin_indices = np.digitize(dscs_pct, bins=bins, right=False)
    # Ensure values equal to the upper edge 100 land in last bin
    bin_indices = np.clip(bin_indices, 1, len(bins)-1)

    for pid, b in zip(idxs, bin_indices):
        bin_to_patients[labels[b-1]].append(int(pid))

    # --- Plot ---
    plt.figure(figsize=(7, 4))
    plt.bar(labels, counts)
    plt.xlabel("Per-case Dice")
    plt.ylabel("Number of patients")
    plt.title(f"Dice distribution across patients for automated pipeline (N={len(dscs_pct)} included)")

    # Set the y-axis limits from 0 to 20
    plt.ylim(0, 20)

    # Annotate counts on bars
    for i, c in enumerate(counts):
        plt.text(i, c + max(1, 0.02*max(counts) if counts.max() > 0 else 1), str(int(c)),
                 ha="center", va="bottom", fontsize=10)

    plt.tight_layout()
    plt.show()

    # Print patient indices per bin with Dice scores
    print("\nPatients per Dice bin (with Dice scores):")
    for lab in labels:
        patients = bin_to_patients[lab]
        if patients:
            pid_scores = [(pid, round(float(dscs_pct[idxs == pid][0]) / 100.0, 3)) 
                          for pid in patients]
            print(f"  {lab}%: {pid_scores}")
        else:
            print(f"  {lab}%: []")



In [None]:
#DICE WITHOUT OUTLIERS

excluded_ids = set(bin_to_patients[labels[0]] + bin_to_patients[labels[1]])

orig_dice = [m["dice"] for m in per_case_metrics if not np.isnan(m["dice"])]

filtered_dice = [m["dice"] for m in per_case_metrics
                 if (not np.isnan(m["dice"])) and (m["idx"] not in excluded_ids)]

if len(filtered_dice) == 0:
    print("No remaining patients after excluding the first two bins.")
else:
    final_test_dice_excl_low = float(np.mean(filtered_dice))
    print(f"Excluded patient IDs (first two bins): {sorted(excluded_ids)}")
    print(f"N included: {len(filtered_dice)} / {len(orig_dice)} total")
    print(f"Recomputed test Dice (mean) excluding outliers: {final_test_dice_excl_low:.4f}")
    print(f"Median: {np.median(filtered_dice):.4f} | Std: {np.std(filtered_dice, ddof=1):.4f}")


# Visualization

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt

# Option 1 - Select patients
indices = [10, 12]

#Option 2 - Random selection
# num_patients = 5
# total_cases = len(test_ds)
# indices = random.sample(range(total_cases), num_patients)  # unique random indicesslice_stride = 1          # show every k-th slice containing GT or Pred; 1 = show all

slice_stride = 1 


shown = 0
with torch.no_grad():
    for idx in indices:
        

        case = test_ds[idx]
        inp = case["conc_image"].unsqueeze(0).to(device)  # [1,1,H,W,D]
        logits = sliding_window_inference(inp, roi_size, sw_batch_size, predictor=model, overlap=overlap)
        if isinstance(logits, list):
            logits = logits[-1]
        pred = torch.argmax(logits, dim=1)[0].cpu().numpy()     # [H,W,D]

        vol = case["conc_image"][0].cpu().numpy()                      # [H,W,D]
        gt  = case["seg_tumor"][0].cpu().numpy()                # [H,W,D]

        zs_gt   = np.where(gt.sum(axis=(0, 1)) > 0)[0]
        zs_pred = np.where(pred.sum(axis=(0, 1)) > 0)[0]
        zs = np.unique(np.concatenate([zs_gt, zs_pred]))
        if len(zs) == 0:
            print(f"\nPatient index: {idx} had no slices with tumor in GT or Pred.")
            zmid = vol.shape[2] // 2
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1); plt.imshow(vol[:, :, zmid], cmap="gray"); plt.title(f"Axial z={zmid}")
            plt.subplot(1, 3, 2); plt.imshow(vol[:, :, zmid], cmap="gray"); plt.title("GT contour (none)")
            plt.subplot(1, 3, 3); plt.imshow(vol[:, :, zmid], cmap="gray"); plt.title("Pred contour (none)")
            plt.tight_layout(); plt.show()
            shown += 1
            continue

        print(f"\nPatient index: {idx} | tumor slices (GT ∪ Pred): {len(zs)}")

        for z in zs[::slice_stride]:
            plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.imshow(vol[:, :, z], cmap="gray")
            plt.title(f"Axial z={z}")

            plt.subplot(1, 3, 2)
            plt.imshow(vol[:, :, z], cmap="gray")
            # GT contour
            if (gt[:, :, z] > 0).any():
                plt.contour(gt[:, :, z], levels=[0.5], linewidths=1.5, colors='lime')
            plt.title("GT contour")

            plt.subplot(1, 3, 3)
            plt.imshow(vol[:, :, z], cmap="gray")
            # Pred contour
            if (pred[:, :, z] > 0).any():
                plt.contour(pred[:, :, z], levels=[0.5], linewidths=1.5, colors='lime')
            plt.title("2-channel SegRes Pred")

            plt.tight_layout()
            plt.show()

        shown += 1

if shown == 0:
    print("No patients with tumor slices found in test set.")
