In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import os

class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.file_list = os.listdir(data_dir)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = os.path.join(self.data_dir, self.file_list[idx])
        data = np.load(file_path)['arr_0']
        if self.transform:
            data = self.transform(data)
        return data

# 自定义转换函数，将numpy数组转换为PyTorch张量
def numpy_to_tensor(sample):
    return torch.from_numpy(sample)

data_dir = 'your_data_directory'

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        numpy_to_tensor
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        numpy_to_tensor
    ]),
}

dataset = CustomDataset(data_dir, transform=data_transforms['train'])
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)

dataset_size = len(dataset)
class_names = []  # 如果您的数据集有类别标签，可以在这里添加类别名

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")