Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the pixel values that comprise an image. There are many applications for image classification, such as detecting damage after a natural disaster, monitoring crop health, or helping screen medical images for signs of disease.

This guide illustrates how to:

1. Fine-tune ViT on the Food-101 dataset to classify a food item in an image.
2. Use your fine-tuned model for inference.

# Libraries

In [None]:
pip install transformers datasets evaluate accelerate

In [None]:
from datasets import load_dataset
from transformers import AutoImageProcessor, DefaultDataCollator
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

# Load Data

In [None]:
# Load a smaller subset for experimentation
food = load_dataset("food101", split="train[:5000]")

# Split the dataset’s train split into a train and test set 
food = food.train_test_split(test_size=0.2)

In [None]:
# Inspect example
# Each example in the dataset has two fields:
# image: a PIL image of the food item
# label: the label class of the food item
food["train"][0]

In [None]:
# create a dictionary that maps the label name to an integer and vice versa
# makes it easier for the model to get the label name from the label id,
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
# sanity checks: convert the label id to a label name and vice-versa
id2label[str(9)]

In [None]:
label2id['breakfast_burrito']

# Preprocessing

In [None]:
# Process image into a tensor
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [None]:
# Apply transformations to make the model more robust against overfitting
# torchvision.transforms module used here but any suitable library can be used
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

# Create a preprocessing function to apply transforms and return pixel_values to be used as model inputs
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

# Apply to dataset on the fly using with_transform
food = food.with_transform(transforms)

In [None]:
# create a batch of examples using DefaultDataCollator
# NB: DefaultDataCollator does not perform additional preprocessing such as dynamic padding
data_collator = DefaultDataCollator()