# 6. Vision Transformer

## Introduction

In this notebook, we are going to fine-tune a pre-trained Vision Transformer (from [🤗 Transformers](https://github.com/huggingface/transformers)) for art classification. We will train the model using [PyTorch Lightning ⚡](https://github.com/PyTorchLightning/pytorch-lightning). 

HuggingFace 🤗 is a leading open-source software library and community that has gained significant attention in recent years for its contributions to democratizing AI. The library provides pre-trained models, datasets, and a suite of tools that make it easier for developers to build and deploy AI applications. One of the most significant contributions of HuggingFace is the development of the Transformers library, which provides an easy-to-use interface for working with Transformer-based models such as BERT and GPT.

PyTorch Lightning is an open-source Python library that provides a high-level interface for PyTorch. This lightweight and high-performance framework organizes PyTorch code to decouple the research from the engineering, making Deep Learning experiments easier to read and reproduce.

**Source:** Rogge, N. (2021) [Fine-tuning the Vision Transformer on CIFAR-10 with PyTorch Lightning - GitHub](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Fine_tuning_the_Vision_Transformer_on_CIFAR_10_with_PyTorch_Lightning.ipynb).

![vit.png](./docs/Vision_Transformer.png)

## What are Transformers?

The Transformer architecture, which was introduced in the seminal paper "Attention is All You Need" in 2017, has taken the world of Deep Learning by storm, particularly in the field of Natural Language Processing (NLP). As a large language model first based on the GPT-3.5 architecture, ChatGPT is a prime example of an application based on the Transformer architecture that is now in the public eye. In addition to ChatGPT, many other popular applications, such as Google's BERT, OpenAI's GPT series, and Facebook's RoBERTa, also rely on the Transformer architecture to achieve state-of-the-art results in NLP tasks. Furthermore, the Transformer architecture has also made significant inroads in the field of computer vision, as evidenced by the success of models such as ViT and DeiT on ImageNet and other visual recognition benchmarks.

The major innovation of the transformer architecture is combining the use of attention-based representations and a CNN style of processing. Unlike traditional convolutional neural networks (CNNs) that rely on convolutional layers to extract features from images, Transformers use attention mechanisms (self-attention, multi-headed attention) to selectively focus on different parts of an input sequence.

The main advantage of Transformers over traditional CNNs is that they can capture long-range dependencies in data more effectively. This is especially useful in Computer Vision tasks where an image may contain objects that are spread out across the image, and where the relationships between objects may be more important than the objects themselves. By attending to different parts of the input image, Transformers can effectively learn to extract these relationships and improve performance on tasks such as object detection and segmentation.


**Sources:**

+ Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). [Attention is all you need.](https://arxiv.org/abs/1706.03762) - arXiv preprint arXiv:1706.03762. 
+ Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2020). [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) - arXiv preprint arXiv:2010.11929.
+ Google Research. (2021). [Vision Transformer and MLP-Mixer Architectures  - GitHub](https://github.com/google-research/vision_transformer)

In [None]:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install -q transformers datasets pytorch-lightning

In [1]:
import os
from src.utils import *
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from src.vit_fine_tune import ViTLightningModule

  from .autonotebook import tqdm as notebook_tqdm


## Load data

Our data is downloaded from the above mentioned kaggle URL. We have saved our data in a directory called `data`, which must be located within the same folder as the execution notebooks. Within `data`, we can find two separate directories: `train`, `test` and `validation`, which contain an equal number of randomly selected images from each of the art classes selected. 

In [2]:
data_dir = "./data"

# Train folder
train_dir = os.path.join(data_dir, "train")
# Validation folder
validation_dir = os.path.join(data_dir, "validation")
# Test folder
test_dir = os.path.join(data_dir, "test")

dataset_stats(train_dir, validation_dir, test_dir)

Number of classes: 4
Existing classes: ['Baroque', 'Realism', 'Renaissance', 'Romanticism']

----------------------------------------
Number of images per class and dataset:
----------------------------------------
             Train  Validation  Test
Style                               
Baroque       4000         500   500
Realism       4000         500   500
Renaissance   4000         500   500
Romanticism   4000         500   500


We can see that all classes are well balanced, and that we have a fair amount of data for training and validation.

We now create the directory where the models will be saved:

In [4]:
# Create directory where to save the models created
models_dir = "./models"
os.makedirs(models_dir, exist_ok=True)

## Activating CUDA for GPU processing

GPUs (Graphics Processing Units) are specialized processors designed to handle the complex computations involved in rendering graphics and images. However, due to their parallel processing capabilities, they are also useful for a wide range of other applications, including Machine Learning and Scientific Computing. Unlike traditional CPUs (Central Processing Units), which are designed to handle a few tasks very quickly, GPUs can handle many smaller tasks simultaneously, making them ideal for computationally-intensive applications.

CUDA (Compute Unified Device Architecture) is a parallel computing platform and programming model developed by NVIDIA, designed to harness the power of GPUs for general-purpose computing tasks. CUDA allows developers to write programs that run on the GPU, taking advantage of its parallel processing capabilities to accelerate performance significantly.

In order to significantly speed up the training of the model, we will use GPU acceleration. We will first check if CUDA is available in our system.

In [3]:
import torch
print(f"Is CUDA supported by this system? {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")
  
# Storing ID of current CUDA device
cuda_id = torch.cuda.current_device()
print(f"ID of current CUDA device: {torch.cuda.current_device()}")
        
print(f"Name of current CUDA device: {torch.cuda.get_device_name(cuda_id)}")

Is CUDA supported by this system? True
CUDA version: 11.8
ID of current CUDA device: 0
Name of current CUDA device: NVIDIA GeForce GTX 1060 with Max-Q Design


CUDA is supported by our system, so we will train the models using GPU.

## Training

TensorBoard is a web-based visualization tool provided by TensorFlow for visualizing and analyzing various aspects of machine learning experiments.

The %load_ext tensorboard command loads the TensorBoard extension in Jupyter Notebook. %tensorboard --logdir lightning_logs/ command starts TensorBoard and specifies the directory where the logs are stored, in this case `./lightning_logs/` TensorBoard reads the events and metrics logged during the training process and provides visualizations to analyze the model's performance, including loss and accuracy curves, histograms of weights and biases, and more.

In [10]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

We use early stopping to stop training when the validation loss stops improving.

In [12]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

model = ViTLightningModule()

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=3,
    strict=False,
    verbose=False,
    mode='min'
)

trainer = Trainer(
    accelerator='gpu',
    devices=1,
    callbacks=[
        early_stop_callback
    ]
)

trainer.fit(model)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used

Epoch 7: 100%|██████████| 1000/1000 [18:37<00:00,  1.12s/it, v_num=7]      


## Validation

We will first show the test results

In [13]:
trainer.test()

  rank_zero_warn(
Restoring states from the checkpoint path at d:\Estudios\Masters\MBD_ICAI\Cuatri_2\NoEstruc\IMAGES\Practica\lightning_logs\version_7\checkpoints\epoch=7-step=8000.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at d:\Estudios\Masters\MBD_ICAI\Cuatri_2\NoEstruc\IMAGES\Practica\lightning_logs\version_7\checkpoints\epoch=7-step=8000.ckpt


Testing DataLoader 0: 100%|██████████| 500/500 [00:56<00:00,  8.83it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss           0.5024473667144775
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.5024473667144775}]

Let's look at the reports for the train, test and validation set

In [2]:
# Load best model from the latest checkpoint
best_model = load_latest_checkpoint(ViTLightningModule)
# Get best model metrics
get_vit_metrics(best_model, train=True)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using device: cuda


100%|██████████| 125/125 [17:48<00:00,  8.55s/it]


Training set classification report:
              precision    recall  f1-score   support

     Baroque       0.95      0.91      0.93      4000
     Realism       0.92      0.94      0.93      4000
 Renaissance       0.96      0.91      0.93      4000
 Romanticism       0.88      0.94      0.91      4000

    accuracy                           0.93     16000
   macro avg       0.93      0.93      0.93     16000
weighted avg       0.93      0.93      0.93     16000

Using device: cuda


100%|██████████| 63/63 [02:26<00:00,  2.32s/it]


Test set classification report:
              precision    recall  f1-score   support

     Baroque       0.87      0.83      0.85       500
     Realism       0.82      0.82      0.82       500
 Renaissance       0.89      0.85      0.87       500
 Romanticism       0.74      0.81      0.78       500

    accuracy                           0.83      2000
   macro avg       0.83      0.83      0.83      2000
weighted avg       0.83      0.83      0.83      2000

Using device: cuda


100%|██████████| 63/63 [02:24<00:00,  2.29s/it]

Validation set classification report:
              precision    recall  f1-score   support

     Baroque       0.89      0.82      0.85       500
     Realism       0.83      0.85      0.84       500
 Renaissance       0.90      0.83      0.86       500
 Romanticism       0.76      0.86      0.81       500

    accuracy                           0.84      2000
   macro avg       0.84      0.84      0.84      2000
weighted avg       0.84      0.84      0.84      2000






Regarding the **training set**, the model exhibits excellent classification metrics such as precision, recall, and F1-score for all the four classes, with scores ranging from 0.88 to 0.96, indicating the model's capability to classify the artworks in the training set with high confidence. The F1-scores for both macro average and weighted average are 0.93, denoting the model's overall good performance, which is also reflected in the high accuracy of 0.93.

Moving on to the **test set**, we observe a decrease in the precision, recall, and F1-scores as compared to the training set. Despite the decrease, the model still performs reasonably well with F1-scores varying between 0.74 to 0.89, signifying the model's generalization ability to novel examples. Both macro average and weighted average F1-scores are 0.83, which indicate good performance but slightly lower compared to the training set. However, the model's overall accuracy on the test set is 0.83, which is considerably better than the mentioned models with an accuracy around 60%.

Finally, regarding the **validation set**, we observe comparable performance to the test set, with slightly higher F1-scores ranging from 0.76 to 0.90. Both macro average and weighted average F1-scores are 0.84, slightly higher than the performance on the test set.

In conclusion, the fine-tuned Vision Transformer model demonstrates high-quality performance for the art classification task, as evidenced by its ability to classify artworks accurately with high precision and recall on the training set. The model's good performance generalizes to new examples with its ability to classify artworks with reasonably high F1-scores on the test and validation sets.

We save the final model in `vit_model.pt`.

In [5]:
# Save the model in the models directory
torch.save(best_model.state_dict(), os.path.join(models_dir, "vit_model.pt"))

## Conclusions

In summary, the Vision Transformer outperforms the previous models by a significant margin. It achieves an accuracy that is approximately 0.2 higher than the other models, representing a 25% increase in performance. The superior performance of the Vision Transformer model can be attributed to its ability to capture global dependencies and interactions between the features. While traditional CNN models rely on convolutional and pooling operations to extract local features and flatten them into a vector, transformers use self-attention mechanisms that enable global interactions among all the features. This allows transformers to model complex relationships between the features and identify long-range dependencies, making them particularly effective for tasks such as image classification.

In addition, the hierarchical architecture of the Vision Transformer may also contribute to its success in the art style classification task. The architecture allows it to process images at multiple levels of granularity, from local features to the entire image. This enables the model to learn representations that are better suited for tasks such as image classification, and could explain its strong performance in this project.