In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.datasets import make_moons

In [2]:
import os
from torchvision.io import read_image

class MnistDataset(Dataset):

    def __init__(self, path, transform=None):
        self.path = path
        self.transform = transform
        self.image_files = []
        self.labels = []


    
        for filename in os.listdir(path=path):
            if filename.endswith(".png"):
                label = filename[0]
                self.image_files.append(filename)
                self.labels.append(label)


    def __len__(self):
        return(len(self.image_files))

    def __getitem__(self, index):

        image_path = os.path.join(self.path, self.image_files[index])
        image = read_image(image_path)
        label = self.labels[index]

        if self.transform:
            image = self.transform(image)

        return image, label



In [4]:
from torchvision.transforms import v2

transforms = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize(mean=(0.5, ), std=(0.5, ))
    
    ]
)

In [5]:
train_data = MnistDataset(path="../data/mnist/train", transform=transforms)

In [17]:
image, label = train_data[4783]

print('img:')
print(f"     {type(image)}")
print(f"     {image.shape}")
print(f"     {image.dtype}")
print("label:")
print(f"     {label}")

img:
     <class 'torchvision.tv_tensors._image.Image'>
     torch.Size([1, 28, 28])
     torch.float32
label:
     0
