In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from PIL import Image

import pandas as pd
import numpy as np

In [None]:
class MNISTVanilla(Dataset):
    def __init__(self, df):
        self.df = df

        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    torch.tensor([33.79 / 255.0 for _ in range(3)]),
                    torch.tensor([79.17 / 255.0 for _ in range(3)]),
                ),
            ]
        )

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img = Image.open(row["Image"])
        img = img.convert("RGB")

        return self.transform(img), torch.tensor([row["Label"]])

In [None]:
df = pd.read_csv("../data/mnist.csv")
df["Image"] = df["Image"].apply(lambda x: f"../data/{x}")
df.head()

In [None]:
mnist_vanilla = MNISTVanilla(df)

In [None]:
mnist_vanilla[0][0].shape

In [None]:
mnist_vanilla[0][1].shape

In [None]:
mnist_vanilla[0][0].min(), mnist_vanilla[0][0].max()

In [None]:
batch_size = 32

In [None]:
loader = DataLoader(mnist_vanilla, batch_size=batch_size, num_workers=8, shuffle=True)

In [None]:
for X, y in loader:
    print(X.shape)
    print(y.shape)
    break