In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

In [None]:
class MNISTDataset(Dataset):
    def __init__(self, root_dir, split="train", transform=None):
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        
        self.data_path = os.path.join(self.root_dir, self.split)

        if not os.path.exists(self.data_path):
            raise FileNotFoundError(f"{self.data_path} bulunamadı.")

        self.samples = []
        self._load_data()

    def _load_data(self):
        for img_name in os.listdir(self.data_path):
            if img_name.endswith((".png", ".jpg", ".jpeg")):
                img_path = os.path.join(self.data_path, img_name)
                label = int(img_name.split("_")[0])
                self.samples.append((img_path, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("L")

        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()
])

In [None]:
DATA_PATH = "."  # train ve test aynı klasördeyse

train_dataset = MNISTDataset(
    root_dir=DATA_PATH,
    split="train",
    transform=transform
)

print("Toplam örnek:", len(train_dataset))

In [None]:
image, label = train_dataset[0]

print("Image shape:", image.shape)
print("Label:", label)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=0
)

In [None]:
images, labels = next(iter(train_loader))

print("Batch image shape:", images.shape)
print("Batch label shape:", labels.shape)

In [None]:
plt.figure(figsize=(6,6))

for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(images[i].squeeze().numpy(), cmap="gray")
    plt.title(str(labels[i].item()))
    plt.axis("off")

plt.tight_layout()
plt.show()