In [3]:
import os
import re
import numpy as np
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModel
from sklearn.metrics.pairwise import cosine_similarity

# 设置路径和设备
infer_dir = 'videos/CT_TT/data_in_jpg_2class/img_in_jpg'
infer_dir = 'videos/CT_word/data_in_jpg_2class/img_in_jpg'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 初始化 DINOv2 模型
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')
model = AutoModel.from_pretrained('facebook/dinov2-large').to(device)
model.eval()

# 从文件名中提取切片编号
def extract_slice_number(filename):
    m = re.search(r'slice_(\d+)', filename)
    return int(m.group(1)) if m else -1

# 获取图像的特征嵌入
def get_image_embedding(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
    return embedding

# 遍历目录，获取每个 PID 的中间切片图像路径
pid_list = []
middle_img_paths = []

for pid in os.listdir(infer_dir):
    pid_path = os.path.join(infer_dir, pid)
    if not os.path.isdir(pid_path):
        continue

    img_files = [f for f in os.listdir(pid_path) if f.lower().endswith('.jpg')]
    if not img_files:
        continue

    img_files.sort(key=extract_slice_number)
    mid_img = img_files[len(img_files)//2]
    mid_img_path = os.path.join(pid_path, mid_img)

    pid_list.append(pid)
    middle_img_paths.append(mid_img_path)

print(f"共找到 {len(pid_list)} 个PID。正在提取中间切片图像的嵌入向量...")

# 提取每个图像的嵌入向量
embeddings = [get_image_embedding(path) for path in middle_img_paths]
embeddings = np.vstack(embeddings)

# 计算距离矩阵（1 - cosine similarity）
sim_matrix = cosine_similarity(embeddings)
dist_matrix = 1 - sim_matrix

# 计算每个 PID 到其他所有 PID 的平均距离
avg_distances = dist_matrix.sum(axis=1) / (len(pid_list) - 1)

# 获取平均距离最小的前1、3、5个 PID
top_indices = np.argsort(avg_distances)
top1_pids = [pid_list[top_indices[0]]]
top3_pids = [pid_list[i] for i in top_indices[:3]]
top5_pids = [pid_list[i] for i in top_indices[:5]]

# 分别打印 Top1, Top3, Top5 PIDs
print("\n最具代表性的 Top-1 PID:")
print(f"- {top1_pids[0]}")

print("\n最具代表性的 Top-3 PIDs:")
for p in top3_pids:
    print(f"- {p}")

print("\n最具代表性的 Top-5 PIDs:")
for p in top5_pids:
    print(f"- {p}")


共找到 95 个PID。正在提取中间切片图像的嵌入向量...

最具代表性的 Top-1 PID:
- word_0081_L4L5

最具代表性的 Top-3 PIDs:
- word_0081_L4L5
- word_0046_L4L5
- word_0026_L4L5

最具代表性的 Top-5 PIDs:
- word_0081_L4L5
- word_0046_L4L5
- word_0026_L4L5
- word_0070_L4L5
- word_0117_L4L5


In [None]:
seed_pid = ['s1382']
# seed_pid = ['s1382']
seed_pid = ['word_0081_L4L5']

In [15]:
import os
import shutil
import pandas as pd
import re

# === 配置路径 ===
dataset = 'CT_word'
full_seed_pids = ['word_0081_L4L5']

# dataset = 'CT_TT'
# full_seed_pids = ['s1382']

infer_dir = f'videos/{dataset}/data_in_jpg_2class/img_in_jpg'
label_dir = f'videos/{dataset}/data_in_jpg_2class/label_in_png'



def get_sorted_files(pid_dir):
    files = [f for f in os.listdir(pid_dir) if f.lower().endswith(('.jpg', '.png'))]
    def extract_slice_idx(fname):
        m = re.search(r'slice_(\d+)', fname)
        return int(m.group(1)) if m else -1
    files.sort(key=extract_slice_idx, reverse=True)
    return [os.path.join(pid_dir, f) for f in files]

def process_by_seed_count(seed_count):
    assert seed_count in [1, 3, 5], "Only seed counts of 1, 3, or 5 are supported"
    seed_pids = full_seed_pids[:seed_count]
    
    output_img_base_dir = f'videos/{dataset}/data_in_jpg_2class/img_in_jpg_to_sam2_{seed_count}seed'
    output_label_base_dir = f'videos/{dataset}/data_in_jpg_2class/label_in_png_to_sam2_{seed_count}seed'


    all_pids = [d for d in os.listdir(infer_dir) if os.path.isdir(os.path.join(infer_dir, d))]
    infer_pids = [pid for pid in all_pids if pid not in seed_pids]

    print(f"\n=== 🔁 处理 {seed_count}-seed 方案 ===")
    print(f"使用的 seed_pids: {seed_pids}")
    print(f"排除后的推理 PID 列表: {infer_pids}")

    for infer_pid in infer_pids:
        print(f"\n🚀 处理推理PID: {infer_pid}")

        img_targ_dir = os.path.join(output_img_base_dir, infer_pid)
        label_targ_dir = os.path.join(output_label_base_dir, infer_pid)
        os.makedirs(img_targ_dir, exist_ok=True)
        os.makedirs(label_targ_dir, exist_ok=True)

        pids_all = seed_pids + [infer_pid]

        pid_to_img_slices = {}
        pid_to_label_slices = {}
        for pid in pids_all:
            img_dir = os.path.join(infer_dir, pid)
            label_dir_path = os.path.join(label_dir, pid)
            pid_to_img_slices[pid] = get_sorted_files(img_dir)
            pid_to_label_slices[pid] = get_sorted_files(label_dir_path)

        global_id = 0
        max_slices = max(len(slices) for slices in pid_to_img_slices.values())
        mapping_records = []

        for layer in range(max_slices):
            for pid in pids_all:
                img_slices = pid_to_img_slices[pid]
                label_slices = pid_to_label_slices[pid]

                if layer < len(img_slices) and layer < len(label_slices):
                    src_img_path = img_slices[layer]
                    img_ext = os.path.splitext(src_img_path)[1][1:]
                    dst_img_name = f"{str(global_id).zfill(5)}.{img_ext}"
                    dst_img_path = os.path.join(img_targ_dir, dst_img_name)
                    shutil.copy(src_img_path, dst_img_path)

                    src_label_path = label_slices[layer]
                    label_ext = os.path.splitext(src_label_path)[1][1:]
                    dst_label_name = f"{str(global_id).zfill(5)}.{label_ext}"
                    dst_label_path = os.path.join(label_targ_dir, dst_label_name)
                    shutil.copy(src_label_path, dst_label_path)

                    src_img_file = os.path.basename(src_img_path)
                    m = re.search(r'slice_(\d+)', src_img_file)
                    slice_id = int(m.group(1)) if m else -1
                    category = 'seed' if pid in seed_pids else 'infer'

                    mapping_records.append({
                        'frame_idx': global_id,
                        'pid': pid,
                        'slice_id': slice_id,
                        'src_img_file': src_img_file,
                        'dst_img_file': dst_img_name,
                        'layer': layer,
                        'category': category
                    })

                    global_id += 1

        mapping_df = pd.DataFrame(mapping_records)
        mapping_csv_path = os.path.join(img_targ_dir, f'{infer_pid}_mapping.csv')
        mapping_df.to_csv(mapping_csv_path, index=False)

        print(f"✅ 图片和标签已复制至: {img_targ_dir} 和 {label_targ_dir}")
        print(f"✅ 映射表保存至: {mapping_csv_path}")


# === 运行所有种子数量配置 ===
process_by_seed_count(1)



=== 🔁 处理 1-seed 方案 ===
使用的 seed_pids: ['word_0081_L4L5']
排除后的推理 PID 列表: ['word_0056_L4L5', 'word_0138_L4L5', 'word_0018_L4L5', 'word_0004_L4L5', 'word_0126_L4L5', 'word_0100_L4L5', 'word_0121_L4L5', 'word_0029_L4L5', 'word_0114_L4L5', 'word_0122_L4L5', 'word_0036_L4L5', 'word_0010_L4L5', 'word_0135_L4L5', 'word_0142_L4L5', 'word_0109_L4L5', 'word_0003_L4L5', 'word_0090_L4L5', 'word_0044_L4L5', 'word_0065_L4L5', 'word_0064_L4L5', 'word_0128_L4L5', 'word_0002_L4L5', 'word_0130_L4L5', 'word_0042_L4L5', 'word_0150_L4L5', 'word_0047_L4L5', 'word_0087_L4L5', 'word_0009_L4L5', 'word_0012_L4L5', 'word_0070_L4L5', 'word_0055_L4L5', 'word_0068_L4L5', 'word_0115_L4L5', 'word_0020_L4L5', 'word_0049_L4L5', 'word_0104_L4L5', 'word_0145_L4L5', 'word_0143_L4L5', 'word_0091_L4L5', 'word_0006_L4L5', 'word_0107_L4L5', 'word_0079_L4L5', 'word_0078_L4L5', 'word_0146_L4L5', 'word_0105_L4L5', 'word_0058_L4L5', 'word_0086_L4L5', 'word_0148_L4L5', 'word_0125_L4L5', 'word_0123_L4L5', 'word_0072_L4L5', 'word_00

In [14]:
from PIL import Image
import numpy as np

# Load the image
img_path = "/home/zhongyi/Desktop/SAM2_SEG/notebooks/videos/CT_TT/data_in_jpg_2class/label_in_png_to_sam2_1seed/s0010/00000.png"
# img_path = '/home/zhongyi/Desktop/SAM2_SEG/notebooks/videos/CT_word/data_in_jpg_2class/label_in_png/word_0002_L4L5/word_0002_L4L5_slice_000_seg.png'
img = Image.open(img_path)

# Convert to numpy array
img_array = np.array(img)

# Get unique values
unique_values = np.unique(img_array)

# Print result
print(f"Unique values ({len(unique_values)}):", unique_values)


Unique values (3): [  0 100 200]
