# transformers: Image classification

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from datasets import load_dataset
from transformers import (
    pipeline,
    AutoImageProcessor,
    AutoModelForImageClassification
)

## Load data

In [None]:
# load data
# ds = load_dataset('beans')

ds = load_dataset('food101', split='train[:5000]')
ds = ds.train_test_split(test_size=0.2)

print(ds)

In [None]:
# get label names
label_names = ds['train'].features['label'].names

print(label_names)

In [None]:
# show example images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(6, 5))
random_ids = np.random.choice(len(ds['train']), size=axes.size, replace=False).tolist()
for random_idx, ax in zip(random_ids, axes.ravel()):
    pil_image = ds['train'][random_idx]['image']
    label_idx = ds['train'][random_idx]['label']
    ax.imshow(np.asarray(pil_image))
    ax.set_title(label_names[label_idx])
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Load model

In [None]:
# set model name
model_name = 'google/vit-base-patch16-224'
# model_name = 'facebook/dinov2-small-imagenet1k-1-layer'

In [None]:
# create preprocessor
processor = AutoImageProcessor.from_pretrained(model_name)

# load model (trained on a different dataset)
model = AutoModelForImageClassification.from_pretrained(model_name, device_map='auto')
model = model.eval()

In [None]:
# load pipeline (preprocessor, model and postprocessor)
pipe = pipeline('image-classification', model=model_name, device_map='auto')

## Run model

In [None]:
# get batch of data
batch_size = 16

batch_dict = ds['train'][:batch_size]

images = batch_dict['image']  # list of PIL images
labels = batch_dict['label']  # list of integers

In [None]:
# preprocess images
preprocessed_images = processor(images, return_tensors='pt')
x = preprocessed_images['pixel_values']

# run model
with torch.no_grad():
    outputs = model(x.to(model.device))

logits = outputs.logits.cpu()

print(f'Images shape: {x.shape}')
print(f'Logits shape: {logits.shape}')

In [None]:
# get predicted labels
label_ids = logits.argmax(dim=-1)
labels = [model.config.id2label[lidx.item()] for lidx in label_ids]

print(labels)

In [None]:
# show predictions
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(6, 5))
for idx, ax in enumerate(axes.ravel()):
    image = np.asarray(images[idx])
    label = labels[idx]
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.suptitle('Predictions')
fig.tight_layout()

## Run pipeline

In [None]:
# run pipeline
results = pipe(images)

print(results)

In [None]:
# run pipeline parts separately
preprocessed_image = pipe.preprocess(images[0])  # can process only a single input
output = pipe.forward(preprocessed_image)
postprocessed_result = pipe.postprocess(output)

print(postprocessed_result)