In [None]:
# ======================== Setup & Initialization ========================
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
from multiprocessing import Pool, cpu_count

# Configure device
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")




In [None]:
# ------------------------ Load SAM2 Predictor --------------------------
# Import SAM2 video predictor builder

from sam2.build_sam import build_sam2_video_predictor

# NOTE: You must configure the following paths manually before running.
# - sam2_checkpoint: Path to the SAM2 model checkpoint (.pt file).
#   Download from the official release and place it under "checkpoints/" or another folder.
# - model_cfg: Path to the corresponding model config (.yaml file).
#   Ensure this matches the checkpoint version.
sam2_checkpoint = "checkpoints/sam2.1_hiera_large.pt"   # <-- Change if your path differs
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"        # <-- Change if your path differs

# Build predictor instance with the specified configuration and checkpoint
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=device)

In [None]:
# ======================== Utility Functions ========================

def show_mask(mask, ax, obj_id=None, random_color=False):
    """Overlay mask on an image plot."""
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def process_frame(args):
    """Save segmentation masks and visualization for a single frame (multiprocessing-safe)."""
    key, idx, video_segments_todo, video_dir, frame_names, prompt_dir = args
    seg_dir = os.path.join(prompt_dir, f"SAM2_seg_mask_{key}")
    vis_dir = os.path.join(prompt_dir, f"SAM2_seg_plot_{key}")

    frame_path = os.path.join(video_dir, frame_names[idx])
    frame_image = Image.open(frame_path)
    combined_mask = None

    for obj_id, mask in video_segments_todo[idx].items():
        mask = np.squeeze(mask) if mask.ndim > 2 else mask
        scaled_mask = mask * (obj_id * 50)

        Image.fromarray(scaled_mask.astype("uint8")).save(
            os.path.join(seg_dir, f"frame_{idx}_obj_{obj_id}.png"), "PNG")

        combined_mask = np.maximum(combined_mask, scaled_mask) if combined_mask is not None else scaled_mask

    if combined_mask is not None:
        Image.fromarray(combined_mask.astype("uint8")).save(
            os.path.join(seg_dir, f"frame_{idx}_combined.png"), "PNG")

    plt.figure(figsize=(6, 4))
    plt.title(f"Frame {idx} ({key})")
    plt.imshow(frame_image, cmap="gray")
    for obj_id, mask in video_segments_todo[idx].items():
        show_mask(np.squeeze(mask) if mask.ndim > 2 else mask, plt.gca(), obj_id=obj_id)
    plt.savefig(os.path.join(vis_dir, f"frame_{idx}_visualization.png"), format="png", dpi=300)
    plt.close()
    print(f"🖼️ Processed frame {idx} for {key}")


In [None]:
# ======================== Main PID Processing ========================
def process_one_pid(pid, data_dir, seed_count):
    print(f"\n🚀 Processing {pid} with {seed_count}-seed")

    # Directories
    img_dir = os.path.join(data_dir, f"img_in_jpg_to_sam2_{seed_count}seed", pid)
    label_dir = os.path.join(data_dir, f"label_in_png_to_sam2_{seed_count}seed", pid)
    prompt_dir = os.path.join(data_dir, f"sam2_results_by_pid_{seed_count}seed", pid)    
    csv_map_path = os.path.join(img_dir, f"{pid}_mapping.csv")

    # Skip if results already exist
    seg_mask_yeslap_dir = os.path.join(prompt_dir, "SAM2_seg_plot_yeslap")
    if os.path.isdir(seg_mask_yeslap_dir) and len(os.listdir(seg_mask_yeslap_dir)) > 0:
        print(f"✅ Results already exist for {pid}, skipping.")
        return
    
    os.makedirs(prompt_dir, exist_ok=True)

    # Load mapping
    df_map = pd.read_csv(csv_map_path)
    seed_rows = df_map[df_map["category"] == "seed"]
    prompt_idx = sorted(seed_rows["frame_idx"].tolist())

    # Prepare frames
    frame_names = sorted([p for p in os.listdir(img_dir) if p.lower().endswith(".jpg")],
                         key=lambda p: int(os.path.splitext(p)[0]))

    # Initialize predictor state
    inference_state = predictor.init_state(video_path=img_dir)
    predictor.reset_state(inference_state)

    # Class ranges (example thresholds)
    class_ranges = {1: (80, 120), 2: (180, 220)}

    # Add prompt masks
    for idx in prompt_idx:
        img_path = os.path.join(img_dir, frame_names[idx])
        seg_path = os.path.join(label_dir, str(idx).zfill(5) + ".png")

        try:
            seg_array = np.array(Image.open(seg_path).convert("L"))
            img_array = np.array(Image.open(img_path).convert("L"))
        except Exception as e:
            print(f"❌ Error reading image {idx}: {e}")
            continue

        class_masks = {}
        for class_id, (low, high) in class_ranges.items():
            mask = (seg_array >= low) & (seg_array <= high)
            class_mask = np.zeros_like(seg_array, dtype=np.uint8)
            class_mask[mask] = 1
            class_masks[class_id] = np.array(
                Image.fromarray(class_mask).resize(img_array.shape[::-1], Image.NEAREST))

        for ann_obj_id, class_mask in class_masks.items():
            try:
                predictor.add_new_mask(inference_state, frame_idx=idx, obj_id=ann_obj_id, mask=class_mask)
            except Exception as e:
                print(f"   ❌ Error in add_new_mask: {e}")

        print(f"✅ Added prompts for frame {idx} ({frame_names[idx]})")
        
# ======================== Inference & IOU Saving ========================
    video_segments, video_segments_nolap, iou_score = {}, {}, {}

    for out in predictor.propagate_in_video(
        inference_state, iou_score_return=True, max_frame_num_to_track=len(frame_names)
    ):
        out_frame_idx, out_obj_ids, out_mask_logits = out[:3]
        ious = out[3]

        # Overlapping masks
        video_segments[out_frame_idx] = {
            obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
            for i, obj_id in enumerate(out_obj_ids)}

        # IOU scores
        iou_score[out_frame_idx] = {obj_id: ious[i] for i, obj_id in enumerate(out_obj_ids)}

        # Non-overlapping masks
        frame_masks = {
            obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() * out_mask_logits[i].cpu().numpy()
            for i, obj_id in enumerate(out_obj_ids)}
        stacked_masks = np.stack([frame_masks[obj_id] > 0 for obj_id in frame_masks], axis=0)
        stacked_logits = np.stack([frame_masks[obj_id] for obj_id in frame_masks], axis=0)
        max_logits_indices = np.argmax(stacked_logits, axis=0)

        video_segments_nolap[out_frame_idx] = {
            obj_id: (max_logits_indices == i).astype(np.uint8) & stacked_masks[i]
            for i, obj_id in enumerate(frame_masks)}

    # Save IOU results
    iou_csv_path = os.path.join(prompt_dir, pid + "_iou_results.csv")
    results = []
    for f_idx in range(len(frame_names)):
        for o_id, iou in iou_score[f_idx].items():
            max_iou = torch.max(iou).item()
            results.append({"frame_idx": f_idx, "obj_id": o_id, "iou": max_iou})
    results_df = pd.DataFrame(results)
    wide_df = results_df.pivot(index="frame_idx", columns="obj_id", values="iou")
    wide_df.columns = [f"iou_id{col}" for col in wide_df.columns]
    wide_df["iou_sum"] = wide_df.sum(axis=1)
    wide_df.to_csv(iou_csv_path, index=True)
    print(f"📄 Saved IOU results to {iou_csv_path}")

# ======================== Save Masks & Visualizations ========================
    video_segments_dict = {"yeslap": video_segments, "nolap": video_segments_nolap}
    for key in video_segments_dict:
        os.makedirs(os.path.join(prompt_dir, f"SAM2_seg_mask_{key}"), exist_ok=True)
        os.makedirs(os.path.join(prompt_dir, f"SAM2_seg_plot_{key}"), exist_ok=True)

    tasks = []
    for key, video_segments_todo in video_segments_dict.items():
        for idx in range(len(frame_names)):
            tasks.append((key, idx, video_segments_todo, img_dir, frame_names, prompt_dir))

    with Pool(min(2, cpu_count())) as pool:
        pool.map(process_frame, tasks)

    print(f"🎉 Finished processing {pid}")


In [None]:
# ================================
# Batch Processing of All PIDs
# ================================

import os
import re
import torch
from tqdm import tqdm  # progress bar

def sort_key(pid):
    """
    Sort PIDs numerically if possible, otherwise lexicographically.
    Example: pid_2 < pid_10
    """
    m = re.search(r'\d+', pid.split('_')[0])
    return (False, int(m.group()), pid.lower()) if m else (True, 0, pid.lower())

# --- Configuration ---
dataset = "CT_word"  # or "CT_TT"
data_dir = os.path.join("notebooks", "videos", dataset, "data_in_jpg")
seed_counts = [1]  # list of seed numbers to run, e.g. [1, 3, 5]

# --- Collect all PIDs from reference (1-seed folder) ---
ref_dir = os.path.join(data_dir, "img_in_jpg_to_sam2_1seed")
pids = sorted(
    [f for f in os.listdir(ref_dir) if os.path.isdir(os.path.join(ref_dir, f))],
    key=sort_key
)

print(f"🔍 Found {len(pids)} PIDs in dataset: {dataset}")

# --- Process each PID across all seeds ---
for pid in tqdm(pids, desc="Processing PIDs"):
    for seed in seed_counts:
        pid_dir = os.path.join(data_dir, f"img_in_jpg_to_sam2_{seed}seed", pid)
        if os.path.exists(pid_dir):
            try:
                process_one_pid(pid, data_dir, seed)
            except Exception as e:
