TO DO:
- [ ] Add example of dataset element
- [X] Write documentation
- [ ] Retrain the model
- [ ] Add an evaluation function
- [ ] Experiment with some values for the learning rate etc.

# Image Captioning using ViT and GPT-2
The second part of the pipeline consists of taking the extracted images as input and transforming them to a human-understandable text. This problem is more commonly known as image captioning, and many pre-existing models exist. One of the most commonly used technique for image captioning is by using an image encoding to transform the image into an embedding which can then be used as input for a language model. This language model then decodes the embedding back into natural language. A commonly used combination is the ViT/GPT-2 pair.  

### Sources
The sources used in this notebook include:
    - https://huggingface.co/nlpconnect/vit-gpt2-image-captioning
    - https://huggingface.co/docs/transformers/model_doc/vit
    - https://huggingface.co/docs/transformers/model_doc/gpt2
    - https://github.com/NjtechCVLab/RSTPReid-Dataset
    - https://vision.cornell.edu/se3/wp-content/uploads/2018/03/1501.pdf
    - https://en.wikipedia.org/wiki/BLEU
    - https://en.wikipedia.org/wiki/METEOR
    - https://huggingface.co/spaces/evaluate-metric/rouge
    
    - The courses in the postgraduate "AI and ML in business and engineering" at KULeuven.

In [None]:
%env TF_ENABLE_ONEDNN_OPTS=0

# Imports
from pathlib import Path
import torch
from datetime import datetime
from torch.utils import tensorboard
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel

from libs.data import RSTPReid
from libs.engine import train_one_epoch

In [1]:
# Define constants
MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning"
DATASET_PATH = Path().resolve().parent / 'RSTPReid'

## Load the model
ViT (Vision Transformer) is a Transformer encoder, which was trained on ImageNet. It attained "excellent results compared to state-of-the-art convolutional networks while requiring substantially fewer computational resources to train", according the the original paper by Dosovitskiy, Beyer, Kolesnikov, Weissenborn, Zhai, Unterthiner, Dehghani, Minderer, Heigold, Gelly, Uszkoreit, and Houlsby.

GPT-2 is a large transformer-based language model with 1.5 billion parameters, trained on a dataset of 8 million web pages. It is trained with a simple objective: predict the next word, given all of the previous words within a text.

In [2]:
%%capture
# Load the models
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
image_processor = ViTImageProcessor.from_pretrained(MODEL_NAME)
tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
# Place the model on the correct device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

2023-04-22 17:42:25.059625: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-22 17:42:26.236989: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /apps/leuven/icelake/2021a/software/CUDA/11.7.1/nvvm/lib64:/apps/leuven/icelake/2021a/software/CUDA/11.7.1/extras/CUPTI/lib64:/apps/leuven/icelake/2021a/software/CUDA/11.7.1/lib64
2023-04-22 17:42:26.237066: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot

## Load the data
Since the goal is to keep the Televic-provided images as testset, a publicly available data set is required to validate the performance of the model, and finetune it if necessary. The dataset used in this notebook is the RSTPReid (Real Scenario Textbased Person Re-identification) dataset.  
This dataset consists of 20505 images of 4101 persons from 15 cameras. Each person has 5 corresponding images taken by different cameras with complex scene transformations and backgrounds. Each image is annotated with 2 textual descriptions. The dataset comes pre-split into train-, validation-, and testsets, according to a 90-5-5 split.

In [3]:
# Load the datasets
train_set = RSTPReid(DATASET_PATH, 'train', image_processor)
val_set = RSTPReid(DATASET_PATH, 'val', image_processor)
test_set = RSTPReid(DATASET_PATH, 'test', image_processor)

print(f"Number of elements in the training set: {len(train_set)}")
print(f"Number of elements in the validation set: {len(val_set)}")
print(f"Number of elements in the test set: {len(test_set)}")

Number of elements in the training set: 18505
Number of elements in the validation set: 1000
Number of elements in the test set: 1000


## Defining an evaluation metric
To evaluate how well our model performs on the previously defined dataset, an evaluation metric is required. Commonly used metrics consist of CIDEr, METEOR, ROUGE, and BLUE. It is known that these metrics don't always correlate well with human judgement, and that each metric has well known blind spots. However, they are easy to use, and are often implemented in popular machine learning libraries. Since our use case is rather simplistic, they will suffice.  

- CIDEr (Consensus-based Image Description Evaluation) is a simple metric, which was defined specifically for image classification. However, since it is very new, there are no verified implementations in common libraries.
- METEOR (Metric for Evaluation of Translation with Explicit Ordering) is a metric for the evaluation of machine translation output. Since this is not our use case, this metric might be less suitable.
- ROUGE (Recall-Oriented Understudy for Gisting Evaluation) is a set of metrics used for evaluating automatic summarization and machine translation. Since our use case can (generously) be described as summarizing an image into a short sentence, this metric might be suitable.
- BLUE (Bilingual evaluation understudy) is an algorithm for evaluating the quality of a text which has been machine translated from one natural language to another. Since the use case here is not machine translation, this metric might not be very well suited.
It should be noted that a more naive approach would be to simple use the Word Error Rate. However, this is a bad scoring way, since it does not account for synonyms.

## Retraining the model

In [4]:
# Freeze all parameters that are a part of the decoder (GTP-2)
for param in model.decoder.parameters():
     param.requires_grad = False
params = [p for p in model.parameters() if p.requires_grad]
print(f"We're finetuning {len(params)} of the {len(list(model.parameters()))} parameters.")

We're finetuning 200 of the 444 parameters.


In [5]:
# Convert the dataset to a dataloader
dataloader = torch.utils.data.DataLoader(train_set)

In [6]:
# Define an optimizer and a learning rate scheduler
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [7]:
writername = datetime.now().strftime("%Y%m%d_%H%M%S")
writer = tensorboard.SummaryWriter(writername)

In [8]:
# Start finetuning
num_epochs = 5
for epoch in range(num_epochs):
    print(f" === EPOCH {epoch} === ")
    train_one_epoch(
        model=model, 
        optimizer=optimizer, 
        tokenizer=tokenizer, 
        data_loader=dataloader, 
        epoch=epoch, 
        device=device,
        writer=writer
    )
    model.save_pretrained(f"retrained_temp_epoch_{epoch}")

 === EPOCH 0 === 


                                                                    

 === EPOCH 1 === 


                                                                  

 === EPOCH 2 === 


                                                                   

 === EPOCH 3 === 


                                                                  

 === EPOCH 4 === 


                                                                  

In [9]:
model.save_pretrained("baseline_model")