In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

# 自定义转换函数，将numpy数组转换为PyTorch张量
def numpy_to_tensor(sample):
    return torch.from_numpy(sample)

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        numpy_to_tensor
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        numpy_to_tensor
    ]),
}

# 创建自定义数据集对象
train_dataset = CustomDataset(train_data, transform=data_transforms['train'])
val_dataset = CustomDataset(val_data, transform=data_transforms['val'])

# 创建数据加载器
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)

dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
class_names = train_dataset.classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")