In [1]:
import os
import numpy as np
from PIL import Image
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

import ssl
import certifi

# Correct way to set the default SSL context to use certifi's CA bundle
def create_custom_https_context(*args, **kw):
    context = ssl.create_default_context(*args, cafile=certifi.where(), **kw)
    return context

ssl._create_default_https_context = create_custom_https_context

# Load CIFAR-10 dataset using torchvision
def load_cifar10():
    transform = transforms.Compose([transforms.ToTensor()])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    
    # Convert datasets to NumPy arrays
    train_images = np.array([trainset[i][0].numpy() for i in range(len(trainset))])
    train_labels = np.array([trainset[i][1] for i in range(len(trainset))])
    test_images = np.array([testset[i][0].numpy() for i in range(len(testset))])
    test_labels = np.array([testset[i][1] for i in range(len(testset))])
    
    # Reshape images from CxHxW to HxWxC
    train_images = np.transpose(train_images, (0, 2, 3, 1))
    test_images = np.transpose(test_images, (0, 2, 3, 1))
    
    return (train_images * 255).astype(np.uint8), train_labels, (test_images * 255).astype(np.uint8), test_labels

# Function to create a prototype dataset
def create_prototype_dataset(images, labels, prototype_size):
    num_classes = 10
    per_class = prototype_size // num_classes
    indices = []
    for class_id in range(num_classes):
        class_indices = np.where(labels == class_id)[0]
        class_indices = np.random.choice(class_indices, per_class, replace=False)
        indices.extend(class_indices)
    np.random.shuffle(indices)
    prototype_images = images[indices]
    prototype_labels = labels[indices]
    return prototype_images, prototype_labels

# Function to create directories for dataset
def create_directories(base_dir, class_names):
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    for class_name in class_names:
        class_dir = os.path.join(base_dir, class_name)
        if not os.path.exists(class_dir):
            os.makedirs(class_dir)

# Function to save images
def save_images(images, labels, base_dir, class_names):
    for idx, (image, label) in enumerate(zip(images, labels)):
        class_name = class_names[label]
        file_name = f'{idx}.png'
        file_path = os.path.join(base_dir, class_name, file_name)
        img = Image.fromarray(image)
        img.save(file_path)

# Main code
if __name__ == "__main__":
    train_images, train_labels, test_images, test_labels = load_cifar10()

    # Create prototype datasets
    prototype_train_size = 5000
    prototype_test_size = 1000
    proto_train_images, proto_train_labels = create_prototype_dataset(train_images, train_labels, prototype_train_size)
    proto_test_images, proto_test_labels = create_prototype_dataset(test_images, test_labels, prototype_test_size)

    # Split the prototype training set into training and validation sets
    proto_train_images, proto_val_images, proto_train_labels, proto_val_labels = train_test_split(
        proto_train_images, proto_train_labels, test_size=0.2, stratify=proto_train_labels)

    # Define class names and base directories
    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    train_dir = './prototype_generated/train'
    val_dir = './prototype_generated/validation'
    test_dir = './prototype_generated/test'

    # Create directories and save prototype images
    create_directories(train_dir, class_names)
    create_directories(val_dir, class_names)
    create_directories(test_dir, class_names)
    save_images(proto_train_images, proto_train_labels, train_dir, class_names)
    save_images(proto_val_images, proto_val_labels, val_dir, class_names)
    save_images(proto_test_images, proto_test_labels, test_dir, class_names)


  from .autonotebook import tqdm as notebook_tqdm


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 54858389.23it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
