In [1]:
!pip install datasets
!pip install transformers
!pip install evaluate
from datasets import load_dataset



In [2]:
food = load_dataset("food101", split="train[:5000]")
#split dataset
food = food.train_test_split(test_size=0.2)
#example
food["train"][0]

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512>,
 'label': 79}

In [3]:
labels = food["train"].features["label"].names

In [4]:
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [5]:
id2label[str(79)]

'prime_rib'

In [6]:
from transformers import AutoImageProcessor

In [7]:
checkpoint = "google/vit-base-patch16-224-in21k"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [8]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

In [9]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [10]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

In [11]:
food = food.with_transform(transforms)

In [12]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [13]:
import evaluate

accuracy = evaluate.load("accuracy")

In [14]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [15]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
pip install accelerate -U



In [17]:
!pip install transformers[torch]
training_args = TrainingArguments(
    output_dir="food_model",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)



In [18]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=food["train"],
    eval_dataset=food["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

In [19]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
0,2.673,2.509524,0.81
2,1.6104,1.624572,0.88


TrainOutput(global_step=186, training_loss=2.420728478380429, metrics={'train_runtime': 523.071, 'train_samples_per_second': 22.941, 'train_steps_per_second': 0.356, 'total_flos': 9.232831524962304e+17, 'train_loss': 2.420728478380429, 'epoch': 2.98})

In [20]:
trainer.save_model(output_dir="best_food")

In [21]:
ds = load_dataset("food101", split="validation[:10]")
image = ds["image"][0]

In [22]:
from transformers import pipeline
classifier = pipeline("image-classification", model="best_food")
classifier(image)

[{'score': 0.28554975986480713, 'label': 'beignets'},
 {'score': 0.01596302166581154, 'label': 'chicken_wings'},
 {'score': 0.015536390244960785, 'label': 'bruschetta'},
 {'score': 0.01518327184021473, 'label': 'prime_rib'},
 {'score': 0.014482281170785427, 'label': 'hamburger'}]

In [23]:
from transformers import AutoImageProcessor
import torch

image_processor = AutoImageProcessor.from_pretrained("best_food")
inputs = image_processor(image, return_tensors="pt")

In [24]:
from transformers import AutoModelForImageClassification

model = AutoModelForImageClassification.from_pretrained("best_food")
with torch.no_grad():
    logits = model(**inputs).logits

In [25]:
predicted_label = logits.argmax(-1).item()
model.config.id2label[predicted_label]

'beignets'