In [None]:
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

# ------------------------------------------------------------------
# Load the pre-downloaded CIFAR-10 batches (NumPy only)
# ------------------------------------------------------------------
# CIFAR-10 is stored in `Dataset/cifar-10-python/cifar-10-batches-py`
data_root = Path("Dataset") / "cifar-10-python" / "cifar-10-batches-py"

def _load_cifar_batch(file_path):
    """Load a single CIFAR-10 batch file (Python pickled dict).
    Returns flat images (N, 3072) float32 in [0, 1] and labels (N,) int64.
    """
    import pickle

    with open(file_path, "rb") as f:
        batch = pickle.load(f, encoding="bytes")
    # Keys are b"data" and b"labels" in the original CIFAR-10 python version
    # Convert raw uint8 pixels to normalized flat vectors in [0, 1]
    data = batch[b"data"].astype(np.float32) / 255.0   # (N, 3072)
    labels = np.array(batch[b"labels"], dtype=np.int64)  # (N,)
    return data, labels

# Load training batches (data_batch_1-5)
train_images_list = []
train_labels_list = []
for i in range(1, 6):
    batch_file = data_root / f"data_batch_{i}"
    imgs, lbls = _load_cifar_batch(batch_file)
    train_images_list.append(imgs)
    train_labels_list.append(lbls)

x_train_full = np.concatenate(train_images_list, axis=0)  # (50000, 3072)
y_train_full = np.concatenate(train_labels_list, axis=0)  # (50000,)
print("full train images", x_train_full.shape)
print("full train labels", y_train_full.shape)

# Load test batch
x_test_full, y_test_full = _load_cifar_batch(data_root / "test_batch")  # (10000, 3072), (10000,)
print("full test images", x_test_full.shape)
print("full test labels", y_test_full.shape)
# ------------------------------------------------------------------
# Create subsets: 40,000 for training, 10,000 for validation (from 50,000 train)
# Same style as Fashion-MNIST notebook
# ------------------------------------------------------------------
x_train = x_train_full[:40000]
y_train = y_train_full[:40000]

x_valid = x_train_full[40000:50000]
y_valid = y_train_full[40000:50000]

x_test = x_test_full
y_test = y_test_full

# Number of classes (should be 10 for CIFAR-10)
num_classes = int(np.unique(y_train).size)
print(" Spliting data CIFAR-10:")
print("  x_train:", x_train.shape, x_train.dtype)
print("  y_train:", y_train.shape, y_train.dtype)
print("  x_valid:", x_valid.shape, x_valid.dtype)
print("  y_valid:", y_valid.shape, y_valid.dtype)
print("  x_test :", x_test.shape, x_test.dtype)
print("  y_test :", y_test.shape, y_test.dtype)
print("Number of classes:", num_classes)

# Optional: quick visualization of a 10x10 grid like Fashion-MNIST
n = x_train.shape[0]
idx = np.random.permutation(n)[:100]

# x_train is flat (N, 3072); reshape to (N, 32, 32, 3) for plotting
images_flat = x_train[idx]              # (100, 3072)
labels = y_train[idx]                   # (100,)
images = images_flat.reshape(-1, 3, 32, 32)
images = np.transpose(images, (0, 2, 3, 1))  # (100, 32, 32, 3)

rows, columns = 10, 10
height, width = images.shape[1], images.shape[2]  # 32, 32
grid = np.zeros((rows * height, columns * width, 3), dtype=images.dtype)

for i in range(rows * columns):
    r = i // columns
    c = i % columns
    grid[r*height:(r+1)*height, c*width:(c+1)*width, :] = images[i]

plt.figure(figsize=(7, 7))
plt.imshow(grid)
plt.axis("off")
plt.show()

print("images_flat shape:", images_flat.shape)
print("images shape:", images.shape)
print("labels shape:", labels.shape)