# Referências
* https://www.pinecone.io/learn/series/image-search/vision-transformers/
* https://huggingface.co/google/vit-base-patch16-224
* https://huggingface.co/docs/datasets/image_load
* https://towardsdatascience.com/image-classification-with-vision-transformer-8bfde8e541d4

In [None]:
from datasets import load_dataset
import evaluate
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import TrainingArguments, Trainer

In [None]:
# verify that we are using a GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Create dataset
Here, we create a dataset using the datasets library.

In [None]:
user = os.path.expanduser("~")
dataset_path = os.path.join(user, "datasets", "cats_and_dogs")

In [None]:
dataset = load_dataset("imagefolder", data_dir=dataset_path)

# create splits
dataset = dataset["train"].train_test_split(test_size=0.1)

In [None]:
dataset

In [None]:
# check how many labels/number of classes
num_classes = len(set(dataset["train"]['label']))
labels = dataset["train"].features['label'].names
label2id = {label: i for i, label in enumerate(labels)}
id2label = {y:x for x,y in label2id.items()}
num_classes, labels

# Processor

In [None]:
model_name_or_path = 'google/vit-base-patch16-224-in21k'
processor = ViTImageProcessor.from_pretrained(model_name_or_path)

In [None]:
processor

In [None]:
def preprocess(batch):
    # take a list of PIL images and turn them to pixel values
    inputs = processor(
        batch['image'],
        return_tensors='pt'
    )
    # include the labels
    inputs['label'] = batch['label']
    return inputs

In [None]:
# apply to dataset
prepared_train = dataset["train"].with_transform(preprocess)
prepared_test = dataset["test"].with_transform(preprocess)

# Loading Model

In [None]:
model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=num_classes,
    id2label=id2label,
    label2id=label2id
)

# Collate function

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

# Define metrics

In [None]:
metric = evaluate.load("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

# Training Arguments

In [None]:
lr = 0.0001
batch = 8
epochs = 10

training_args = TrainingArguments(
  output_dir="models",
  per_device_train_batch_size=batch,
  evaluation_strategy="steps",
  num_train_epochs=epochs,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=lr,
  save_total_limit=5,
  remove_unused_columns=False,
  push_to_hub=False,
  load_best_model_at_end=True,
)

# Train Model

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=prepared_train,
    eval_dataset=prepared_test,
    tokenizer=processor,
    compute_metrics=compute_metrics,
)

In [None]:
train_results = trainer.train()

In [None]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()