In [None]:
# ==============================================================
# nnsam2 Post-processing Pipeline (SAM2 → nnU-Net pseudo-labels)
# --------------------------------------------------------------
# 1. Merge IOU and mapping results
# 2. Select best pid_folder per (pid, frame_idx) using Dice
# 3. Copy masks to canonical pid folder
# 4. Reconstruct NIfTI from per-slice PNGs
# 5. Evaluate Dice similarity vs ground truth
#
# Supports multiple refinement stages via REFINE_PROFILES
# ==============================================================

import os, glob, shutil, json, re
import numpy as np
import pandas as pd
import cv2
import SimpleITK as sitk

from tqdm import tqdm
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from scipy.ndimage import label

# ==============================================================
# Configuration: Dataset & Stage Profiles
# ==============================================================

DATA_PROFILES = {
    "CT_word": {
        "base_dir": "videos/CT_word",
        "label_dir_2c": "data_in_nii/label_in_nii_L4L5_2class",
    },
    "CT_TT": {
        "base_dir": "videos/CT_TT",
        "label_dir_2c": "data_in_nii/label_in_nii_L4L5_2class",
    },
}

# Refinement step parameters (Stage 1 → Stage 3)
REFINE_PROFILES = {
    "stage1": {
        "iou_dataset_top": 0.10,  # keep top 10% at dataset level
        "iou_slice_top": 0.02,    # keep top 2% at slice level
        "dice_thresh": None,      # not applied at Stage 1
        "area_ratio": None        # not applied at Stage 1
    },
    "stage2": {
        "iou_dataset_top": 0.10,
        "iou_slice_top": None,
        "dice_thresh": 0.90,      # require DSC > 0.90 vs SAM2
        "area_ratio": 1.5         # smoothness: area ≤ 1.5x superior slice
    },
    "stage3": {
        "iou_dataset_top": 0.20,  # keep top 20%
        "iou_slice_top": None,
        "dice_thresh": 0.90,      # require DSC > 0.90 vs Stage1
        "area_ratio": 1.25        # stricter smoothness
    }
}

# ----------------------------
# Choose active dataset & stage
# ----------------------------
ACTIVE_DATASET = "CT_word"
ACTIVE_STAGE = "stage1"
SEED_COUNT = 1
SLICE_PROMPT = 1

P = DATA_PROFILES[ACTIVE_DATASET]
R = REFINE_PROFILES[ACTIVE_STAGE]

BASE_DIR = P["base_dir"]
DATA_DIR = os.path.join(BASE_DIR, "data_in_jpg_2class")
RESULT_ROOT = os.path.join(DATA_DIR, f"sam2_results_by_pid_{SEED_COUNT}seed_{SLICE_PROMPT}_slice_prompt")
IMG_MAPPING_ROOT = os.path.join(DATA_DIR, f"img_in_jpg_to_sam2_{SEED_COUNT}seed")

MERGED_CSV = os.path.join(DATA_DIR, f"{ACTIVE_DATASET}_2class_IOU_all_pid_with_full_mapping_{SEED_COUNT}seed_{SLICE_PROMPT}_slice_prompt.csv")
SAVE_SEG_DIR = os.path.join(BASE_DIR, f"data_in_nii/SAM2_auto_seg_nii_{SEED_COUNT}shot_2class_{SLICE_PROMPT}_slice_prompt")
os.makedirs(SAVE_SEG_DIR, exist_ok=True)

print(json.dumps({
    "dataset": ACTIVE_DATASET,
    "stage": ACTIVE_STAGE,
    "seed_count": SEED_COUNT,
    "slice_prompt": SLICE_PROMPT,
    "result_root": RESULT_ROOT,
    "merged_csv": MERGED_CSV,
    "save_seg_dir": SAVE_SEG_DIR
}, indent=2))


# ==============================================================
# Utility Functions
# ==============================================================

def keep_largest_connected_component(mask):
    mask = mask.astype(np.uint8)
    if mask.ndim == 3:
        structure = np.ones((3,3,3), np.uint8)
        labeled, num = label(mask, structure=structure)
        if num == 0:
            return mask
        counts = np.bincount(labeled.ravel())
        counts[0] = 0
        return (labeled == np.argmax(counts)).astype(np.uint8)
    elif mask.ndim == 2:
        num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
        if num_labels <= 1: return mask
        largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
        return (labels == largest).astype(np.uint8)
    else:
        raise ValueError("Unsupported mask dimension")

def save_array_to_nii(arr, ref_path, save_path):
    ref = sitk.ReadImage(ref_path)
    out = sitk.GetImageFromArray(arr)
    out.CopyInformation(ref)
    sitk.WriteImage(out, save_path)
    print(f"💾 Saved NIfTI: {save_path}")

def dice_binary(m1, m2):
    m1, m2 = (m1>0), (m2>0)
    inter = np.logical_and(m1,m2).sum()
    union = m1.sum() + m2.sum()
    return 1.0 if union==0 else 2*inter/union


# ==============================================================
# Step 1: Merge IOU & Mapping
# ==============================================================

all_records = []
for pid_folder in sorted(os.listdir(RESULT_ROOT)):
    iou_csv = os.path.join(RESULT_ROOT, pid_folder, f"{pid_folder}_iou_results.csv")
    map_csv = os.path.join(IMG_MAPPING_ROOT, pid_folder, f"{pid_folder}_mapping.csv")
    if not (os.path.exists(iou_csv) and os.path.exists(map_csv)):
        continue
    df_iou = pd.read_csv(iou_csv)
    df_map = pd.read_csv(map_csv)
    df = pd.merge(df_map, df_iou, on="frame_idx", how="left")
    df["pid_folder"] = pid_folder
    all_records.append(df)

df_all = pd.concat(all_records, ignore_index=True)
df_all.to_csv(MERGED_CSV, index=False)
print(f"📄 Merged CSV saved: {MERGED_CSV}")


# ==============================================================
# Step 2: Apply Refinement Filtering
# ==============================================================

# Example: keep only top X% by IoU
if R["iou_dataset_top"]:
    cutoff = df_all["iou_sum"].quantile(1 - R["iou_dataset_top"])
    df_all = df_all[df_all["iou_sum"] >= cutoff]

if R["iou_slice_top"]:
    df_all = df_all.groupby("slice_id").apply(
        lambda x: x.nlargest(max(1, int(len(x)*R["iou_slice_top"])), "iou_sum")
    ).reset_index(drop=True)

# Dice & area constraints would be applied here in Stage2/Stage3
# (needs SAM2 preds vs nnUNet preds; left as placeholder)
# if R["dice_thresh"]: ...
# if R["area_ratio"]: ...

print(f"✅ Filtered data count after {ACTIVE_STAGE}: {len(df_all)}")


# ==============================================================
# Step 3: Reconstruct NIfTI per PID
# ==============================================================

def process_pid(item):
    pid, group = item
    save_path = os.path.join(SAVE_SEG_DIR, f"{pid}_sam2_seg_2class.nii.gz")
    if os.path.exists(save_path): return
    group = group.sort_values("slice_id")
    img_paths = glob.glob(os.path.join(BASE_DIR, "data_in_nii/img_in_nii_L4L5", f"{pid}*.nii.gz"))
    if not img_paths: return
    ref_img = sitk.ReadImage(img_paths[0])
    H, W = sitk.GetArrayFromImage(ref_img).shape[1:]
    seg_slices = []
    for _, r in group.iterrows():
        frame_idx = r["frame_idx"]
        pred_dir = os.path.join(RESULT_ROOT, pid, "SAM2_seg_mask_nolap")
        slice_mask = np.zeros((H,W), np.uint8)
        for obj in [1,2]:
            png = os.path.join(pred_dir, f"frame_{frame_idx}_obj_{obj}.png")
            if not os.path.exists(png): continue
            m = cv2.imread(png, cv2.IMREAD_GRAYSCALE)
            if m is None or m.max()==0: continue
            m_bin = (m==np.unique(m)[1]).astype(np.uint8)
            m_bin = keep_largest_connected_component(m_bin)
            m_bin = cv2.resize(m_bin, (W,H), interpolation=cv2.INTER_NEAREST)
            slice_mask[m_bin>0] = obj
        seg_slices.append(slice_mask)
    if not seg_slices: return
    seg_stack = np.stack(seg_slices)
    save_array_to_nii(seg_stack, img_paths[0], save_path)

with ProcessPoolExecutor(max_workers=8) as exe:
    list(tqdm(exe.map(process_pid, df_all.groupby("pid")), total=df_all["pid"].nunique()))


# ==============================================================
# Step 4: Evaluate DSC
# ==============================================================

label_dir = os.path.join(BASE_DIR, P["label_dir_2c"])
results = []
for pred_path in tqdm(glob.glob(os.path.join(SAVE_SEG_DIR, "*.nii.gz")), desc="Eval"):
    pid = os.path.basename(pred_path).split("_sam2_seg_2class")[0]
    gt_paths = glob.glob(os.path.join(label_dir, f"{pid}*.nii.gz"))
    if not gt_paths: continue
    pred = sitk.GetArrayFromImage(sitk.ReadImage(pred_path))
    gt = sitk.GetArrayFromImage(sitk.ReadImage(gt_paths[0]))
    results.append([pid, dice_binary(pred==1, gt==1), dice_binary(pred==2, gt==2)])

df_dsc = pd.DataFrame(results, columns=["pid", "dsc_class1", "dsc_class2"])
df_dsc.to_csv(os.path.join(BASE_DIR, f"{ACTIVE_DATASET}_SAM2_auto_seg_DSC_summary_{SEED_COUNT}seed_2class.csv"), index=False)

print("✅ DSC summary saved")
print(df_dsc.describe())


In [None]:
df_dsc