In [1]:
!rm -rf /kaggle/working/*

In [None]:
!git clone https://github.com/EfficientTraining/LabelBench.git

In [None]:
!pip install -r /kaggle/working/LabelBench/requirements.txt

In [None]:
import kagglehub

path = kagglehub.dataset_download("cyizhuo/cub-200-2011-by-classes-folder")
print(path)

!mkdir -p ./data
!cp -r /kaggle/input/cub-200-2011-by-classes-folder ./data/

In [None]:
!mkdir -p ./data/notmnist
!cp -r /kaggle/input/notmnist/notMNIST_small ../data/notmnist/


In [6]:
%cd /kaggle/working/LabelBench


/kaggle/working/LabelBench


In [7]:
ls

[0m[01;34mconfigs[0m/        [01;34mLabelBench[0m/  mp_eval_launcher.py  README.md
[01;34mdocs[0m/           LICENSE      mp_launcher.py       requirements.txt
example_run.sh  main.py      point_evaluation.py  [01;34mresults[0m/


### TINYIMAGENET

In [8]:
!cat << 'EOF' > LabelBench/dataset/dataset_impl/tinyimagenet_dataset.py
import os
import zipfile
import urllib.request
import shutil
from torchvision import transforms
from torchvision.datasets import ImageFolder
from LabelBench.skeleton.dataset_skeleton import register_dataset, LabelType, TransformDataset
import torch
import torch.nn.functional as F


URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"

def one_hot_target(y):
    return F.one_hot(torch.tensor(y), num_classes=200).float()


def is_image_file(path):
    return path.lower().endswith((".jpg", ".jpeg", ".png"))


def download_and_prepare(data_dir):
    root = os.path.join(data_dir, "tiny-imagenet-200")
    zip_path = os.path.join(data_dir, "tiny-imagenet-200.zip")

    if os.path.exists(root):
        return root

    os.makedirs(data_dir, exist_ok=True)

    print("Downloading TinyImageNet...")
    urllib.request.urlretrieve(URL, zip_path)

    print("Extracting TinyImageNet...")
    with zipfile.ZipFile(zip_path, "r") as z:
        z.extractall(data_dir)

    # -------- FIX TRAIN --------
    print("Fixing train folder structure...")
    train_dir = os.path.join(root, "train")
    for cls in os.listdir(train_dir):
        cls_path = os.path.join(train_dir, cls)
        images_dir = os.path.join(cls_path, "images")
        if os.path.isdir(images_dir):
            for img in os.listdir(images_dir):
                shutil.move(
                    os.path.join(images_dir, img),
                    os.path.join(cls_path, img),
                )
            shutil.rmtree(images_dir)

    # -------- FIX VAL --------
    print("Fixing val folder structure...")
    val_dir = os.path.join(root, "val")
    img_dir = os.path.join(val_dir, "images")
    ann_file = os.path.join(val_dir, "val_annotations.txt")

    with open(ann_file, "r") as f:
        annotations = [line.strip().split("\t") for line in f]

    for img, cls, *_ in annotations:
        cls_dir = os.path.join(val_dir, cls)
        os.makedirs(cls_dir, exist_ok=True)
        shutil.move(os.path.join(img_dir, img), os.path.join(cls_dir, img))

    shutil.rmtree(img_dir)
    os.remove(ann_file)

    print("TinyImageNet ready at:", root)
    return root


@register_dataset("tinyimagenet", LabelType.MULTI_CLASS)
def get_tinyimagenet_dataset(data_dir, *args):
    root = download_and_prepare(data_dir)

    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(64),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_tf = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
    ])

    train_ds = ImageFolder(os.path.join(root, "train"))
    val_ds = ImageFolder(os.path.join(root, "val"))

    return (
        TransformDataset(train_ds,transform=train_tf,target_transform=one_hot_target),
        TransformDataset(val_ds, transform=test_tf),
        TransformDataset(val_ds, transform=test_tf),
        None, None, None,
        200,
        train_ds.classes
    )



### CUB200

In [9]:
!cat << 'EOF' > LabelBench/dataset/dataset_impl/cub200_dataset.py
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import ImageFolder
from LabelBench.skeleton.dataset_skeleton import register_dataset, LabelType, TransformDataset
NUM_CLASSES = 200

def one_hot_target(y):
    return F.one_hot(torch.tensor(y), num_classes=NUM_CLASSES).float()

@register_dataset("cub200", LabelType.MULTI_CLASS)
def get_cub200_dataset(data_dir, *args):
    # root = os.path.join(data_dir, "cub-200-2011-by-classes-folder")
    data_dir = os.path.abspath(data_dir)
    root = os.path.join(data_dir, "cub-200-2011-by-classes-folder")

    train_dir = os.path.join(root, "train")
    test_dir  = os.path.join(root, "test")

    if not os.path.isdir(train_dir):
        raise RuntimeError(f"Train directory not found: {train_dir}")
    if not os.path.isdir(test_dir):
        raise RuntimeError(f"Test directory not found: {test_dir}")

    train_tf = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_tf = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])

    train_ds = ImageFolder(train_dir)
    test_ds  = ImageFolder(test_dir)

    return (
        TransformDataset(train_ds, transform=train_tf, target_transform=one_hot_target),
        TransformDataset(test_ds,  transform=test_tf,  target_transform=one_hot_target),  # val
        TransformDataset(test_ds,  transform=test_tf,  target_transform=one_hot_target),  # test
        None, None, None,
        NUM_CLASSES,
        train_ds.classes
    )



### SPLITCIFAR100

In [10]:
!cat << 'EOF' > LabelBench/dataset/dataset_impl/splitcifar100_dataset.py
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset
from LabelBench.skeleton.dataset_skeleton import register_dataset, LabelType, TransformDataset

# CIFAR-100 → 20 tasks, 5 classes each
SPLITS = {
    i: list(range(i * 5, (i + 1) * 5))
    for i in range(20)
}


class SplitCIFAR100(Dataset):
    """
    PURE FILTER DATASET
    - returns (x, y)
    - NO transforms
    - NO one-hot
    """

    def __init__(self, base_ds, allowed_classes):
        self.base_ds = base_ds
        self.allowed = allowed_classes

        self.indices = [
            i for i, (_, y) in enumerate(base_ds) if y in allowed_classes
        ]

        self.class_map = {c: i for i, c in enumerate(allowed_classes)}

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        x, y = self.base_ds[self.indices[idx]]
        y = self.class_map[y]
        return x, y


def one_hot(y, n):
    return F.one_hot(torch.tensor(y), num_classes=n).float()


@register_dataset("splitcifar100", LabelType.MULTI_CLASS)
def get_splitcifar100(_):
    raise RuntimeError("Use splitcifar100_<id>, e.g. splitcifar100_0")


# register splitcifar100_0 ... splitcifar100_19
for split_id, classes in SPLITS.items():

    @register_dataset(f"splitcifar100_{split_id}", LabelType.MULTI_CLASS)
    def _make_split(data_dir, classes=classes):

        train_tf = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        test_tf = transforms.Compose([
            transforms.ToTensor(),
        ])

        base_train = datasets.CIFAR100(
            root=data_dir, train=True, download=True
        )
        base_test = datasets.CIFAR100(
            root=data_dir, train=False, download=True
        )

        train_ds = SplitCIFAR100(base_train, classes)
        test_ds  = SplitCIFAR100(base_test, classes)

        n_cls = len(classes)  # = 5

        return (
            TransformDataset(
                train_ds,
                transform=train_tf,
                target_transform=lambda y: one_hot(y, n_cls),
            ),
            TransformDataset(
                test_ds,
                transform=test_tf,
                target_transform=lambda y: one_hot(y, n_cls),
            ),
            TransformDataset(
                test_ds,
                transform=test_tf,
                target_transform=lambda y: one_hot(y, n_cls),
            ),
            None, None, None,
            n_cls,
            [str(c) for c in classes],
        )




### FASHIONMNIST

In [11]:
!cat << 'EOF' > LabelBench/dataset/dataset_impl/fashionmnist_dataset.py
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from LabelBench.skeleton.dataset_skeleton import register_dataset, LabelType, TransformDataset

def one_hot(y, n):
    return F.one_hot(torch.tensor(y), num_classes=n).float()

@register_dataset("fashionmnist", LabelType.MULTI_CLASS)
def get_fashionmnist(data_dir):

    tf = transforms.Compose([
        transforms.ToTensor(),
    ])

    train_ds = datasets.FashionMNIST(
        root=data_dir, train=True, download=True
    )
    test_ds = datasets.FashionMNIST(
        root=data_dir, train=False, download=True
    )

    n_cls = 10

    return (
        TransformDataset(train_ds, transform=tf,
                         target_transform=lambda y: one_hot(y, n_cls)),
        TransformDataset(test_ds, transform=tf,
                         target_transform=lambda y: one_hot(y, n_cls)),
        TransformDataset(test_ds, transform=tf,
                         target_transform=lambda y: one_hot(y, n_cls)),
        None, None, None,
        n_cls,
        train_ds.classes,
    )




### NOTMNIST

In [12]:
!cat << 'EOF' > LabelBench/dataset/dataset_impl/notmnist_dataset.py
import os
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import ImageFolder
from LabelBench.skeleton.dataset_skeleton import register_dataset, LabelType, TransformDataset


def one_hot(y, n):
    return F.one_hot(torch.tensor(y), num_classes=n).float()


@register_dataset("notmnist", LabelType.MULTI_CLASS)
def get_notmnist_dataset(data_dir):
    """
    Expected structure:
    data_dir/notmnist/notMNIST_small/
        A/
        B/
        ...
        J/
    """
    root = os.path.join(data_dir, "notmnist")

    if not os.path.isdir(root):
        raise RuntimeError(f"NOTMNIST not found at {root}")

    train_tf = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])

    test_tf = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])

    base_ds = ImageFolder(root)
    num_classes = len(base_ds.classes)

    return (
        TransformDataset(
            base_ds,
            transform=train_tf,
            target_transform=lambda y: one_hot(y, num_classes),
        ),
        TransformDataset(
            base_ds,
            transform=test_tf,
            target_transform=lambda y: one_hot(y, num_classes),
        ),
        TransformDataset(
            base_ds,
            transform=test_tf,
            target_transform=lambda y: one_hot(y, num_classes),
        ),
        None, None, None,
        num_classes,
        base_ds.classes,
    )



In [13]:
# !echo "import LabelBench.dataset.dataset_impl.splitcifar10_dataset" >> LabelBench/dataset/datasets.py
!echo "import LabelBench.dataset.dataset_impl.cub200_dataset" >> LabelBench/dataset/datasets.py
!echo "import LabelBench.dataset.dataset_impl.splitcifar100_dataset" >> LabelBench/dataset/datasets.py
!echo "import LabelBench.dataset.dataset_impl.fashionmnist_dataset" >> LabelBench/dataset/datasets.py
!echo "import LabelBench.dataset.dataset_impl.notmnist_dataset" >> LabelBench/dataset/datasets.py

### TEST SCRIPT OF DATALOADER

In [14]:
# test_all_datasets.py
!cat << 'EOF' > test_all_datasets.py
from LabelBench.dataset.datasets import get_dataset
from torch.utils.data import DataLoader


def test_dataset(name, data_dir="../data"):
    print("\n" + "=" * 50)
    print(f"Testing dataset: {name}")
    print("=" * 50)

    try:
        dataset = get_dataset(name, data_dir)

        print("ALDataset created ✓")
        print("Number of classes:", dataset.get_num_classes())

        train_ds = dataset.train_dataset
        val_ds   = dataset.val_dataset
        test_ds  = dataset.test_dataset

        print("Train size:", len(train_ds))
        print("Val size:", len(val_ds))
        print("Test size:", len(test_ds))

        # ---- single sample check ----
        x, y = train_ds[0]
        print("Single image shape:", x.shape)
        print("Single label shape:", y.shape)

        # ---- dataloader check ----
        loader = DataLoader(
            train_ds,
            batch_size=8,
            shuffle=True,
            num_workers=2
        )
        bx, by = next(iter(loader))
        print("Batch image shape:", bx.shape)
        print("Batch label shape:", by.shape)

        print(f"{name} ✅ DATASET + DATALOADER WORKING")

    except Exception as e:
        print(f"{name} ❌ FAILED")
        print("Reason:", e)


# ==============================
# Run all dataset checks
# ==============================
test_dataset("cifar10")
test_dataset("cifar100")
test_dataset("tinyimagenet")
test_dataset("cub200")
test_dataset("fashionmnist")
test_dataset("notmnist")
test_dataset("splitcifar100_0")
test_dataset("splitcifar100_1")
test_dataset("splitcifar100_19")



Testing dataset: cifar10


100%|██████████| 170M/170M [00:06<00:00, 28.2MB/s] 


ALDataset created ✓
Number of classes: 10
Train size: 50000
Val size: 5000
Test size: 5000
Single image shape: torch.Size([3, 32, 32])
Single label shape: torch.Size([10])
Batch image shape: torch.Size([8, 3, 32, 32])
Batch label shape: torch.Size([8, 10])
cifar10 ✅ DATASET + DATALOADER WORKING

Testing dataset: cifar100


100%|██████████| 169M/169M [00:08<00:00, 18.9MB/s] 


ALDataset created ✓
Number of classes: 100
Train size: 50000
Val size: 5000
Test size: 5000
Single image shape: torch.Size([3, 32, 32])
Single label shape: torch.Size([100])
Batch image shape: torch.Size([8, 3, 32, 32])
Batch label shape: torch.Size([8, 100])
cifar100 ✅ DATASET + DATALOADER WORKING

Testing dataset: tinyimagenet
Downloading TinyImageNet...
Extracting TinyImageNet...
Fixing train folder structure...
Fixing val folder structure...
TinyImageNet ready at: ../data/tiny-imagenet-200
ALDataset created ✓
Number of classes: 200
Train size: 100000
Val size: 10000
Test size: 10000
Single image shape: torch.Size([3, 64, 64])
Single label shape: torch.Size([200])
Batch image shape: torch.Size([8, 3, 64, 64])
Batch label shape: torch.Size([8, 200])
tinyimagenet ✅ DATASET + DATALOADER WORKING

Testing dataset: cub200
ALDataset created ✓
Number of classes: 200
Train size: 5994
Val size: 5794
Test size: 5794
Single image shape: torch.Size([3, 224, 224])
Single label shape: torch.Size([

100%|██████████| 26.4M/26.4M [00:00<00:00, 113MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 3.94MB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 57.7MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 7.25MB/s]


ALDataset created ✓
Number of classes: 10
Train size: 60000
Val size: 10000
Test size: 10000
Single image shape: torch.Size([1, 28, 28])
Single label shape: torch.Size([10])
Batch image shape: torch.Size([8, 1, 28, 28])
Batch label shape: torch.Size([8, 10])
fashionmnist ✅ DATASET + DATALOADER WORKING

Testing dataset: notmnist
notmnist ❌ FAILED
Reason: Couldn't find any class folder in ../data/notmnist.

Testing dataset: splitcifar100_0
ALDataset created ✓
Number of classes: 5
Train size: 2500
Val size: 500
Test size: 500
Single image shape: torch.Size([3, 32, 32])
Single label shape: torch.Size([5])
Batch image shape: torch.Size([8, 3, 32, 32])
Batch label shape: torch.Size([8, 5])
splitcifar100_0 ✅ DATASET + DATALOADER WORKING

Testing dataset: splitcifar100_1
ALDataset created ✓
Number of classes: 5
Train size: 2500
Val size: 500
Test size: 500
Single image shape: torch.Size([3, 32, 32])
Single label shape: torch.Size([5])
Batch image shape: torch.Size([8, 3, 32, 32])
Batch label s

In [15]:
# from LabelBench.dataset.datasets import get_dataset
# from torch.utils.data import DataLoader

# dataset = get_dataset("notmnist", "../data")

# print("ALDataset created ✓")
# print("Number of classes:", dataset.get_num_classes())

# train_ds = dataset.train_dataset
# x, y = train_ds[0]
# print("Single image shape:", x.shape)
# print("Single label shape:", y.shape)

# bx, by = next(iter(DataLoader(train_ds, batch_size=8)))
# print("Batch image shape:", bx.shape)
# print("Batch label shape:", by.shape)

# print("notmnist ✅ DATASET + DATALOADER WORKING")

In [16]:
# from LabelBench.dataset.datasets import get_dataset
# from torch.utils.data import DataLoader

# ds = get_dataset("fashionmnist", "./data")

# print("Classes:", ds.get_num_classes())
# x, y = ds.train_dataset[0]
# print(x.shape, y.shape)

# bx, by = next(iter(DataLoader(ds.train_dataset, batch_size=8)))
# print(bx.shape, by.shape)

# print("fashionmnist ✅ WORKING")

In [17]:
# from LabelBench.dataset.datasets import get_dataset
# from torch.utils.data import DataLoader

# dataset = get_dataset("cub200", "../data")
# train_ds = dataset.train_dataset

# print("Train length:", len(train_ds))
# x, y = train_ds[0]
# print("Single image:", x.shape)
# print("Single label:", y.shape)

# bx, by = next(iter(DataLoader(train_ds, batch_size=8)))
# print("Batch images:", bx.shape)
# print("Batch labels:", by.shape)

# print("✅ CUB200 DATASET + DATALOADER WORKING")

In [18]:
# !python - << 'EOF'
# from LabelBench.dataset.datasets import get_dataset
# from torch.utils.data import DataLoader

# dataset = get_dataset("tinyimagenet", "./data")
# train_ds = dataset.train_dataset

# print("Length:", len(train_ds))

# x, y = train_ds[0]
# print("Single image:", x.shape)
# print("Single label:", y.shape)

# bx, by = next(iter(DataLoader(train_ds, batch_size=8)))
# print("Batch images:", bx.shape)
# print("Batch labels:", by.shape)

# print("✅ TinyImageNet COMPLETELY FIXED")

In [19]:
# from LabelBench.dataset.datasets import get_dataset
# from torch.utils.data import DataLoader

# dataset = get_dataset("splitcifar10_0", "./data")

# print("ALDataset created ✓")
# print("Number of classes:", dataset.get_num_classes())
# train_ds = dataset.train_dataset
# x, y = train_ds[0]
# print("Single image:", x.shape)
# print("Single label:", y.shape)

# bx, by = next(iter(DataLoader(train_ds, batch_size=8)))
# print("Batch images:", bx.shape)
# print("Batch labels:", by.shape)

# print("✅ SPLIT CIFAR-10 WORKING")