In [42]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import os

In [23]:
def delete_files(directory):
    # 获取目录下的所有文件和子目录
    files = os.listdir(directory)
    
    for file in files:
        file_path = os.path.join(directory, file)
        
        # 判断是否为文件
        if os.path.isfile(file_path):
            # 删除文件
            os.remove(file_path)
        elif os.path.isdir(file_path):
            # 如果是目录，则递归调用函数删除子目录中的文件
            delete_files(file_path)
            # 删除空目录
            os.rmdir(file_path)

In [25]:
def clean_up():
    delete_files('.\\preprocess\\train\\positive')
    delete_files('.\\preprocess\\train\\negative')
    delete_files('.\\preprocess\\val\\positive')
    delete_files('.\\preprocess\\val\\negative')

In [31]:
clean_up()

In [44]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.file_list = []
        self.labels = []

        positive_dir = os.path.join(data_dir, 'positive')
        negative_dir = os.path.join(data_dir, 'negative')

        positive_files = os.listdir(positive_dir)
        negative_files = os.listdir(negative_dir)

        self.file_list.extend(positive_files)
        self.file_list.extend(negative_files)

        self.labels.extend(['positive'] * len(positive_files))
        self.labels.extend(['negative'] * len(negative_files))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        file_path = os.path.join(self.data_dir, self.labels[idx], file_name)
        data = np.load(file_path)['matrix1']
        label = self.labels[idx]

        if self.transform:
            if 'train' in file_path:
                data = self.transform['train'](data)
            elif 'val' in file_path:
                data = self.transform['val'](data)

        return data, label

# 自定义转换函数，将numpy数组转换为PyTorch张量
def numpy_to_tensor(sample):
    return torch.from_numpy(sample)

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        numpy_to_tensor,
        transforms.Normalize((0.5,), (0.5,))
    ]),
    'val': transforms.Compose([
        transforms.Resize(224),
        numpy_to_tensor,
        transforms.Normalize((0.5,), (0.2,))
    ]),
}


In [50]:
data_dir = '.\\preprocess'

image_datasets = {x: CustomDataset(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

class_names = ['positive', 'negative']  # 如果您的数据集有类别标签，可以在这里添加类别名

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [51]:
print(class_names)

['positive', 'negative']


In [52]:
def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloader['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

TypeError: 'DataLoader' object is not subscriptable