# Data Processing for Simpsons-MNIST
This notebook loads the Simpsons-MNIST dataset, normalizes images, flattens them, and creates stratified train/validation/test splits.

In [5]:
import os
from typing import Tuple, List, Dict, Union

import numpy as np
from PIL import Image


In [6]:
def _load_images_from_folder(folder: Union[str, Dict[str, str]], mode: str):
    images: List[np.ndarray] = []
    if mode == 'both':
        return {m: _load_images_from_folder(p, m) for m, p in folder.items()}
    for file in sorted(os.listdir(folder)):
        if not file.lower().endswith('.jpg'):
            continue
        path = os.path.join(folder, file)
        img = Image.open(path)
        if mode == 'grayscale':
            img = img.convert('L')
        else:
            img = img.convert('RGB')
        arr = np.asarray(img, dtype=np.float32) / 255.0
        images.append(arr.flatten())
    return np.stack(images, axis=0)


def load_simpsons_mnist(base_dir: str,
                        mode: str = 'rgb',
                        val_ratio: float = 0.2,
                        seed: int = 42):
    rng = np.random.default_rng(seed)
    if mode == 'both':
        classes = sorted(d for d in os.listdir(os.path.join(base_dir, 'rgb', 'train')) if not d.startswith('.'))
        train_data = {'rgb': [], 'grayscale': []}
        test_data = {'rgb': [], 'grayscale': []}
        train_labels, test_labels = [], []
        for label, cls in enumerate(classes):
            cls_train = {
                'rgb': os.path.join(base_dir, 'rgb', 'train', cls),
                'grayscale': os.path.join(base_dir, 'grayscale', 'train', cls)
            }
            cls_test = {
                'rgb': os.path.join(base_dir, 'rgb', 'test', cls),
                'grayscale': os.path.join(base_dir, 'grayscale', 'test', cls)
            }
            imgs_train = _load_images_from_folder(cls_train, 'both')
            imgs_test = _load_images_from_folder(cls_test, 'both')
            for m in imgs_train:
                train_data[m].append(imgs_train[m])
                test_data[m].append(imgs_test[m])
            train_labels.append(np.full(imgs_train['rgb'].shape[0], label, dtype=np.int32))
            test_labels.append(np.full(imgs_test['rgb'].shape[0], label, dtype=np.int32))
        X = {m: np.vstack(train_data[m]) for m in train_data}
        y = np.concatenate(train_labels)
        X_test = {m: np.vstack(test_data[m]) for m in test_data}
        y_test = np.concatenate(test_labels)
        train_indices, val_indices = [], []
        for label in np.unique(y):
            idx = np.where(y == label)[0]
            rng.shuffle(idx)
            split = int(len(idx) * (1 - val_ratio))
            train_indices.extend(idx[:split])
            val_indices.extend(idx[split:])
        X_train = {m: X[m][train_indices] for m in X}
        X_val = {m: X[m][val_indices] for m in X}
        y_train, y_val = y[train_indices], y[val_indices]
        return (X_train, y_train), (X_val, y_val), (X_test, y_test), classes
    else:
        train_dir = os.path.join(base_dir, mode, 'train')
        test_dir = os.path.join(base_dir, mode, 'test')
        classes = sorted(d for d in os.listdir(train_dir) if not d.startswith('.'))
        train_data, train_labels, test_data, test_labels = [], [], [], []
        for label, cls in enumerate(classes):
            cls_train = os.path.join(train_dir, cls)
            cls_test = os.path.join(test_dir, cls)
            imgs_train = _load_images_from_folder(cls_train, mode)
            imgs_test = _load_images_from_folder(cls_test, mode)
            train_data.append(imgs_train)
            train_labels.append(np.full(imgs_train.shape[0], label, dtype=np.int32))
            test_data.append(imgs_test)
            test_labels.append(np.full(imgs_test.shape[0], label, dtype=np.int32))
        X = np.vstack(train_data)
        y = np.concatenate(train_labels)
        X_test = np.vstack(test_data)
        y_test = np.concatenate(test_labels)
        train_indices, val_indices = [], []
        for label in np.unique(y):
            idx = np.where(y == label)[0]
            rng.shuffle(idx)
            split = int(len(idx) * (1 - val_ratio))
            train_indices.extend(idx[:split])
            val_indices.extend(idx[split:])
        X_train, y_train = X[train_indices], y[train_indices]
        X_val, y_val = X[val_indices], y[val_indices]
        return (X_train, y_train), (X_val, y_val), (X_test, y_test), classes


In [7]:
(train_X, train_y), (val_X, val_y), (test_X, test_y), classes = load_simpsons_mnist('simpsons-mnist-0.1-rgb/dataset', mode='both', val_ratio=0.1)
print(f"Train RGB shape: {train_X['rgb'].shape}")
print(f"Train Grayscale shape: {train_X['grayscale'].shape}")
print(f"Validation RGB shape: {val_X['rgb'].shape}")
print(f"Validation Grayscale shape: {val_X['grayscale'].shape}")
print(f"Test RGB shape: {test_X['rgb'].shape}")
print(f"Test Grayscale shape: {test_X['grayscale'].shape}")
print(f"Number of classes: {len(classes)}")


Train RGB shape: (7200, 2352)
Train Grayscale shape: (7200, 784)
Validation RGB shape: (800, 2352)
Validation Grayscale shape: (800, 784)
Test RGB shape: (2000, 2352)
Test Grayscale shape: (2000, 784)
Number of classes: 10
