# Sub cifar10 dataset
This notebook creates a balanced subset of cifar 10 dataset, with 10 random samples for each class. 

The purpose of this dataset is to test an analyse the classification method proposed in Your Diffusion Model is Secretely a  Zero-Shot Classifier.

In [None]:
# HF login
from huggingface_hub import login

# Replace 'your_huggingface_token' with the actual token you generated
huggingface_token = ""

# Login using the token
login(token=huggingface_token)

In [None]:
from diffusion.utils import DATASET_ROOT
from torchvision import datasets
from collections import defaultdict
import random

# Load CIFAR
c10 = datasets.CIFAR10(root=DATASET_ROOT, train=False, transform=None,
                                   target_transform=None, download=True)

# Initialize a dictionary to store indices by class
class_indices = defaultdict(list)

# Populate the dictionary with indices
for idx, (_, target) in enumerate(c10):
    class_indices[target].append(idx)

# Set the number of examples per class
examples_per_class = 100 // 10  # 10 classes in CIFAR-10

# Sample indices from each class
balanced_sample_indices = []
for indices in class_indices.values():
    balanced_sample_indices.extend(random.sample(indices, examples_per_class))

# Retrieve the balanced sample
balanced_sample = [c10[idx] for idx in balanced_sample_indices]

In [None]:
cifar10_class_dict = {
    0: "Airplane",
    1: "Automobile",
    2: "Bird",
    3: "Cat",
    4: "Deer",
    5: "Dog",
    6: "Frog",
    7: "Horse",
    8: "Ship",
    9: "Truck"
}

In [None]:
for i in range(10):
    class_examples = balanced_sample[10*i:10*(i+1)]
    class_name = cifar10_class_dict[class_examples[0][1]]
    print(f'Class: {class_name}')
    display(image_grid([x[0].resize((128,128)) for x in class_examples], 1, examples_per_class))

In [None]:
import os
import shutil
#from torchvision.utils import save_image

# Define the directory structure
dataset_dir = 'sub_cifar10'
if os.path.exists(dataset_dir):
    shutil.rmtree(dataset_dir)
os.makedirs(dataset_dir, exist_ok=True)

for idx, (img, target) in enumerate(balanced_sample):
    class_dir = os.path.join(dataset_dir, cifar10_class_dict[target])
    os.makedirs(class_dir, exist_ok=True)
    #save_image(img, os.path.join(class_dir, f'{idx}.png'))
    img.save(os.path.join(class_dir, f'{idx}.png'))

In [None]:
from huggingface_hub import HfApi

# Initialize the Hugging Face API
api = HfApi()

# Create a new dataset repository
repo_name = "sub-cifar10"
repo_url = api.create_repo(repo_name, repo_type="dataset", exist_ok=True)

print(f"Dataset created at: {repo_url}")