In [25]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms

In [26]:
def unpickle(file):
    print(f"Unpickling file: {file}")
    with open(file, 'rb') as fo:
        data_dict = pickle.load(fo, encoding='bytes')
    print("Finished unpickling.")
    return data_dict

In [27]:
data_dir = "../data/cifar-10-python/cifar-10-batches-py"

In [28]:
print(f"Data directory: {data_dir}")

Data directory: ../data/cifar-10-python/cifar-10-batches-py


### Load each of the five batches

In [29]:
data_list = []
labels_list = []

for i in range(1, 6):
    batch_file = os.path.join(data_dir, f"data_batch_{i}")
    print(f"Loading batch file: {batch_file}")
    batch = unpickle(batch_file)
    data = batch[b"data"]  # Shape: (10000, 3072)
    labels = batch[b"labels"]  # List of 10000 labels
    data_list.append(data)
    labels_list.extend(labels)
    print(f"Loaded batch {i}: data shape {data.shape}, labels count {len(labels)}")

Loading batch file: ../data/cifar-10-python/cifar-10-batches-py/data_batch_1
Unpickling file: ../data/cifar-10-python/cifar-10-batches-py/data_batch_1
Finished unpickling.
Loaded batch 1: data shape (10000, 3072), labels count 10000
Loading batch file: ../data/cifar-10-python/cifar-10-batches-py/data_batch_2
Unpickling file: ../data/cifar-10-python/cifar-10-batches-py/data_batch_2
Finished unpickling.
Loaded batch 2: data shape (10000, 3072), labels count 10000
Loading batch file: ../data/cifar-10-python/cifar-10-batches-py/data_batch_3
Unpickling file: ../data/cifar-10-python/cifar-10-batches-py/data_batch_3
Finished unpickling.
Loaded batch 3: data shape (10000, 3072), labels count 10000
Loading batch file: ../data/cifar-10-python/cifar-10-batches-py/data_batch_4
Unpickling file: ../data/cifar-10-python/cifar-10-batches-py/data_batch_4
Finished unpickling.
Loaded batch 4: data shape (10000, 3072), labels count 10000
Loading batch file: ../data/cifar-10-python/cifar-10-batches-py/data

In [30]:
all_data = np.concatenate(data_list, axis=0)
print("Concatenated data shape (before reshaping):", all_data.shape)

Concatenated data shape (before reshaping): (50000, 3072)


In [31]:
all_data = all_data.reshape(-1, 3, 32, 32)
print("Data reshaped to:", all_data.shape)

Data reshaped to: (50000, 3, 32, 32)


In [32]:
all_labels = np.array(labels_list)
print("All labels shape:", all_labels.shape)

All labels shape: (50000,)


In [33]:
class CIFAR10Dataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        # Convert image to torch tensor and float type
        image = torch.from_numpy(image).float()
        if self.transform:
            image = self.transform(image)
        return image, label

### A normalization transform

In [34]:
transform = transforms.Compose([
    transforms.Normalize((125.3, 123.0, 113.9), (63.0, 62.1, 66.7))
])
print("Normalization transform created.")

Normalization transform created.


In [35]:
dataset = CIFAR10Dataset(all_data, all_labels, transform=transform)
print("CIFAR10Dataset created with length:", len(dataset))

CIFAR10Dataset created with length: 50000


### Splitted the dataset into training (45,000) and validation (5,000) sets

In [36]:
train_size = 45000
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print("Dataset split:")
print("  Training samples:", len(train_dataset))
print("  Validation samples:", len(val_dataset))

Dataset split:
  Training samples: 45000
  Validation samples: 5000


### Created the DataLoaders for the training and validation sets

In [37]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
print("DataLoaders created for training and validation.")

DataLoaders created for training and validation.


In [38]:
for images, labels in train_loader:
    print("First batch images shape:", images.shape)
    print("First batch labels shape:", labels.shape)
    break

First batch images shape: torch.Size([64, 3, 32, 32])
First batch labels shape: torch.Size([64])
