<a href="https://colab.research.google.com/github/mosh98/ViT_fine_tuing/blob/main/ViT_finetuing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Simplest way to fine tune a ViT**

In [None]:
!pip install -q transformers datasets

#Loading Dataset.
This case we are using a CIFAR 10 dataset.
Essentially images with 10 labels.


In [None]:
from datasets import load_dataset

# load cifar10 (only small portion for demonstration purposes) 
train_ds, test_ds = load_dataset('cifar10', split=['train[:5000]', 'test[:2000]'])

# split up training into training + validation
splits = train_ds.train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']


Downloading:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/799 [00:00<?, ?B/s]

Downloading and preparing dataset cifar10/plain_text (download: 162.60 MiB, generated: 130.30 MiB, post-processed: Unknown size, total: 292.90 MiB) to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4...


Downloading:   0%|          | 0.00/170M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset cifar10 downloaded and prepared to /root/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
train_ds

Dataset({
    features: ['img', 'label'],
    num_rows: 4500
})

In [None]:
val_ds

Dataset({
    features: ['img', 'label'],
    num_rows: 500
})

In [None]:
id2label = {id:label for id, label in enumerate(train_ds.features['label'].names)}

In [None]:
label2id = {label:id for id,label in id2label.items()}
id2label

{0: 'airplane',
 1: 'automobile',
 2: 'bird',
 3: 'cat',
 4: 'deer',
 5: 'dog',
 6: 'frog',
 7: 'horse',
 8: 'ship',
 9: 'truck'}

#Extract features from images and spice it up with some classic augmentations

In [None]:
from transformers import ViTFeatureExtractor
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

Downloading:   0%|          | 0.00/160 [00:00<?, ?B/s]

In [None]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

In [None]:
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(feature_extractor.size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )


In [None]:
_val_transforms = Compose(
        [
            Resize(feature_extractor.size),
            CenterCrop(feature_extractor.size),
            ToTensor(),
            normalize,
        ]
    )

In [None]:
def train_transforms(examples):
    examples['pixel_values'] = [_train_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

def val_transforms(examples):
    examples['pixel_values'] = [_val_transforms(image.convert("RGB")) for image in examples['img']]
    return examples

In [None]:
# Set the transforms
train_ds.set_transform(train_transforms)
val_ds.set_transform(val_transforms)
test_ds.set_transform(val_transforms)

In [None]:
from torch.utils.data import DataLoader
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

train_dataloader = DataLoader(train_ds, collate_fn=collate_fn, batch_size=4)


In [None]:
batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)
    

pixel_values torch.Size([4, 3, 224, 224])
labels torch.Size([4])


#Downloading you model


In [None]:
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=10,
                                                  id2label=id2label,
                                                  label2id=label2id)

Downloading:   0%|          | 0.00/502 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/330M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#Putting it to Train!

In [None]:
from transformers import TrainingArguments, Trainer

metric_name = "accuracy"

args = TrainingArguments(
    f"test-cifar-10",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)    

In [None]:
from datasets import load_metric
import numpy as np

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)
    

Downloading:   0%|          | 0.00/1.41k [00:00<?, ?B/s]

In [None]:
import torch

trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

In [None]:
trainer.train()

***** Running training *****
  Num examples = 4500
  Num Epochs = 3
  Instantaneous batch size per device = 10
  Total train batch size (w. parallel, distributed & accumulation) = 10
  Gradient Accumulation steps = 1
  Total optimization steps = 1350


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.793335,0.97
2,1.468500,0.465967,0.98
3,0.729300,0.396102,0.98


***** Running Evaluation *****
  Num examples = 500
  Batch size = 4
Saving model checkpoint to test-cifar-10/checkpoint-450
Configuration saved in test-cifar-10/checkpoint-450/config.json
Model weights saved in test-cifar-10/checkpoint-450/pytorch_model.bin
Configuration saved in test-cifar-10/checkpoint-450/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 500
  Batch size = 4
Saving model checkpoint to test-cifar-10/checkpoint-900
Configuration saved in test-cifar-10/checkpoint-900/config.json
Model weights saved in test-cifar-10/checkpoint-900/pytorch_model.bin
Configuration saved in test-cifar-10/checkpoint-900/preprocessor_config.json
***** Running Evaluation *****
  Num examples = 500
  Batch size = 4
Saving model checkpoint to test-cifar-10/checkpoint-1350
Configuration saved in test-cifar-10/checkpoint-1350/config.json
Model weights saved in test-cifar-10/checkpoint-1350/pytorch_model.bin
Configuration saved in test-cifar-10/checkpoint-1350/preprocessor_

TrainOutput(global_step=1350, training_loss=0.9612960702401621, metrics={'train_runtime': 225.9452, 'train_samples_per_second': 59.749, 'train_steps_per_second': 5.975, 'total_flos': 1.046216869705728e+18, 'train_loss': 0.9612960702401621, 'epoch': 3.0})

#Test

In [None]:
outputs = trainer.predict(test_ds)

***** Running Prediction *****
  Num examples = 2000
  Batch size = 4


In [None]:
print(outputs.metrics)

{'test_loss': 0.470125675201416, 'test_accuracy': 0.973, 'test_runtime': 11.7443, 'test_samples_per_second': 170.295, 'test_steps_per_second': 42.574}
