In [3]:
import os
from torch.utils.data import Dataset

class CustomPatientDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        for filename in os.listdir(self.root):
            if filename.endswith('.png'):
                patient_id = filename.split('_')[1]  # Extract patient ID from filename
                image_path = os.path.join(self.root, filename)
                samples.append((patient_id, image_path))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        patient_id, image_path = self.samples[idx]
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, patient_id


In [4]:
from torch.utils.data import random_split
from torchvision import transforms
from PIL import Image

# 数据预处理和增强
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 创建数据集实例
dataset = CustomPatientDataset(root='/local/data1/honzh073/data/8bit_image', transform=transform)

# 划分数据集
total_size = len(dataset)
train_size = int(0.8 * total_size)
val_size = int(0.1 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)


ValueError: num_samples should be a positive integer value, but got num_samples=0