Tutorial 3 (ViT)
======================


## About

For this part of the assignment, you will gain some experience working with Visual Transformers (ViT).
The main activties will be around fine-tuning ViT model using HuggingFace Lib.

* **Fine-tuning ViT model**:

    Fine-tune the ViT model on the CIFAR-10, DTD, and COCO-O datasets.


<hr> 

* The <b><font color="red">red</font></b> color indicates the task that should be done, like <b><font color="red">[TODO]</font></b>: ...
* Addicitional comments, hints are in <b><font color="blue">blue</font></b>. For example <b><font color="blue">[HINT]</font></b>: ...

## Prelimiaries

In [None]:
# !pip install datasets
# !pip install fiftyone
# !pip install scikit-learn
# !pip install tqdm

In [None]:
import os
import gdown
import zipfile

import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

from torchvision import transforms

from datasets import load_dataset
from datasets import Dataset, DatasetDict

from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import TrainingArguments, Trainer

import fiftyone as fo

from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

In [None]:
# make plots a bit nicer
plt.matplotlib.rcParams.update({'font.size': 18, 'font.family': 'serif'})

## Auxilary functions

In [None]:
def create_hf_cocoo_dataset(path_coco_o:str , path_data:str, seed:int =42, test_ratio=0.3):
    def load_image(example):
        example['image'] = Image.open(example['image_path'])
        return example

    if not os.path.exists(path_coco_o):
        url = 'https://drive.google.com/uc?id=1aBfIJN0zo_i80Hv4p7Ch7M8pRzO37qbq'
        zip_file_path = os.path.join(path_data, 'ood_coco.zip')
        gdown.download(url, zip_file_path, quiet=False)
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(path_data)

    cocoo_classes_list = os.listdir(path_coco_o)
    all_elements_coco = [
        (os.path.join(path_coco_o, label, 'val2017', img), index) 
        for index, label in enumerate(cocoo_classes_list) 
        for img in os.listdir(os.path.join(path_coco_o, label, 'val2017'))
    ]

    np.random.seed(seed)
    indices = np.arange(len(all_elements_coco))
    np.random.shuffle(indices)
    n_test = int(len(indices) * test_ratio)

    train_indices, test_indices = indices[n_test:], indices[:n_test]
    datasets = {}

    for split, split_indices in zip(['train', 'test'], [train_indices, test_indices]):
        split_data = [(all_elements_coco[i][0], all_elements_coco[i][1]) for i in split_indices]
        image_paths, labels = zip(*split_data)
        dataset = Dataset.from_dict({'image_path': image_paths, 'label': labels})
        datasets[split] = dataset.map(load_image, remove_columns=['image_path'])

    return DatasetDict(datasets), cocoo_classes_list

In [None]:
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))

def processor_transform(processor):
    def curry(example_batch):
        inputs = processor([x for x in example_batch['img']], return_tensors='pt')
        inputs['labels'] = example_batch['label']
        return inputs
    return curry

## Load data

In [None]:
# Set the local folder with the data
path_data = "./data"
os.makedirs(path_data, exist_ok=True)

In [None]:
# Load cifar10 dataset
cifar10_dataset = load_dataset('cifar10', cache_dir=path_data)
cifar10_classes_list = cifar10_dataset['train'].features['label'].names

In [None]:
# Load DTD dataset
dtd_dataset = load_dataset("tanganke/dtd", cache_dir=path_data)
dtd_classes_list = dtd_dataset['train'].features['label'].names

In [None]:
# Load COCO-O dataset
path_coco_o = os.path.join(path_data, 'ood_coco')
cocoo_dataset, cocoo_classes_list = create_hf_cocoo_dataset(path_coco_o, path_data)

## Training

### cifar10

In [None]:
# We will use the 'base' version of the ViT family
model_name = "google/vit-base-patch16-224"

In [None]:
# Special object for preprocessing
processor = ViTImageProcessor.from_pretrained(model_name) 

In [None]:
# Set the transform to match what the processor expects
transform_vit = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])

In [None]:
# Load pretrain model
model = ViTForImageClassification.from_pretrained(model_name)

In [None]:
# Inspect model
#print(model)
print(model.classifier)

In [None]:
# Set the correct number of classes
# Note: ignore warning
num_classes = len(cifar10_classes_list)
model = ViTForImageClassification.from_pretrained(model_name, num_labels=num_classes, ignore_mismatched_sizes=True)
print(model.classifier)

In [None]:
# Training arguments
args = TrainingArguments(
    output_dir='./results',
    logging_dir='./logs',    
    save_strategy="epoch",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    remove_unused_columns=False,
)

In [None]:
# We will take a small subset of the dataset to speed up things
# We are free to play with different number of (n_train, n_test)
n_train = 2000
n_test = 1000
train_subset = cifar10_dataset['train'].select(range(n_train))
test_subset = cifar10_dataset['test'].select(range(n_test))

# Prepare data for Trainer
transform_func = processor_transform(processor)
ds_train = train_subset.with_transform(transform_func)
ds_test = test_subset.with_transform(transform_func)

ds_test_full = cifar10_dataset['test'].with_transform(transform_func)

In [None]:
# Define HuggingFace 'Trainer'
trainer = Trainer(
    model,
    args, 
    train_dataset=ds_train,
    eval_dataset=ds_test,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

In [None]:
# Train
trainer.train()

In [None]:
# Predict
outputs = trainer.predict(ds_test_full)
print(outputs.metrics)

In [None]:
predictations = outputs.predictions.argmax(1)
true_labels = cifar10_dataset['test']['label']

In [None]:
# Detailed analysis (report)
print(classification_report(true_labels, predictations, target_names=cifar10_classes_list))

In [None]:
# Detailed analysis (confusion matrix)

cm = confusion_matrix(true_labels, predictations)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=cifar10_classes_list)

fig, ax = plt.subplots(figsize=(10, 8))
disp.plot(cmap='Blues', ax=ax, xticks_rotation=90);

In [None]:
type(cocoo_dataset)

### Another dataset

<b><font color="red">[TODO]</font></b>: Conduct fine-tuning experiments for DTD dataset or COCO-O dataset or both. What is the accuracy, how does it compare to the cnn-based experiments?

## Analysis of the results with FiftyOne lib

<b><font color="red">[TODO]</font></b>: Using the example from the previous 'Practice (Lecture 1)' session and the guidance from the provided [LINK](https://docs.voxel51.com/recipes/adding_classifications.html), analyze the COCO-O results using the FiftyOne tool. Specifically, focus on examining instances where the predictions do not align with the ground truth labels.