In [None]:
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
def delete_files(directory):
    # 获取目录下的所有文件和子目录
    files = os.listdir(directory)
    
    for file in files:
        file_path = os.path.join(directory, file)
        
        # 判断是否为文件
        if os.path.isfile(file_path):
            # 删除文件
            os.remove(file_path)
        elif os.path.isdir(file_path):
            # 如果是目录，则递归调用函数删除子目录中的文件
            delete_files(file_path)
            # 删除空目录
            os.rmdir(file_path)

In [None]:
def clean_up():
    delete_files('.\\preprocess\\train\\positive')
    delete_files('.\\preprocess\\train\\negative')
    delete_files('.\\preprocess\\val\\positive')
    delete_files('.\\preprocess\\val\\negative')
    delete_files('.\\pic_positive')
    delete_files('.\\pic_negative')

In [None]:
#clean_up()

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.file_list = []
        self.labels = []

        # 获取数据文件夹中的文件列表和标签
        for label in os.listdir(data_dir):
            label_dir = os.path.join(data_dir, label)
            if os.path.isdir(label_dir):
                files = os.listdir(label_dir)
                self.file_list.extend(files)
                self.labels.extend([label] * len(files))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_name = self.file_list[idx]
        file_path = os.path.join(self.data_dir, self.labels[idx], file_name)
        data = np.load(file_path)['matrix1']
        label = self.labels[idx]

        if self.transform:
            data = self.transform(data)

        print(data.shape)
        return data, label

def numpy_to_pil(sample):
    sample = sample.squeeze()  # 去除单通道维度
    sample = Image.fromarray(sample, mode='L')  # 指定图像模式为单通道
    return sample

data_transforms = {
    'train': transforms.Compose([
        transforms.Lambda(numpy_to_pil),  # 添加自定义转换函数
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ]),
    'val': transforms.Compose([
        transforms.Lambda(numpy_to_pil),  # 添加自定义转换函数
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ]),
}

data_dir = '.\\preprocess'

image_datasets = {x: CustomDataset(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=0)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

class_names = ['positive', 'negative']  # 如果您的数据集有类别标签，可以在这里添加类别名

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
dataloaders['train']

In [None]:
iter_dataloader = iter(dataloaders['train'])
print(iter_dataloader)

In [None]:
next(iter_dataloader)

In [None]:
print(class_names)

In [None]:
def imshow(inp, title=None):
    """Display image for Tensor."""
    print(inp.shape)
    inp = inp.squeeze().numpy()
    print(inp.shape)
    print(inp)
    plt.imshow(inp, cmap='gray', vmin=-0.5, vmax=0.5)  # 指定灰度色彩映射
    plt.colorbar()
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))
print(inputs, classes)
# Make a grid from batch
out = torchvision.utils.make_grid(inputs.unsqueeze(0))  # 将输入转换为四维张量
print(out[3])

imshow(out[0], title=[x for x in classes])

In [None]:
data = np.load('.\\preprocess\\train\\positive\\')['matrix1']