In [1]:
!pip install -Uq datasets transformers[torch]

In [2]:
from functools import partial
from torchvision.datasets import Food101
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

preprocessor = partial(feature_extractor, return_tensors='pt')

train_ds = Food101(root="food101_dataset", split='train', transform=preprocessor, download=True)
test_ds = Food101(root="food101_dataset", split='test', transform=preprocessor)

labels = train_ds.classes



In [3]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x[0]['pixel_values'][0] for x in batch]),
        'labels': torch.tensor([x[1] for x in batch])
    }

In [4]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)


  metric = load_metric("accuracy")


In [5]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-food101",
  per_device_train_batch_size=64,
  per_device_eval_batch_size=64,
  evaluation_strategy="steps",
  num_train_epochs=1,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)


In [7]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=feature_extractor,
)

In [8]:
trainer.evaluate()

{'eval_loss': 4.633031368255615,
 'eval_accuracy': 0.009346534653465346,
 'eval_runtime': 440.1755,
 'eval_samples_per_second': 57.363,
 'eval_steps_per_second': 0.897}

In [None]:
# For interactive debugging
# !pip install -Uqq ipdb
# import ipdb
# %pdb on

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

Step,Training Loss,Validation Loss


In [None]:
trainer.evaluate()