Tutorial 4 (Image embeddings for Classification)
======================


## About

For this part of the assignment is dedicated to different image embeddings (DINO, CLIP, ResNet).


<hr> 

* The <b><font color="red">red</font></b> color indicates the task that should be done, like <b><font color="red">[TODO]</font></b>: ...
* Addicitional comments, hints are in <b><font color="blue">blue</font></b>. For example <b><font color="blue">[HINT]</font></b>: ...

## Prelimiaries

In [None]:
# !pip install datasets
# !pip install fiftyone
# !pip install scikit-learn
# !pip install tensorboard jupyter-tensorboard
# !pip install tqdm

In [None]:
import os
import gdown
import zipfile
from tqdm import tqdm
from copy import deepcopy

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader


from datasets import load_dataset
from datasets import Dataset, DatasetDict

from transformers import AutoImageProcessor, AutoModel

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay


In [None]:
# make plots a bit nicer
plt.matplotlib.rcParams.update({'font.size': 18, 'font.family': 'serif'})

## Auxilary functions

### Training related

We will reuse the scripts from previous tutorials. It is a bit reduntant, but all notebooks are self-contained. 

In [None]:
def train_and_validate(
    model: nn.Module,    
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_epoch: int,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
    max_iter: int | None = None
) -> nn.Module: 
    """Simple training script."""
    
    model.to(device)

    best_val_accuracy = 0.0
    best_model_state = None

    for epoch in range(num_epoch):
        model.train()
        train_metrics = {"loss": 0.0, "correct": 0, "amount": 0}

        for batch_idx, (inputs, labels) in tqdm(enumerate(train_loader), 'training', total=len(train_loader)):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()

            _, preds = torch.max(output, 1)
            train_metrics["amount"] += inputs.size(0)
            train_metrics["loss"] += loss.item() * inputs.size(0)
            train_metrics["correct"] += torch.sum(preds == labels.data)

            if max_iter and batch_idx > max_iter:
                break

        train_loss = train_metrics["loss"] / len(train_loader.dataset)
        train_acc = train_metrics["correct"].float() / len(train_loader.dataset)

        model.eval()
        val_metrics = {"loss": 0.0, "correct": 0}
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, 'evaluation', total=len(val_loader)):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                val_metrics["loss"] += loss.item() * inputs.size(0)
                val_metrics["correct"] += torch.sum(preds == labels.data)

        val_loss = val_metrics["loss"] / len(val_loader.dataset)
        val_acc = val_metrics["correct"].float() / len(val_loader.dataset)

        if val_acc > best_val_accuracy:
            best_model_state = deepcopy(model.state_dict())
            best_val_accuracy = val_acc

        print(
            f'Epoch [{epoch + 1}/{num_epoch}], '
            f'train loss: {train_loss:.4f}, train acc: {train_acc:.4f}, '
            f'val loss: {val_loss:.4f}, val acc: {val_acc:.4f}'
        )

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model

In [None]:
def predict(
    model: nn.Module,    
    data_loader: torch.utils.data.DataLoader,
    device: torch.device,
) -> np.ndarray:
    """ Predict on a given dataloader. """
    model.to(device)
    model.eval()
    predictations = []
    ground_truth_labels = []
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            predictations.extend(preds.cpu().numpy())
            ground_truth_labels.extend(labels)
    return np.array(predictations), np.array(ground_truth_labels)

### Models related

In [None]:
class SimpleMLP(nn.Module):
    
    def __init__(self, input_size: int, hidden_size: int, n_classes: int):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, n_classes)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Data related

In [None]:
def create_hf_cocoo_dataset(path_coco_o: str, path_data: str, seed: int = 42, test_ratio = 0.3) -> tuple[DatasetDict, list[str]]:
    def load_image(example):
        example['image'] = Image.open(example['image_path'])
        return example

    if not os.path.exists(path_coco_o):
        url = 'https://drive.google.com/uc?id=1aBfIJN0zo_i80Hv4p7Ch7M8pRzO37qbq'
        zip_file_path = os.path.join(path_data, 'ood_coco.zip')
        gdown.download(url, zip_file_path, quiet=False)
        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
            zip_ref.extractall(path_data)

    cocoo_classes_list = os.listdir(path_coco_o)
    all_elements_coco = [
        (os.path.join(path_coco_o, label, 'val2017', img), index) 
        for index, label in enumerate(cocoo_classes_list) 
        for img in os.listdir(os.path.join(path_coco_o, label, 'val2017'))
    ]

    np.random.seed(seed)
    indices = np.arange(len(all_elements_coco))
    np.random.shuffle(indices)
    n_test = int(len(indices) * test_ratio)

    train_indices, test_indices = indices[n_test:], indices[:n_test]
    datasets = {}

    for split, split_indices in zip(['train', 'test'], [train_indices, test_indices]):
        split_data = [(all_elements_coco[i][0], all_elements_coco[i][1]) for i in split_indices]
        image_paths, labels = zip(*split_data)
        dataset = Dataset.from_dict({'image_path': image_paths, 'label': labels})
        datasets[split] = dataset.map(load_image, remove_columns=['image_path'])

    return DatasetDict(datasets), cocoo_classes_list

### Embedding related

In [None]:
def extract_embedding(processor, model, image, dino_mode: str = 'clc', device: torch.device = None) -> np.ndarray:
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model_name = model.__class__.__name__.lower()

    with torch.no_grad():
        inputs = processor(images=image, return_tensors='pt').to(device)
        outputs = model(**inputs)

        if 'dino' in model_name:
            image_features = outputs.last_hidden_state
            image_features = (
                image_features.mean(dim=1) if dino_mode == 'mean'
                else image_features[:, 0, :]
                if dino_mode == 'clc'
                else ValueError("Unsupported 'mode': choose 'mean' or 'clc'")
            )
        elif 'clip' in model_name:
            image_features = model.get_image_features(**inputs)
        elif 'resnet' in model_name:
            image_features = outputs.pooler_output
        else:
            raise ValueError("Unknown 'model_type': choose 'dino', 'clip', or 'resnet'")

    return np.float32(image_features.detach().cpu().numpy().squeeze())

In [None]:
def save_embeddings_and_labels(data_set, processor, model, path_data, dataset_name, model_name, parts=['train', 'test']):
    for ds_part in parts:
        fname = f'{dataset_name}_{model_name}_{ds_part}'
        embeddings, labels = zip(*[
            (
                extract_embedding(processor, model, dinfo['image'].convert('RGB') if dinfo['image'].mode != 'RGB' else dinfo['image']),
                dinfo['label']
            )
            for dinfo in tqdm(data_set[ds_part])
        ])

        np.save(os.path.join(path_data, f'{fname}_features.npy'), np.array(embeddings))
        np.save(os.path.join(path_data, f'{fname}_labels.npy'), np.array(labels))

def load_saved_data(path_data, dataset_name, model_name, parts=['train', 'test']):
    def load_part(ds_part):
        fname = f'{dataset_name}_{model_name}_{ds_part}'
        features_path = os.path.join(path_data, f'{fname}_features.npy')
        labels_path = os.path.join(path_data, f'{fname}_labels.npy')
        if os.path.exists(features_path) and os.path.exists(labels_path):
            return np.load(features_path), np.load(labels_path)
        else:
            print(f"Files for {ds_part} not found at {features_path} and/or {labels_path}")
            return None, None

    all_embeddings, all_labels = {}, {}
    for part in parts:
        embeddings, labels = load_part(part)
        if embeddings is not None and labels is not None:
            all_embeddings[part] = embeddings
            all_labels[part] = labels

    return all_embeddings, all_labels

In [None]:
def collate_fn(batch):
    data = torch.stack([torch.tensor(item['data']) for item in batch])
    labels = torch.tensor([item['labels'] for item in batch])
    return data, labels

## Load data

In [None]:
# Set the local folder with the data
path_data = "./data"
os.makedirs(path_data, exist_ok=True)

In [None]:
# Load cifar10 dataset
cifar10_dataset = load_dataset('cifar10', cache_dir=path_data)
cifar10_dataset = cifar10_dataset.rename_column(original_column_name='img', new_column_name='image')
cifar10_classes_list = cifar10_dataset['train'].features['label'].names

In [None]:
# Load DTD dataset
dtd_dataset = load_dataset("tanganke/dtd", cache_dir=path_data)
dtd_classes_list = dtd_dataset['train'].features['label'].names

In [None]:
# Load COCO-O dataset
path_coco_o = os.path.join(path_data, 'ood_coco')
cocoo_dataset, cocoo_classes_list = create_hf_cocoo_dataset(path_coco_o, path_data)

In [None]:
# Define mappings for datasets
datasets = {
    'cocoo': (cocoo_dataset, cocoo_classes_list),
    'cifar10': (cifar10_dataset, cifar10_classes_list),
    'dtd': (dtd_dataset, dtd_classes_list)
}

## Models

In [None]:
# We will consider three different embeddings - DINO, CLIP, ResNet50
models = {
    'dino': ('facebook/dinov2-base', 'facebook/dinov2-base'),
    'clip': ('openai/clip-vit-base-patch32', 'openai/clip-vit-base-patch32'),
    'resnet': ('microsoft/resnet-50', 'microsoft/resnet-50')
}

## Select dataset (COCO-O) and model (DINO)

In [None]:
# Select dataset
# We will start with 'cocoo' dataset
dataset_name = 'cocoo'  # e.g., 'cocoo', 'cifar10', 'dtd'
data_set, data_classes_list = datasets.get(dataset_name, (None, None))
if data_set is None:
    print('...unknown dataset')

# Select embedding model
# We will start with 'DINOv2' model
model_name = 'dino'  # e.g., 'dino', 'clip', 'resnet'
model_info = models.get(model_name, None)
if model_info:
    processor = AutoImageProcessor.from_pretrained(model_info[0])
    model = AutoModel.from_pretrained(model_info[1])
else:
    print('...unknown model')

In [None]:
# Inspect one image
img = data_set['train'][0]['image']
label = data_set['train'][0]['label']
print(data_classes_list[label])
#img

## Generate embeddings

In [None]:
# Run one image and check
img_features = extract_embedding(processor, model, img)
print(img_features.shape)

In [None]:
# To run on whole dataset and save embeddings
# We can do it for different folders and different datasets
save_embeddings_and_labels(data_set, processor, model, path_data, dataset_name, model_name)

In [None]:
# Once saved, we load one specific embeddings
model_name = 'dino'
dataset_name = 'cocoo'
embeddings_preloaded, labels_preloaded = load_saved_data(path_data, dataset_name, model_name)

## Training classifier

In [None]:
# Dataloaders for training
dataset_train = Dataset.from_dict({'data': embeddings_preloaded['train'], 'labels': labels_preloaded['train']})
dataset_test = Dataset.from_dict({'data': embeddings_preloaded['test'], 'labels': labels_preloaded['test']})
trainloader = DataLoader(dataset_train, batch_size=32, shuffle=False, collate_fn=collate_fn)
testloader = DataLoader(dataset_test, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [None]:
# Set the device 
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [None]:
# Set criterion
criterion = nn.CrossEntropyLoss()

In [None]:
# Other learning settings
num_epoch = 10
learning_rate = 0.001
batch_size = 64

In [None]:
# Let's have a look at the sizes of the loaders
print(len(trainloader))
print(len(testloader))

# Let's check sizes of the batch and their types
images, labels = next(iter(testloader))
print(images.shape, type(images))
print(labels.shape, type(labels))

In [None]:
# Set-up model and optimizer
model_mlp = SimpleMLP(768, 128, 47)
optimizer = torch.optim.Adam(model_mlp.parameters(), lr=learning_rate)

In [None]:
# Let's train!

In [None]:
%%time
best_model = train_and_validate(
    model=model_mlp, 
    train_loader=trainloader, 
    val_loader=testloader, 
    num_epoch=num_epoch, 
    criterion=criterion, 
    optimizer=optimizer, 
    device=device
)

In [None]:
# Let's do another round of training (with smaller LR)
learning_rate = 0.0001
optimizer = torch.optim.Adam(model_mlp.parameters(), lr=learning_rate)
best_model = train_and_validate(
    model=model_mlp, 
    train_loader=trainloader, 
    val_loader=testloader, 
    num_epoch=num_epoch, 
    criterion=criterion, 
    optimizer=optimizer, 
    device=device
)

In [None]:
# Make predictations
predictations, true_labels = predict(model=best_model, data_loader=testloader, device=device)

In [None]:
# Detailed analysis (report)
print(classification_report(true_labels, predictations, target_names=data_classes_list))

In [None]:
# Detailed analysis (confusion matrix)

cm = confusion_matrix(true_labels, predictations)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=data_classes_list)

fig, ax = plt.subplots(figsize=(10, 8))
disp.plot(cmap='Blues', ax=ax, xticks_rotation=90);

### Another embedding

<b><font color="red">[TODO]</font></b>: For the same dataset apply different embeddings model (CLIP, ResNet) and compare to the current results.

### Another dataset

<b><font color="red">[TODO]</font></b>: Conduct fine-tuning experiments for DTD dataset or cifar10 dataset or both. What is the accuracy, how does it compare to the cnn-based experiments, ViT experiments?