Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the pixel values that comprise an image. There are many applications for image classification, such as detecting damage after a natural disaster, monitoring crop health, or helping screen medical images for signs of disease.

This guide illustrates how to:

1. Fine-tune ViT on the Food-101 dataset to classify a food item in an image.
2. Use your fine-tuned model for inference.

# Libraries

In [1]:
pip install transformers datasets evaluate accelerate

Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset

# Load Data

In [3]:
# Load a smaller subset for experimentation
food = load_dataset("food101", split="train[:5000]")

# Split the dataset’s train split into a train and test set 
food = food.train_test_split(test_size=0.2)

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/490M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/464M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/472M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/464M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/475M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/470M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/478M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/486M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/423M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/413M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/426M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/75750 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/25250 [00:00<?, ? examples/s]

In [4]:
# Inspect example
# Each example in the dataset has two fields:
# image: a PIL image of the food item
# label: the label class of the food item
food["train"][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
 'label': 53}

In [5]:
# create a dictionary that maps the label name to an integer and vice versa
# makes it easier for the model to get the label name from the label id,
labels = food["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [7]:
# sanity checks: convert the label id to a label name and vice-versa
id2label[str(9)]

'breakfast_burrito'

In [8]:
label2id['breakfast_burrito']

'9'