In [1]:
from datasets import load_dataset, Dataset
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import DataLoader

def get_cifar_label_dicts(cifar_dataset:Dataset):
    """
    returns
        label2id: dict
        id2label: dict
    """
    # Get label mappings
    labels = cifar_dataset["train"].features["coarse_label"].names
    label2id = {str(label): str(i) for i, label in enumerate(labels)}
    id2label = {str(i): str(label) for i, label in enumerate(labels)}

    return label2id, id2label

def get_cifar_dataloaders(batch_size=2):
    # Load the CIFAR dataset
    cifar = load_dataset("uoft-cs/cifar100")
    cifar_train = cifar["train"]
    cifar_test = cifar["test"]

    # Define transforms
    _transforms = Compose([ToTensor()])  # Add more transforms as needed to prevent overfitting

    def preprocess_transforms(data_examples):
        data_examples["pixel_values"] = [_transforms(img) for img in data_examples["img"]]
        del data_examples["img"]
        return data_examples

    # Apply transformations
    cifar_train = cifar_train.with_transform(preprocess_transforms)
    cifar_test = cifar_test.with_transform(preprocess_transforms)

    # Create dataloaders
    train_dataloader = DataLoader(
        dataset=cifar_train,
        batch_size=batch_size,
        shuffle=True
    )

    test_dataloader = DataLoader(
        dataset=cifar_test,
        batch_size=batch_size,
        shuffle=True
    )

    return train_dataloader, test_dataloader


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dataloader, test_dataloader = get_cifar_dataloaders(batch_size=4)

Found cached dataset parquet (/zhome/57/8/181461/.cache/huggingface/datasets/uoft-cs___parquet/cifar100-775d1ef257a5c668/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
100%|██████████| 2/2 [00:00<00:00, 208.32it/s]


In [7]:
cifar = load_dataset("uoft-cs/cifar100")
label2id, id2label = get_cifar_label_dicts(cifar)

Found cached dataset parquet (/zhome/57/8/181461/.cache/huggingface/datasets/uoft-cs___parquet/cifar100-775d1ef257a5c668/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)
100%|██████████| 2/2 [00:00<00:00, 498.14it/s]


In [8]:
len(label2id.keys())

20

In [3]:
batch = next(iter(train_dataloader))

In [6]:
batch["pixel_values"].shape

torch.Size([4, 3, 32, 32])

In [9]:
from vit import VisionTransformer

vit = VisionTransformer(use_linear_patch=True, num_classes=20)

In [10]:
pred = vit(batch["pixel_values"])

RuntimeError: The size of tensor a (4) must match the size of tensor b (196) at non-singleton dimension 1

In [18]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")

In [None]:
from transformers import AutoImageProcessor

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")