In [None]:
import os
import re
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

# ==========================
# 1. Setup
# ==========================
# Input directory containing patient subfolders with slice images (.jpg).
# ⚠️ Update this path to match your dataset.
infer_dir = 'videos/CT_word/data_in_jpg_2class/img_in_jpg'

# Device configuration (use GPU if available, otherwise fall back to CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize DINOv2 feature extractor and model
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
model = AutoModel.from_pretrained("facebook/dinov2-large").to(device)
model.eval()

# ==========================
# 2. Helper functions
# ==========================
def extract_slice_number(filename):
    """Extract numeric slice index from filename (expects 'slice_xxx')."""
    m = re.search(r"slice_(\d+)", filename)
    return int(m.group(1)) if m else -1

def get_image_embedding(image_path):
    """Compute DINOv2 embedding for a single image."""
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).cpu().numpy()

def get_patient_embedding(img_files, pid_path):
    """
    Compute a representative embedding for a patient by averaging three key slices:
      - First slice  → approximates the L3/L4 disc level
      - Middle slice → approximates the L4/L5 disc level
      - Last slice   → approximates the L5/S1 disc level
    """
    img_files.sort(key=extract_slice_number)
    if not img_files:
        return None
    
    first_img = os.path.join(pid_path, img_files[0])
    mid_img = os.path.join(pid_path, img_files[len(img_files) // 2])
    last_img = os.path.join(pid_path, img_files[-1])
    
    embeddings = [get_image_embedding(p) for p in [first_img, mid_img, last_img]]
    return np.mean(embeddings, axis=0)

# ==========================
# 3. Collect embeddings
# ==========================
pid_list = []
embeddings = []

for pid in os.listdir(infer_dir):
    pid_path = os.path.join(infer_dir, pid)
    if not os.path.isdir(pid_path):
        continue

    img_files = [f for f in os.listdir(pid_path) if f.lower().endswith(".jpg")]
    if not img_files:
        continue

    emb = get_patient_embedding(img_files, pid_path)
    if emb is None:
        continue

    pid_list.append(pid)
    embeddings.append(emb)

print(f"✅ Found {len(pid_list)} patients. Extracting embeddings...")

# ==========================
# 4. Compute distances
# ==========================
embeddings = np.vstack(embeddings)
sim_matrix = cosine_similarity(embeddings)
dist_matrix = 1 - sim_matrix

# Compute average distance to all other patients
avg_distances = dist_matrix.sum(axis=1) / (len(pid_list) - 1)

# Select the most representative patient (lowest average distance)
top_index = np.argmin(avg_distances)
top_pid = pid_list[top_index]

# ==========================
# 5. Output result
# ==========================
print("\n🎯 Most representative patient (Top-1):")
print(f"- {top_pid}")


In [None]:
import os
import re
import shutil
import pandas as pd

# ==========================
# Configuration
# ==========================
dataset = "CT_word"

# Top-1 seed patient ID (selected as the most representative patient).
# ⚠️ Update this value manually based on the output from the previous step.
# Example: seed_pid = "word_0081_L4L5"
seed_pid = top_pid   # or replace with a fixed string if not running interactively

# Input directories (update according to your dataset structure)
infer_dir = f"videos/{dataset}/data_in_jpg_2class/img_in_jpg"
label_dir = f"videos/{dataset}/data_in_jpg_2class/label_in_png"

# Output directories (new datasets for SAM2 training with one seed)
output_img_base_dir = f"videos/{dataset}/data_in_jpg_2class/img_in_jpg_to_sam2_1seed"
output_label_base_dir = f"videos/{dataset}/data_in_jpg_2class/label_in_png_to_sam2_1seed"


def get_sorted_files(pid_dir):
    """
    Return all slice files in a patient directory, sorted by slice index (descending).
    Expected file naming convention: 'slice_xxx.jpg' or 'slice_xxx.png'.
    """
    files = [f for f in os.listdir(pid_dir) if f.lower().endswith((".jpg", ".png"))]

    def extract_slice_idx(fname):
        m = re.search(r"slice_(\d+)", fname)
        return int(m.group(1)) if m else -1

    files.sort(key=extract_slice_idx, reverse=True)
    return [os.path.join(pid_dir, f) for f in files]


def process_with_single_seed(seed_pid):
    """
    Combine slices from one seed patient and each inference patient to create
    paired datasets for SAM2 training.

    Args:
        seed_pid (str): Patient ID selected as the seed (reference).
    """
    # Exclude the seed from inference list
    all_pids = [d for d in os.listdir(infer_dir) if os.path.isdir(os.path.join(infer_dir, d))]
    infer_pids = [pid for pid in all_pids if pid != seed_pid]

    print(f"\n=== 🔁 Processing with single seed ===")
    print(f"Using seed_pid: {seed_pid}")
    print(f"Inference PIDs: {infer_pids}")

    for infer_pid in infer_pids:
        print(f"\n🚀 Processing inference PID: {infer_pid}")

        # Create target directories for this patient
        img_targ_dir = os.path.join(output_img_base_dir, infer_pid)
        label_targ_dir = os.path.join(output_label_base_dir, infer_pid)
        os.makedirs(img_targ_dir, exist_ok=True)
        os.makedirs(label_targ_dir, exist_ok=True)

        # Merge seed + current inference patient
        pids_all = [seed_pid, infer_pid]

        pid_to_img_slices = {pid: get_sorted_files(os.path.join(infer_dir, pid)) for pid in pids_all}
        pid_to_label_slices = {pid: get_sorted_files(os.path.join(label_dir, pid)) for pid in pids_all}

        global_id = 0
        max_slices = max(len(slices) for slices in pid_to_img_slices.values())
        mapping_records = []

        # Interleave slices layer by layer
        for layer in range(max_slices):
            for pid in pids_all:
                img_slices = pid_to_img_slices[pid]
                label_slices = pid_to_label_slices[pid]

                if layer < len(img_slices) and layer < len(label_slices):
                    # Copy image slice
                    src_img_path = img_slices[layer]
                    img_ext = os.path.splitext(src_img_path)[1][1:]
                    dst_img_name = f"{str(global_id).zfill(5)}.{img_ext}"
                    dst_img_path = os.path.join(img_targ_dir, dst_img_name)
                    shutil.copy(src_img_path, dst_img_path)

                    # Copy label slice
                    src_label_path = label_slices[layer]
                    label_ext = os.path.splitext(src_label_path)[1][1:]
                    dst_label_name = f"{str(global_id).zfill(5)}.{label_ext}"
                    dst_label_path = os.path.join(label_targ_dir, dst_label_name)
                    shutil.copy(src_label_path, dst_label_path)

                    # Record mapping metadata
                    src_img_file = os.path.basename(src_img_path)
                    m = re.search(r"slice_(\d+)", src_img_file)
                    slice_id = int(m.group(1)) if m else -1
                    category = "seed" if pid == seed_pid else "infer"

                    mapping_records.append({
                        "frame_idx": global_id,
                        "pid": pid,
                        "slice_id": slice_id,
                        "src_img_file": src_img_file,
                        "dst_img_file": dst_img_name,
                        "layer": layer,
                        "category": category
                    })

                    global_id += 1

        # Save mapping table (for traceability)
        mapping_df = pd.DataFrame(mapping_records)
        mapping_csv_path = os.path.join(img_targ_dir, f"{infer_pid}_mapping.csv")
        mapping_df.to_csv(mapping_csv_path, index=False)

        print(f"✅ Copied images and labels to: {img_targ_dir} / {label_targ_dir}")
        print(f"✅ Mapping table saved to: {mapping_csv_path}")


# ==========================
# Run single-seed processing
# ==========================
process_with_single_seed(seed_pid)
