In [2]:
import pickle
import numpy as np
from PIL import Image
import os

# CIFAR-10 animal label indices
animal_labels = {2, 3, 4, 5, 6, 7}

def unpickle(file):
    with open(file, 'rb') as fo:
        return pickle.load(fo, encoding='bytes')

# Load class names
meta = unpickle("batches.meta")
label_names = [x.decode("utf-8") for x in meta[b"label_names"]]

def extract_animals(batch_file, output_dir):
    batch = unpickle(batch_file)
    data = batch[b"data"]
    labels = batch[b"labels"]
    filenames = batch[b"filenames"]

    for i in range(len(data)):
        label = labels[i]
        if label not in animal_labels:
            continue  # skip non-animals

        img = data[i].reshape(3, 32, 32)
        img = np.transpose(img, (1, 2, 0))

        class_name = label_names[label]
        class_dir = os.path.join(output_dir, class_name)
        os.makedirs(class_dir, exist_ok=True)

        fname = filenames[i].decode("utf-8")
        Image.fromarray(img).save(os.path.join(class_dir, fname))

# Extract from all batches
batches = [
    "data_batch_1",
    "data_batch_2",
    "data_batch_3",
    "data_batch_4",
    "data_batch_5",
    "test_batch"
]

output = "cifar10_animals"

for bf in batches:
    extract_animals(bf, output)
