In [15]:
# NOTE: Make sure you install pytorch, torchvision etc. before running this script
!pip install -r requirements.txt

Collecting evaluate (from -r requirements.txt (line 3))
  Using cached evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Collecting responses<0.19 (from evaluate->-r requirements.txt (line 3))
  Using cached responses-0.18.0-py3-none-any.whl (38 kB)
Using cached evaluate-0.4.1-py3-none-any.whl (84 kB)
Installing collected packages: responses, evaluate
Successfully installed evaluate-0.4.1 responses-0.18.0


In [None]:
from transformers import AutoImageProcessor, AutoModelForImageClassification, Trainer, TrainingArguments
from datasets import load_dataset, load_metric
from torchvision.transforms import (
    Compose,
    Normalize,
    RandomHorizontalFlip,
    Resize,
    ToTensor,
)
import torch
import numpy as np

In [None]:
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
batch_size = 8
out_dir = "./out"

In [None]:
# load the dataset

ds = load_dataset("food101", split="train[:3000]")
splits = ds.train_test_split(test_size=0.2)

train_ds = splits['train']
val_ds = splits['test']


labels = train_ds.features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
print(train_ds)
print(val_ds)

In [None]:
print(train_ds[0])

In [None]:
display(train_ds[0].get("image").resize((224, 224)))
id2label[str(train_ds[0].get("label"))]

In [None]:
# Define the image transforms

image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (image_processor.size["height"], image_processor.size["width"])

train_transforms = Compose(
        [
            Resize(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)


In [None]:
train_ds[0]

In [None]:
# Load the ViT model pre-traiend on ImageNet22K
model = AutoModelForImageClassification.from_pretrained(model_checkpoint, label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True)

In [None]:
# Set training arguments
train_args = TrainingArguments(
    output_dir='./out',
    seed=42,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    num_train_epochs=3,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    logging_steps=10,
    weight_decay=0.01,
    logging_dir='./logs',
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    remove_unused_columns=False
)


In [None]:
# collate_fn to be used when batching data
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
metric = load_metric('accuracy')
# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
metrics = trainer.evaluate(val_ds)
trainer.log_metrics("eval", metrics)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(val_ds)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)