<img style="float: right;" src="../../assets/htwlogo.svg">

# Exercise: Using embeddings and implementing a k-nearest neighbour classifier

We already learned a lot about embeddings and also got to know our first classication
model - a k-nearest neighbour classifier. Let's dive into some practice and play around with CLIP embeddings.

**Author**: _Erik Rodner_<br>

In [None]:
# It's time to use pytorch, at least to handle CLIP embeddings,
import torch
from transformers import CLIPProcessor, CLIPModel # HuggingFace module
# and some interesting datasets.
from torchvision import datasets
# For classical splitting, we will still use sklearn, since we
# learned about it. Otherwise, one would rather use pytorch only later on.
from sklearn.model_selection import train_test_split
# This is just to compute pairwise distances for the first task
from scipy.spatial.distance import pdist, squareform
import numpy as np




We do not have to start from scratch at all. The following functions should be known from the lecture
material. However, there is a slight change in the ``get_text_embeddings``function, which allows processing
several strings at once. The ``get_pairwise_distances`` function is just for your convenience, it basically computes
all mutual distances between a set of points and returns the distances as a quadratic matrix.

In [None]:
# Load the CLIP model and processor - this can take a while, since the model weights need
# to be downloaded
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Function to get the text embedding using CLIP
def get_text_embeddings(text_array):
    """ Get CLIP-embeddings for a list of strings, returns a list of numpy arrays """
    
    # Preprocess the text
    inputs = processor(text=text_array, return_tensors="pt", padding=True, truncation=True)
    # Generate the embedding
    with torch.no_grad():
        outputs = model.get_text_features(**inputs)
    # Return the embeddings as an array of numpy arrays
    return [ output.cpu().numpy().flatten() for output in outputs ]

def get_image_embedding(image):
    """ Get a CLIP-embedding for a single PIL image """
    inputs = processor(images=image, return_tensors="pt")
    # Generate the embedding
    with torch.no_grad():
        outputs = model.get_image_features(**inputs)
    # Return the embedding as a numpy array
    return outputs[0].cpu().numpy().flatten()

def get_pairwise_distances(vectors):
    """ Compute mutual distances and return them as a quadratic matrix """
    return squareform(pdist(vectors))

## Task 1: Analyze CLIP embeddings for a collection of texts

The first task is about exploration. take a list of texts as follows and try to understand the resulting
pairwise distances. The distances should somehow match with your understanding of semantic similarity. What are the dimensions of these vectors?
Can you find counter-examples? How about visualizing them in 2D using PCA or ``umap``(not officially part of the lecture)?

In [None]:
embedding_vectors = get_text_embeddings(["Hello", "How are you?", "cat", "dog", "butterfly"])

In [None]:
get_pairwise_distances(embedding_vectors)

### Dataset preparation for the next tasks

In the following, we will explore CLIP embeddings of a classical computer vision datasets that
lived even before ImageNet: Caltech 101. It has 101 object categories annotated for classification and most
of the images were collected using Google image search. Pytorch has some nice helper functions to automatically download the dataset.

In [None]:
dataset = datasets.Caltech101(root="../data", download=True)

The first task will be about **zero-shot performance** of CLIP embeddings. This is only possible, because CLIP provides a joint embedding space of images and text and semantic similarity was enforced of images and their captions during training.
Therefore, we can easily compute the semantic similarity of an image and a text, by calculating the distance
between their embeddings. 

We will apply this principle to the dataset, by computing the similarity of each image with all possible category names. The category with the smallest distance (or highest similarity) to the image will be our prediction.

So let's first compute all embeddings of all category names:

In [None]:
text_embeddings = get_text_embeddings(dataset.categories)

For our distance metric, we could use standard Euclidean distance. However, it is reasonable to use the
so called cosine similarity function, defined as follows (can you spot why it is called cosine similarity?):

In [None]:
def cosine_sim(x1, x2):
    return np.dot(x1, x2)/(np.linalg.norm(x1) * np.linalg.norm(x2))

Let's all select a few training and testing images.

In [None]:
indices = range(len(dataset))
labels = [ example[1] for example in dataset ]
indices_train, indices_temp, _, labels_temp  = train_test_split(indices, labels, shuffle=True, stratify=labels, train_size=101*4, random_state=42)
_, indices_test = train_test_split(indices_temp, test_size=202, stratify=labels_temp, random_state=42)

### Task 2: Zero-shot classification on Caltech 101

In the following, we only use the test dataset. Do the following:
1. Go through all test images (I helped with that already :P)
2. Compute the image embedding of each test image
3. Calculate the cosine similarity of the image embedding and each category embedding
4. Obtain a class prediction by looking at the maximum similarity
5. Compare it with the ground-truth label

In [None]:
for index_test in indices_test:
    image, label = dataset[index_test]
    # your code comes here :)
    # ...

### Task 3: Write your own k-nn classifier

Let's skip all text embeddings in the following, we will only use image embeddings.
Your task is now to implement a k-nearest neighbour classifier and evaluate this later on with the above specified training and test set.

In [None]:
# your code comes here :)

### Task 4: Can you spot the dataset bias?

Caltech 101 has some severe dataset biases - can you spot them?
Look at the images of the motorbike or airplane category!

In [None]:
# add some visualization code here