In [2]:
import os
import torchvision
import torchvision.transforms as transforms

# 下载MNIST数据集
transform = transforms.Compose([transforms.ToTensor()])
mnist_trainset = torchvision.datasets.MNIST(root='./data_1224', train=True, download=True, transform=transform)
mnist_testset = torchvision.datasets.MNIST(root='./data_1224', train=False, download=True, transform=transform)

# 创建一个目录来保存图片
os.makedirs('./mnist_images/train', exist_ok=True)
os.makedirs('./mnist_images/test', exist_ok=True)

# 遍历数据集并保存图片
for idx, (image, label) in enumerate(mnist_trainset):
    # 创建类别目录
    label_dir = os.path.join('./mnist_images/train', str(label))
    os.makedirs(label_dir, exist_ok=True)
    # 转换为PIL图像并保存
    pil_image = transforms.ToPILImage()(image)
    pil_image.save(os.path.join(label_dir,f'{idx}.jpg'))

for idx, (image, label) in enumerate(mnist_testset):
    # 创建类别目录
    label_dir = os.path.join('./mnist_images/test', str(label))
    os.makedirs(label_dir, exist_ok=True)
    # 转换为PIL图像并保存
    pil_image = transforms.ToPILImage()(image)
    pil_image.save(os.path.join(label_dir,f'{idx}.jpg'))

print("All images have been saved successfully.")

All images have been saved successfully.


In [16]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import cv2 as cv


class MNISTDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_list = []
        self.name_list = []
        self.id_list = []
        for root, dirs, files in os.walk(self.root_dir):
            if dirs:
                self.name_list = dirs
            for file_i in files:
                file_i_full_path = os.path.join(root, file_i)
                # 欧式。path.split默认会将完整路径分割成(header,tail)两部分
                file_class = os.path.split(file_i_full_path)[0].split('\\')[-1]
                self.id_list.append(self.name_list.index(file_class))
                self.file_list.append(file_i_full_path)
                
    def __len__(self):
        return len(self.file_list)
        
    def __getitem__(self, idx):
        img = self.file_list[idx]
        # 0: 以灰度图模式读取图片
        img = cv.imread(img, 0)
        img = cv.resize(img, dsize=(28, 28))
        # float()便于后续浮点计算
        img = torch.from_numpy(img).float()
        label = self.id_list[idx]
        label = torch.tensor(label)
        return img, label

my_dataset_train = MNISTDataset(r'./mnist_images/train')
my_dataloader_train = DataLoader(my_dataset_train, batch_size=10, shuffle=True)
print("Read all train dataset")
for x, y in my_dataloader_train:
    print(x.type(), x.shape, y)
    break

Read all train dataset
torch.FloatTensor torch.Size([10, 28, 28]) tensor([0, 0, 4, 1, 8, 2, 1, 7, 8, 2])
