In [None]:
!pip install -q transformers datasets

## Preprocessing the data

In [14]:
from transformers import ViTImageProcessor 

feature_extractor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

image_mean = feature_extractor.image_mean
image_std = feature_extractor.image_std
image_height = feature_extractor.size['height']
image_width = feature_extractor.size['width']

In [15]:
from torchvision import transforms, datasets


train_transform = transforms.Compose([
    transforms.RandomResizedCrop((image_height, image_width)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std)
])

test_transform = transforms.Compose([
    transforms.Resize((image_height, image_width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=image_mean, std=image_std)
])


train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

len(train_dataset), len(test_dataset)

Files already downloaded and verified
Files already downloaded and verified


(50000, 10000)

## Define the model

In [17]:
label_name = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

id2label = {id:label for id, label in enumerate(label_name)}
label2id = {label:id for id,label in id2label.items()}
id2label

{0: 'airplane',
 1: 'automobile',
 2: 'bird',
 3: 'cat',
 4: 'deer',
 5: 'dog',
 6: 'frog',
 7: 'horse',
 8: 'ship',
 9: 'truck'}

In [18]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=len(label_name),
                                                  id2label=id2label,
                                                  label2id=label2id)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
from transformers import TrainingArguments, Trainer

metric_name = "accuracy"

args = TrainingArguments(
    f"test-cifar-10",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

In [20]:
from datasets import load_metric
import numpy as np

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [36]:
import torch

def collate_fn(examples):
    # example => Tuple(image, label)
    pixel_values = torch.stack([example[0] for example in examples])
    labels = torch.tensor([example[1] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

## Train the model

In [37]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.3951,0.166551,0.9793
2,0.2733,0.099381,0.9834
3,0.2345,0.083667,0.9855


TrainOutput(global_step=4689, training_loss=0.4003512040209633, metrics={'train_runtime': 2975.2438, 'train_samples_per_second': 50.416, 'train_steps_per_second': 1.576, 'total_flos': 1.16246318856192e+19, 'train_loss': 0.4003512040209633, 'epoch': 3.0})

## Evaluation

Finally, let's evaluate the model on the test set:

In [38]:
outputs = trainer.predict(test_dataset)

In [39]:
print(outputs.metrics)

{'test_loss': 0.08366703242063522, 'test_accuracy': 0.9855, 'test_runtime': 65.0171, 'test_samples_per_second': 153.806, 'test_steps_per_second': 4.814}
