In [15]:
import os, glob, re
import numpy as np
import pandas as pd
import cv2
import SimpleITK as sitk
from tqdm import tqdm
from scipy.ndimage import label, binary_fill_holes

def save_array_to_nii(mask_array, nii_template_path, nii_save_path):
    """
    Save a NumPy array as a NIfTI file, keeping metadata from template.
    """
    template_img = sitk.ReadImage(nii_template_path)
    output_img = sitk.GetImageFromArray(mask_array)
    output_img.SetSpacing(template_img.GetSpacing())
    output_img.SetDirection(template_img.GetDirection())
    output_img.SetOrigin(template_img.GetOrigin())

    writer = sitk.ImageFileWriter()
    writer.SetFileName(nii_save_path)
    writer.SetUseCompression(True)
    writer.Execute(output_img)
    print(f"💾 Saved: {nii_save_path}")

def keep_largest_connected_component(mask):
    """
    Keep the largest connected component in a binary mask and fill small holes.
    Works for both 2D and 3D masks.
    """
    labeled, num = label(mask)
    if num == 0:
        return mask
    counts = np.bincount(labeled.ravel())
    counts[0] = 0  # ignore background
    largest_label = counts.argmax()
    return binary_fill_holes((labeled == largest_label)).astype(np.uint8)



In [None]:
dataset = "CT_word"   # e.g. "AFL_MRI", "Bedrest", "CT_TT"
seed_count = 1

# Base directories
base_dir = f"videos/{dataset}"
data_dir = os.path.join(base_dir, "data_in_jpg")
csv_path = os.path.join(data_dir, f"{dataset}_4class_IOU_all_pid_with_full_mapping_{seed_count}seed.csv")

# NIfTI input & output
img_dir = os.path.join(base_dir, "data_in_nii", "img_in_nii_L4L5")
save_seg_dir = os.path.join(base_dir, "data_in_nii", f"SAM2_auto_seg_nii_{seed_count}shot")
os.makedirs(save_seg_dir, exist_ok=True)

# SAM2 prediction results
sam2_root = os.path.join(data_dir, f"sam2_results_by_pid_{seed_count}seed")

# Special case for Bedrest dataset
if dataset == "Bedrest":
    img_dir = os.path.join(base_dir, "data_in_nii", "img_in_nii_water_72_L4L5")

# Grayscale values corresponding to each class
class_values = [100, 200]

# Load IOU mapping CSV
df = pd.read_csv(csv_path)


In [None]:
for pid, group in tqdm(df.groupby("pid"), desc="Processing PIDs"):
    group = group.sort_values("slice_id").reset_index(drop=True)
    frame_ids = group["frame_idx"].tolist()
    slice_ids = group["slice_id"].tolist()

    # Locate input NIfTI
    nii_files = glob.glob(os.path.join(img_dir, f"{pid}*.nii.gz"))
    if not nii_files:
        print(f"❌ Image not found for {pid}")
        continue
    img_path = nii_files[0]
    img_array = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
    H, W = img_array.shape[1:]

    sam2_pred_dir = os.path.join(sam2_root, pid, "SAM2_seg_mask_nolap")

    seg_array = []
    for frame_id, slice_id in zip(frame_ids, slice_ids):
        per_class_masks = []
        for i, class_val in enumerate(class_values):
            mask_path = os.path.join(sam2_pred_dir, f"frame_{frame_id}_obj_{i+1}.png")
            if not os.path.exists(mask_path):
                print(f"⚠️ Missing mask: {mask_path}")
                continue

            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            mask_bin = (np.abs(mask - class_val) < 10).astype(np.uint8)
            mask_clean = keep_largest_connected_component(mask_bin)
            mask_resized = cv2.resize(mask_clean, (W, H), interpolation=cv2.INTER_NEAREST)
            per_class_masks.append(mask_resized)

        if len(per_class_masks) != len(class_values):
            print(f"❌ Incomplete masks for {pid} frame {frame_id}, skipping")
            continue

        # Merge class masks
        multi_mask = np.zeros((H, W), dtype=np.uint8)
        for cls_idx, binary_mask in enumerate(per_class_masks):
            multi_mask[binary_mask > 0] = cls_idx + 1

        # Fix orientation for some datasets
        if "ID" in pid or "CT_TT" in dataset:
            multi_mask = np.rot90(multi_mask, 2)

        seg_array.append(multi_mask)

    if not seg_array:
        print(f"⚠️ No valid masks for {pid}, skipping saving")
        continue

    # Stack slices into 3D
    seg_stack = np.stack(seg_array)

    # Clean each class with largest connected component
    for class_idx in range(1, len(class_values) + 1):
        class_mask_3d = (seg_stack == class_idx).astype(np.uint8)
        class_mask_3d_clean = keep_largest_connected_component(class_mask_3d)
        seg_stack[seg_stack == class_idx] = 0
        seg_stack[class_mask_3d_clean > 0] = class_idx

    # Save NIfTI
    save_path = os.path.join(save_seg_dir, f"{pid}_sam2_seg.nii.gz")
    save_array_to_nii(seg_stack, img_path, save_path)


In [None]:
# ================================
# Dice Coefficient Evaluation
# ================================

import os, glob
import numpy as np
import pandas as pd
import SimpleITK as sitk
from tqdm import tqdm

def dice_coefficient(pred, gt, class_id):
    """
    Compute Dice similarity coefficient (DSC) for a given class.
    """
    pred_bin = (pred == class_id)
    gt_bin = (gt == class_id)
    intersection = np.sum(pred_bin & gt_bin)
    volume_sum = np.sum(pred_bin) + np.sum(gt_bin)
    if volume_sum == 0:  # handle empty masks
        return 1.0
    return 2 * intersection / volume_sum


# --- Directories ---
base_dir = f"videos/{dataset}"
label_dir = os.path.join(base_dir, "data_in_nii", "label_in_nii_L4L5_2class")
if dataset == "Bedrest":
    label_dir = os.path.join(base_dir, "data_in_nii", "label_in_nii_72_water_72_L4L5")

# --- Class IDs to evaluate ---
class_ids = [1, 2]

# --- Collect results ---
results = []
for fname in tqdm(os.listdir(save_seg_dir), desc="Evaluating DSC"):
    if not fname.endswith(".nii.gz"):
        continue

    pid = (
        fname.replace("_sam2_seg.nii.gz", "")
             .replace("_L4L5_lcc_combined.nii.gz", "")
    )
    pred_path = os.path.join(save_seg_dir, fname)
    label_candidates = glob.glob(os.path.join(label_dir, f"{pid}*.nii.gz"))

    if not label_candidates:
        print(f"⚠️ Label not found for {pid}, skipping.")
        continue

    label_path = label_candidates[0]

    # Load prediction & ground truth
    pred = sitk.GetArrayFromImage(sitk.ReadImage(pred_path))
    gt = sitk.GetArrayFromImage(sitk.ReadImage(label_path))

    # Compute Dice per class
    dscs = [dice_coefficient(pred, gt, c) for c in class_ids]
    results.append([pid, *dscs])

# --- Save summary ---
df = pd.DataFrame(results, columns=["pid"] + [f"dsc_class{c}" for c in class_ids])
out_csv = os.path.join(base_dir, f"{dataset}_SAM2_auto_seg_DSC_summary_{seed_count}seed.csv")
df.to_csv(out_csv, index=False)

print(f"📄 DSC summary saved to: {out_csv}")


In [None]:
# ================================
# DSC Statistics (Mean ± Std)
# ================================

print(f"📊 Dataset: {dataset}")

# Compute mean and std for selected classes
selected_cols = ["dsc_class1", "dsc_class2"]
mean_values = df[selected_cols].mean()
std_values = df[selected_cols].std()

print("\n📈 Mean DSC per class:")
for cls, val in mean_values.items():
    print(f"  {cls}: {val:.4f}")

print("\n📉 Standard Deviation of DSC per class:")
for cls, val in std_values.items():
    print(f"  {cls}: {val:.4f}")

# Show summary info
print(f"\n✅ Total cases evaluated: {df.shape[0]}")

# Preview first rows
df.head()
