In [None]:
import os
import glob
import numpy as np
import pandas as pd
import SimpleITK as sitk
from medpy.metric import binary
from scipy.ndimage import label

# Save NumPy mask array as NIfTI file
def save_array_to_nii(mask_array, nii_template_path, nii_save_path):
    sitk_image_object = sitk.ReadImage(nii_template_path)
    output_spacing = sitk_image_object.GetSpacing()
    output_direction = sitk_image_object.GetDirection()
    output_origin = sitk_image_object.GetOrigin()

    nrrd_output = sitk.GetImageFromArray(mask_array.astype(np.uint8))
    nrrd_output.SetSpacing(output_spacing)
    nrrd_output.SetDirection(output_direction)
    nrrd_output.SetOrigin(output_origin)

    sitk.WriteImage(nrrd_output, nii_save_path, True)  # save with compression
    print(f"Saved: {nii_save_path}")

# Keep only the largest connected component in a binary mask
def largest_connected_component(binary_mask: np.ndarray) -> np.ndarray:
    sitk_mask = sitk.GetImageFromArray(binary_mask.astype(np.uint8))

    cc_filter = sitk.ConnectedComponentImageFilter()
    labeled = cc_filter.Execute(sitk_mask)

    stats = sitk.LabelShapeStatisticsImageFilter()
    stats.Execute(labeled)

    if stats.GetNumberOfLabels() == 0:
        return np.zeros_like(binary_mask, dtype=bool)

    largest_label = max(stats.GetLabels(), key=lambda l: stats.GetPhysicalSize(l))
    largest_cc = labeled == largest_label
    largest_cc_arr = sitk.GetArrayFromImage(largest_cc)
    return largest_cc_arr.astype(bool)


In [None]:
import os
import glob
import SimpleITK as sitk
import numpy as np
import pandas as pd
from medpy.metric import binary

# Fix numpy bool warning
np.bool = bool  

# Assume largest_connected_component() is already defined elsewhere

# --- Directories ---
seg_dir = "/home/zhongyi/Desktop/nnunetv1/nnUNet_raw_data_base/nnUNet_test_data/test_seg_in_nii_raw_910_curated"
label_dir = "/home/zhongyi/Desktop/nnunetv1/nnUNet_raw_data_base/nnUNet_test_data/test_seg_in_nii_raw_912_curated"

# --- Settings ---
all_classes = [1, 2]   # muscle classes
records = []

# --- Iterate over segmentation results ---
seg_files = sorted(glob.glob(os.path.join(seg_dir, "*.nii.gz")))

for seg_path in seg_files:
    seg_name = os.path.basename(seg_path)
    
    # Extract patient ID
    pid = "_".join(seg_name.split("_")[:3])
    if "ID" in pid:
        pid = pid.replace("water_L4L5", "")
    if "_T" in pid or "word" in pid:
        pid = "_".join(seg_name.split("_")[:2])
    if pid[0] == "s":
        pid = pid.split("_")[0]    
    pid = pid.replace("_L4L5", "")   

    # Find corresponding ground truth label
    label_candidates = glob.glob(os.path.join(label_dir, f"{pid}*.nii.gz"))
    if not label_candidates:
        print(f"❌ No ground truth label found for {pid}")
        continue
    label_path = label_candidates[0]

    # Load segmentation and label arrays
    seg_arr = sitk.GetArrayFromImage(sitk.ReadImage(seg_path))
    label_arr = sitk.GetArrayFromImage(sitk.ReadImage(label_path))

    dsc_values = {}
    decrease_values = {}

    # --- Per-class evaluation ---
    for cls in all_classes:
        seg_bin = (seg_arr == cls)
        label_bin = (label_arr == cls)

        seg_bin_lcc = largest_connected_component(seg_bin)
        label_bin_lcc = largest_connected_component(label_bin)

        # Smoothness check: cross-sectional area should not jump too much
        slice_areas = np.array([np.sum(seg_bin[z, :, :]) for z in range(seg_bin.shape[0])])
        ratios = slice_areas[1:] / (slice_areas[:-1] + 1e-8)  # prevent div by zero
        decreased = np.all(ratios <= 1.25)  # criterion: max 1.25x increase

        if np.sum(label_bin_lcc) == 0:
            dsc = np.nan
            print(f"⚠️ {pid} - Class {cls} not present in ground truth. DSC: NaN")
        else:
            try:
                dsc = binary.dc(seg_bin_lcc, label_bin_lcc)
                print(f"✅ {pid} - Class {cls} - DSC: {dsc:.4f} - Smooth: {decreased}")
            except Exception as e:
                dsc = np.nan
                print(f"❌ Error computing DSC for {pid} - Class {cls}: {e}")

        dsc_values[cls] = dsc
        decrease_values[cls] = decreased

    # Mean DSC across classes
    mean_dsc = np.nanmean(list(dsc_values.values()))

    records.append({
        "pid": pid,
        "dsc_class_1": round(float(dsc_values.get(1, np.nan)), 4) if not np.isnan(dsc_values.get(1, np.nan)) else None,
        "dsc_class_2": round(float(dsc_values.get(2, np.nan)), 4) if not np.isnan(dsc_values.get(2, np.nan)) else None,
        "mean_dsc": round(float(mean_dsc), 4) if not np.isnan(mean_dsc) else None,
        "decrease_class_1": decrease_values.get(1, False),
        "decrease_class_2": decrease_values.get(2, False),
    })


In [None]:
# --- Collect results into DataFrame ---
df = pd.DataFrame(records)

print(f"\n✅ Total PIDs evaluated: {len(df)}")

# Combine smoothness checks: require both classes satisfied
df["decrease_all"] = df["decrease_class_1"] & df["decrease_class_2"]

# Sort: prioritize smooth predictions first, then higher mean DSC
df_sorted = df.sort_values(
    by=["decrease_all", "mean_dsc"],
    ascending=[False, False]
).reset_index(drop=True)

df_sorted


In [None]:
# --- Assign dataset category based on PID naming rules ---
def assign_dataset(pid: str) -> str:
    if "ID" in pid:
        return "bedrest"
    elif "T1" in pid:
        return "T1W"
    elif "T2" in pid:
        return "T2W"
    elif pid.startswith("s"):
        return "ct_tt"
    elif pid.startswith("word"):
        return "ct_word"
    else:
        return "AFL"

# Apply dataset assignment
df_sorted["dataset"] = df_sorted["pid"].apply(assign_dataset)
df_sorted


In [None]:
# --- Ensure required columns ---
df["decrease_all"] = df["decrease_class_1"] & df["decrease_class_2"]

if "dataset" not in df.columns:
    df["dataset"] = df["pid"].apply(assign_dataset)

# --- Collect top cases per dataset ---
df_top_list = []

for dataset, group in df.groupby("dataset"):
    # Only keep cases with smoothness constraint + high DSC
    df_true = group[(group["decrease_all"]) & (group["mean_dsc"] > 0.95)]
    
    # Top-N selection: keep 20% or at least 2 cases
    top_n = max(2, int(len(group) * 0.2))
    
    df_top = df_true.sort_values(by="mean_dsc", ascending=False).head(top_n)
    df_top_list.append(df_top)

# --- Combine all datasets ---
df_top_all = pd.concat(df_top_list).reset_index(drop=True)
df_top_all


In [None]:
import os
import glob
import shutil

# =====================================
# Copy Selected High-IoU Cases to Stage Dataset
# =====================================

# --- Input dirs ---
img_dir = "/home/zhongyi/Desktop/nnunetv1/nnUNet_raw_data_base/nnUNet_test_data/test_img_in_nii"
seg_dir = "/home/zhongyi/Desktop/nnunetv1/nnUNet_raw_data_base/nnUNet_test_data/test_seg_in_nii_raw_912_curated"

# --- Output dirs (update per stage) ---
stage = "stage3"   # e.g., "stage2", "stage3"
base_out = f"/home/zhongyi/Desktop/nnunetv1/nnUNet_raw_data_base/nnUNet_train_data_raw/MRI_highIOU_2class_nnsam2_{stage}"
out_img_dir = os.path.join(base_out, "img_in_nii_L4-L5")
out_seg_dir = os.path.join(base_out, "seg_in_nii_L4-L5")

# --- Create output dirs if not exist ---
os.makedirs(out_img_dir, exist_ok=True)
os.makedirs(out_seg_dir, exist_ok=True)

# --- Iterate through selected PIDs ---
for pid in df_top_all["pid"]:
    # Find matching image + segmentation files
    img_matches = glob.glob(os.path.join(img_dir, f"{pid}*.nii*"))
    seg_matches = glob.glob(os.path.join(seg_dir, f"{pid}*.nii*"))

    # Copy image
    if len(img_matches) == 1:
        shutil.copy(img_matches[0], out_img_dir)
        print(f"✅ Copied image: {img_matches[0]}")
    elif len(img_matches) == 0:
        print(f"❌ Missing image for {pid}")
    else:
        print(f"⚠️ Multiple images for {pid}: {img_matches}")

    # Copy segmentation
    if len(seg_matches) == 1:
        shutil.copy(seg_matches[0], out_seg_dir)
        print(f"✅ Copied seg: {seg_matches[0]}")
    elif len(seg_matches) == 0:
        print(f"❌ Missing segmentation for {pid}")
    else:
        print(f"⚠️ Multiple segmentations for {pid}: {seg_matches}")
