In [1]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

def unpickle(path: str) -> dict:
    with open(path, "rb") as f:
        return pickle.load(f, encoding="bytes")

class CIFAR10(Dataset):
    """
    CIFAR-10 Dataset (from cifar-10-batches-py).
    Returns (image_tensor, label_tensor), where image_tensor is float32 (3,32,32).
    """
    def __init__(self, root: str, split: str = "train", transform=None, normalize: bool = False):
        self.root = root
        self.split = split.lower().strip()
        self.transform = transform
        self.normalize = normalize

        cifar_dir = os.path.join(root, "cifar-10-batches-py")
        if not os.path.isdir(cifar_dir):
            raise FileNotFoundError(f"Folder not found: {cifar_dir}")

        data = []
        labels = []

        if self.split == "train":
            for i in range(1, 6):
                batch = unpickle(os.path.join(cifar_dir, f"data_batch_{i}"))
                data.append(batch[b"data"])
                labels.extend(batch[b"labels"])
            self.x = np.vstack(data).astype(np.uint8)          # (50000, 3072)
            self.y = np.array(labels, dtype=np.int64)          # (50000,)
        elif self.split == "test":
            batch = unpickle(os.path.join(cifar_dir, "test_batch"))
            self.x = batch[b"data"].astype(np.uint8)           # (10000, 3072)
            self.y = np.array(batch[b"labels"], dtype=np.int64)
        else:
            raise ValueError("split must be 'train' or 'test'")

        # TODO: add sourc here
        # CIFAR-10 normalization (optional)
        self.mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
        self.std  = torch.tensor([0.2023, 0.1994, 0.2010]).view(3, 1, 1)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx: int):
        flat = self.x[idx]  # (3072,)
        img = torch.from_numpy(flat).view(3, 32, 32).float() / 255.0  # -> [0,1]

        if self.transform is not None:
            img = self.transform(img)

        if self.normalize:
            img = (img - self.mean) / self.std

        label = torch.tensor(int(self.y[idx]), dtype=torch.long)
        return img, label

def get_cifar10_dataloaders(root: str, batch_size: int = 128, num_workers: int = 2,
                           normalize: bool = False, train_transform=None, test_transform=None):
    train_ds = CIFAR10(root=root, split="train", transform=train_transform, normalize=normalize)
    test_ds  = CIFAR10(root=root, split="test",  transform=test_transform,  normalize=normalize)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds, batch_size=batch_size*2, shuffle=False,
                              num_workers=num_workers, pin_memory=True)
    return train_loader, test_loader

In [4]:
train_loader, _ = get_cifar10_dataloaders(
    root="./data/",
    batch_size=128,
    num_workers=0,      # wichtig unter Windows
    normalize=False,
    train_transform=None,
    test_transform=None,
)

x, y = next(iter(train_loader))
print(x.shape, y[:10])


torch.Size([128, 3, 32, 32]) tensor([3, 9, 8, 1, 4, 4, 0, 8, 1, 4])


In [5]:
# usage_example_cifar_loader.py

import torch

# No augmentation, no extra transforms
train_loader, test_loader = get_cifar10_dataloaders(
    root="./data/",          # folder that contains ./cifar-10-batches-py/
    batch_size=128,
    num_workers=0,
    normalize=False,   # keep raw [0,1] tensors
    train_transform=None,
    test_transform=None,
)

# Show one train batch
x, y = next(iter(train_loader))
print("TRAIN BATCH")
print("x shape:", x.shape)          # (B, 3, 32, 32)
print("x dtype:", x.dtype)          # torch.float32
print("x min/max:", float(x.min()), float(x.max()))
print("y shape:", y.shape)          # (B,)
print("y dtype:", y.dtype)          # torch.int64
print("first 10 labels:", y[:10].tolist())

# Show one test batch
x2, y2 = next(iter(test_loader))
print("\nTEST BATCH")
print("x shape:", x2.shape)
print("x min/max:", float(x2.min()), float(x2.max()))
print("first 10 labels:", y2[:10].tolist())

# Show dataset sizes
print("\nDATASET SIZES")
print("train steps per epoch:", len(train_loader))
print("test steps:", len(test_loader))


TRAIN BATCH
x shape: torch.Size([128, 3, 32, 32])
x dtype: torch.float32
x min/max: 0.0 1.0
y shape: torch.Size([128])
y dtype: torch.int64
first 10 labels: [9, 1, 6, 4, 1, 9, 2, 2, 7, 5]

TEST BATCH
x shape: torch.Size([256, 3, 32, 32])
x min/max: 0.0 1.0
first 10 labels: [3, 8, 8, 0, 6, 6, 1, 6, 3, 1]

DATASET SIZES
train steps per epoch: 391
test steps: 40


In [None]:
# !python -m pip uninstall -y torch torchvision torchaudio
# !python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118


Found existing installation: torch 2.10.0
Uninstalling torch-2.10.0:
  Successfully uninstalled torch-2.10.0
Found existing installation: torchvision 0.25.0
Uninstalling torchvision-0.25.0:
  Successfully uninstalled torchvision-0.25.0


You can safely remove it manually.
You can safely remove it manually.


Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp313-cp313-win_amd64.whl.metadata (27 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.22.1%2Bcu118-cp313-cp313-win_amd64.whl.metadata (6.3 kB)
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.7.1%2Bcu118-cp313-cp313-win_amd64.whl.metadata (6.8 kB)
Downloading https://download.pytorch.org/whl/cu118/torch-2.7.1%2Bcu118-cp313-cp313-win_amd64.whl (2817.2 MB)
   ---------------------------------------- 0.0/2.8 GB ? eta -:--:--
   ---------------------------------------- 0.0/2.8 GB 27.0 MB/s eta 0:01:45
   ---------------------------------------- 0.0/2.8 GB 27.7 MB/s eta 0:01:42
   ---------------------------------------- 0.0/2.8 GB 26.2 MB/s eta 0:01:47
   ---------------------------------------- 0.0/2.8 GB 24.8 MB/s eta 0:01:53
   ----------------------------