In [None]:
from sklearn.model_selection import train_test_split
from collections import defaultdict
import os
from PIL import Image
from pathlib import Path
from datasets import load_dataset


In [None]:
output_root = Path("data/")
selected_classes = ["french_toast", "garlic_bread", "pretzel", "croissant"]
splits = ["train", "val", "test"]
split_ratio = 0.7

In [None]:
dataset = load_dataset("food101", split="train")
label_names = dataset.features["label"].names

 

In [18]:
print(f"Available classes: {label_names}")

Available classes: ['apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'm

In [13]:
from collections import defaultdict

class_images = defaultdict(list)
class_counts = defaultdict(int)
max_per_class = 200

for i, example in enumerate(dataset):
    label_id = example["label"]
    label_name = label_names[label_id]

    if label_name in selected_classes and class_counts[label_name] < max_per_class:
        img_id = f"{label_name}_{i}"
        class_images[label_name].append((example["image"], img_id))
        class_counts[label_name] += 1

    if all(class_counts[cls] >= max_per_class for cls in selected_classes):
        break




In [None]:
max_per_class = 100
for cls in selected_classes:
    items = class_images[cls][:max_per_class]

    if len(items) < 2:
        print(f"Skipped {cls} — not enough items ({len(items)})")
        continue

    train, temp = train_test_split(items, train_size=split_ratio, random_state=42)
    val, test = train_test_split(temp, test_size=0.5, random_state=42)

    split_data = {"train": train, "val": val, "test": test}

    for split in splits:
        out_dir = output_root / split / cls
        out_dir.mkdir(parents=True, exist_ok=True)

        for img, img_id in split_data[split]:
            img_path = out_dir / f"{img_id}.jpg"
            try:
                img.save(img_path)
            except Exception as e:
                print(f"Warrning {img_path}: {e}")


Skipped pretzel — not enough items (0)
Skipped croissant — not enough items (0)
