In [1]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import shutil

In [2]:

# ------------------ 1. 定义数据集类 ------------------
class FundusOCTDataset(Dataset):
    def __init__(self, csv_file, cf_root_dir, oct_root_dir=None, transform=None, is_train=True):
        self.data = pd.read_csv(csv_file)
        self.cf_root_dir = cf_root_dir
        self.oct_root_dir = oct_root_dir
        self.transform = transform
        self.is_train = is_train

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 获取眼底图路径
        cf_filename = self.data.iloc[idx]['CF_path']
        fundus_img_path = os.path.join(self.cf_root_dir, cf_filename)

        # 读取眼底图
        if not os.path.exists(fundus_img_path):
            raise FileNotFoundError(f"Eye fundus image not found: {fundus_img_path}")

        fundus_img = Image.open(fundus_img_path).convert("RGB")

        # 应用变换
        if self.transform:
            fundus_img = self.transform(fundus_img)

        return fundus_img, cf_filename



In [3]:
# ------------------ 2. 定义图像转换 ------------------
transform = transforms.Compose([
    transforms.Resize((768, 496)),  # 调整图像大小
    transforms.ToTensor()
])



In [None]:
# ------------------ 3. 加载训练和验证数据集 ------------------
train_cf_root = 'D:\\Project\\train\\train_CF'
train_oct_root = 'D:\\Project\\train\\OCT'
train_csv_path = 'D:\\Project\\train\\train.csv'

val_cf_root = 'D:\\Project\\val\\val_CF'
val_csv_path = 'D:\\Project\\val\\val.csv'

train_dataset = FundusOCTDataset(csv_file=train_csv_path,
                                 cf_root_dir=train_cf_root,
                                 oct_root_dir=train_oct_root,
                                 transform=transform,
                                 is_train=True)

val_dataset = FundusOCTDataset(csv_file=val_csv_path,
                               cf_root_dir=val_cf_root,
                               transform=transform,
                               is_train=False)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [None]:
# ------------------ 4. 匹配验证集眼底照与训练集眼底照 ------------------
feature_extractor = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
feature_extractor.fc = torch.nn.Identity()  # 移除全连接层，以获得特征向量
feature_extractor = feature_extractor.to(device)
feature_extractor.eval()

# 提取训练集眼底图像的特征
train_features = []
train_cf_paths = []

with torch.no_grad():
    for fundus_img, cf_filename in train_loader:
        fundus_img = fundus_img.to(device)
        feature = feature_extractor(fundus_img).cpu().numpy()
        train_features.append(feature)
        train_cf_paths.append(cf_filename[0])
train_features = np.vstack(train_features)

# 匹配验证集眼底图像
output_dir = 'val_output/generator_OCT_images'
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
    for i, (val_fundus_img, val_cf_filename) in enumerate(val_loader):
        val_fundus_img = val_fundus_img.to(device)
        val_feature = feature_extractor(val_fundus_img).cpu().numpy()

        # 计算相似度并找到最相似的眼底图像
        similarities = cosine_similarity(val_feature, train_features)
        best_match_idx = np.argmax(similarities)
        best_match_cf_path = train_cf_paths[best_match_idx]

        # 根据匹配到的眼底照找到对应的 OCT 图像文件夹并复制
        best_match_oct_folder_name = train_dataset.data.iloc[best_match_idx]['OCT_path']
        best_match_oct_folder_path = os.path.join(train_oct_root, best_match_oct_folder_name)

        # 创建验证集输出文件夹
        val_folder_name = os.path.splitext(val_cf_filename[0])[0]
        output_path = os.path.join(output_dir, val_folder_name)
        os.makedirs(output_path, exist_ok=True)

        # 复制六张 OCT 图像
        for j in range(6):
            src_oct_img_path = os.path.join(best_match_oct_folder_path, f"{best_match_oct_folder_name}_{j}.jpg")
            dst_oct_img_path = os.path.join(output_path, f'{val_folder_name}_{j}.jpg')
            if os.path.isfile(src_oct_img_path):
                shutil.copy(src_oct_img_path, dst_oct_img_path)

print("验证集OCT图像生成完成！")