# Heidelberg Workshop number 2

<a target="_blank" href="https://colab.research.google.com/github/etienneguevel/heidelberg/blob/main/notebooks/TD.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

##  Setup
If running locally using jupyter, first intall the necessary libraries in your environment using the installation instructions in the repository.

If running from Google Colab, set `using_colab=True` below and run the cell.
In Colab, be sure to select 'GPU' under 'Resources'->'Modify the type of execution'.

In [None]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !git clone https://github.com/etienneguevel/heidelberg.git
    !{sys.executable} -m pip install -q -r ./heidelberg/requirements_colab.txt
    !{sys.executable} -m pip install --no-deps ./heidelberg/
    %cd heidelberg/notebooks

# Load the dataset
from medmnist import BloodMNIST

# We need to download the dataset it might take a while
train_dataset = BloodMNIST(split="train", size=64, download=True)
val_dataset = BloodMNIST(split="val", size=64, download=True)
test_dataset = BloodMNIST(split="test", size=64, download=True)

## White cells detection

In [None]:
import bioformats
import javabridge

from bioformats import ImageReader, get_omexml_metadata

def setup_javabridge(log_level: str = "ERROR"):
    javabridge.start_vm(class_path=bioformats.JARS)
    logger_name = javabridge.get_static_field(
        "org/slf4j/Logger", "ROOT_LOGGER_NAME", "Ljava/lang/String;")
    logger = javabridge.static_call(
        "org/slf4j/LoggerFactory", "getLogger",
        "(Ljava/lang/String;)Lorg/slf4j/Logger;", logger_name)
    level = javabridge.get_static_field(
        "ch/qos/logback/classic/Level", log_level,
        "Lch/qos/logback/classic/Level;")
    javabridge.call(logger, "setLevel",
                     "(Lch/qos/logback/classic/Level;)V", level)

setup_javabridge()

In [None]:
# Open the image file using ImageReader
file_path = "../data/blood_sample.ome.tif"

# Load the OME-XML metadata
omexml_metadata = get_omexml_metadata(file_path)

# Use ImageReader to read the image
with ImageReader(file_path) as reader:
    # Optionally print OME metadata for debugging
    print(omexml_metadata)
    x0,y0,x1,y1=10000, 10000, 12048, 12048
    
    image=reader.read(series=0, z=0, t=0,XYWH=(x0, y0, x1 - x0, y1 - y0))
    # Check image shape and other details
    print("Image shape:", image.shape) 

In [None]:
import matplotlib.pyplot as plt

plt.imshow(image)
plt.axis('off')
plt.show()

In [None]:
image.shape

In [None]:
import torch
from ultralytics import YOLO

model = YOLO("yolov8n.pt")

results = model.predict(image)

for result in results:
    # Detection
    result.boxes.xyxy  # box with xyxy format, (N, 4)
    result.boxes.xywh  # box with xywh format, (N, 4)
    result.boxes.xyxyn  # box with xyxy format but normalized, (N, 4)
    result.boxes.xywhn  # box with xywh format but normalized, (N, 4)
    result.boxes.conf  # confidence score, (N, 1)
    result.boxes.cls  # cls, (N, 1)

    # Segmentation
    result.masks.data  # masks, (N, H, W)
    result.masks.xy  # x,y segments (pixels), List[segment] * N
    result.masks.xyn  # x,y segments (normalized), List[segment] * N

    # Classification
    result.probs  # cls prob, (num_class, )

# Each result is composed of torch.Tensor by default,
# in which you can easily use following functionality:
result = result.cuda()
result = result.cpu()
result = result.to("cpu")

## Classification

Classification is a key task as the relative proportion of the different white blood cells
categories indicates the presence or not of a pathology.

This part's goal is to leverage an open-source dataset of white blood cells in order to
train a Deep Learning model to perform classification.

In [None]:
# Let's load the dataset
train_dataset.info

This dataset contains images of white blood cells that are divided between 8 categories:

- basophil
- eosinophil
- erythroblast
- immature granulocytes
- lymphocyte
- monocyte
- neutrophil
- platelet

### Understand an image

In [None]:
import random
import matplotlib.pyplot as plt

train_size = len(train_dataset)

label_dict = train_dataset.info["label"]
img, label = train_dataset[random.randint(0, train_size)]

print(f"Image's class is: {label_dict.get(str(label[0]))}\n")
plt.imshow(img)
plt.show()

The images are made of pixels, made of a grid of size H*L, and each pixel is made
of 3 channels : Red, Green and Blue.  

This makes images **3-Dimensional** objects (N\*L\*3), here the images are of size
64\*64\*3

In [None]:
import numpy as np

image = np.array(img.convert("RGB"))
figure, plots = plt.subplots(ncols=3, nrows=1)
for i, subplot in zip(range(3), plots):
    temp = np.zeros(image.shape, dtype='uint8')
    temp[:,:,i] = image[:,:,i]
    subplot.imshow(temp)
    subplot.set_axis_off()
plt.show()

### Create the data objects

`torch` is a popular python framework to make Deep Learning models. Among its functionalities, it offers ways
to facilitate data usage through `Dataset` and `DataLoader` objects.  
The first step is to create a `Dataset` ([here](https://docs.pytorch.org/tutorials/beginner/basics/data_tutorial.html)) object.  

You should implement three methods for that :
- `__init__` -> create the 'attributes' of our object (kind of a way of storing data)
- `__len__` -> return the number of elements of our dataset (allows our object to get called with the `len` function)
- `__getitem__` -> return the element at an index (like for a `list`). In our case an item is made of an image transformed into a `tensor` + its label

In [None]:
# The first step is to make a dataset, for this we need to create our own custom object
# Below is the "backbone" of a Dataset object, with all the necessary methods that need to be implemented
# Uncomment and execute the cell below to see the answer
from torch.utils.data import Dataset

class CustomDataset(Dataset):

    def __init__(self, data, transform):
        # initialization method, you should store data and transform as attributes
        pass

    def __len__(self):
        # should return the number of elements of the dataset
        pass

    def __getitem__(self, idx):
        # should return the element of the dataset at index idx (image and label)
        # (Don't forget to transform the image)
        pass

In [None]:
# %load ../src/custom_dataset.py

In [None]:
# In this cell we use the custom dataset we just created with our imported datasets
import torchvision.transforms as transforms

# Here we need a transform vision to convert the data which are PIL images into vectors
transform = transforms.ToTensor()

training_dataset = CustomDataset(train_dataset, transform)
validation_dataset = CustomDataset(val_dataset, transform)

What we will want to do afterwards is to loop over the dataset, to make 'batches'
of data for the training protocol.   

`torch` provides the `DataLoader` object, let's create those data loaders!

In [None]:
# Make a DataLoader from the custom dataset created above
from torch.utils.data import DataLoader

train_loader = DataLoader(training_dataset, batch_size=64)
valid_loader = DataLoader(validation_dataset, batch_size=64)

# Let's check that the dataloader works
for images, labels in train_loader:
    print(f"Batch size: {images.shape[0]}")
    print(f"Image shape: {images.shape[1:]}")
    print(f"Labels shape: {labels.shape}")
    break

### Create the model

The model we are going to use for this task are Convolutional Neural Networks
([CNN](https://poloclub.github.io/cnn-explainer/)), and especially a family
of them called [ResNets](https://arxiv.org/abs/1512.03385).  

Training an efficient model can be time-consuming. However, models for image
processing have already been created and trained for similar tasks, and are
available for reusage.  

Those models are called **pretrained**, and can be use to be adapted on our
dataset (fitting them on our custom classes of white blood cells).  
This process is called **fine-tuning**.

In [None]:
# We create a model from a pretrained point
from torchvision import models

model = models.resnet18(pretrained=True)

# We can see the blocks in our model like this
print(model)

In [None]:
# We can see the number of parameters of our model like this
print(f"There are {sum([p.numel() for p in model.parameters()]):.2g} parameters in the model used.")

In [None]:
# Bonus: You can also make your own model architecture!
# Uncomment and execute the cell below to see a simple CNN appear
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Net, self).__init__()
        # create layers here
        # self.l1 = ...

    def forward(self, x):
        # here the output should be a vector of size num_classes
        pass

In [None]:
# %load ../src/net.py

### Training loop

The next step of our process is the actual training of our model, the steps are
as follow :

- Load a batch of our dataloader
- Get the output of our model
- Compute the loss with the labels
- Make the propagation

For this we are going to need the following objects :

- A criterion (ie a way to calculate the loss); for multi-label classification,
    almost only the **cross-entropy** loss is used
- An optimizer (ie an algorithm to update the weights of our model). Several
    popular options exists (Adam, Stochastic Gradient Descent...)


Then we will implement the training loop and make the model train!

In [None]:
# Now we need to make an optimizer + a loss
import torch.nn as nn
from torch.optim import Adam, SGD, AdamW

# cross-entropy loss is the one used for multi classification tasks
criterion = nn.CrossEntropyLoss()

# Adam is a popular optimizer, but other could be used (SGD, Adamw...)
optimizer = Adam(model.parameters(), lr=0.001, weight_decay=1e-4) # you can also use SGD or AdamW, also try different learning rates and weight decay values

Now we have everything necessary in order to launch the training process.  
> ToDo : Find an image to illustrate training loop

In [None]:
# Implement the training loop here
# Uncomment and execute the cell below to see the answer
from tqdm import tqdm # tqdm is a library to display progress bars while looping

n_epochs = 10 # Number of epochs to train the model (ie number of times the model will see the whole dataset)

for epoch in range(n_epochs):
    model.train()  # Set the model to training mode

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs}"):
        optimizer.zero_grad()  # Zero the gradients -> essential step before the backward pass

        # ToDo: Calculate the outputs of the model
        
        # ToDo: Compute the loss
        loss = 0

        # Make the backward pass and update the weights
        loss.backward()  # Backward pass
        optimizer.step()  # Update the weights

    # Bonus: Implement the validation loop here
    # You need to set the model to evaluation mode and use the validation dataloader
    # to compute the validation loss and accuracy of the model every k epochs (e.g. every 5 epochs)
    if epoch % 5 == 0:
        model.eval()

In [None]:
# %load ../src/training.py

### Visualize the results

Our dataset was split into three parts: train ,val and test. While we have used
the train for the backpropagation, and the val for monitoring, test is still
unseen to this point.

It's purpose is to the metrics at the endpoint of the training pipeline.

In [None]:
# Execute this cell to see the results of your trained model on the test dataset
from tqdm import tqdm

def find_accuracy(model, dataloader_test, device):
    correct = 0
    total = 0
    label_test = []
    predicted_test = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader_test):
            # Move the data to the device 
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Predict the label with the trained model
            outputs = model(inputs)
            preds = outputs.argmax(dim=1)
            
            # Calculate the performance
            labels = labels.squeeze(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            # Add the predictions & labels in the list
            label_test.extend(labels.to('cpu').tolist())
            predicted_test.extend(preds.to('cpu').tolist())

    print('Accuracy of the network on the test images: %d %%' % (
        100 * correct / total))
    return predicted_test, label_test, correct / total

# Make the test dataset
testing_dataset = CustomDataset(test_dataset, transform)
test_loader = DataLoader(testing_dataset, batch_size=64)

# Specify the device
device = torch.device("cuda")
model.to(device)

# Calculate the results
predictions, labels, acc = find_accuracy(model, test_loader, device)

In classification problems, it is also important to check that the algorithm
has good performances on every classes, and not only on the most dominants.  

Indeed, in case of imbalanced dataset, training can be biased and less populated
classes have deteriorated performances.

In [None]:
# Execute this cell to visualize the results of your model on the test set
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sklearn.metrics import classification_report, confusion_matrix

names = [n[:20] for n in train_dataset.info.get("label").values()]

C_matrix = pd.DataFrame(confusion_matrix(labels, predictions))
C_matrix.index = names
C_matrix.columns = names

fig = plt.figure(figsize=(12, 8))
sns.heatmap(C_matrix, annot=True, cmap="flare", vmax=100, fmt='.3g')
fig.show()

print(classification_report(labels, predictions, target_names=names))

## Foundation models

*"Foundation models"* are called like this because their outputs, called **embeddings**
are not predictions but vectors of a shape $\mathbb{R}^n$.  
The embeddings are the angular stone used for other usages, which can be classification,
segmentation, multimodal models...

In theory every architecture can be used to make a fondation model, but in practice
[Transformers](https://poloclub.github.io/transformer-explainer/) are THE type of
model commonly used,  
and have been the workhorse of the AI ecosystem for the last 8 years.

### Vision Transformers (ViT)

[ViTs](https://arxiv.org/abs/2010.11929) are the adaptation of transformers for images; they take the image,
split them in squares that are transformed in vectors (called **embeddings**),  
and then pass them through a Transformer Neural Network.  

![ViT_tokens](./images/ViT_token.png)

A special token, named **[CLS]** token, is added to the sequence of tokens,
whose state at the output of the model serves as an **image representation**.  
It is then used for the downstream tasks, like classification.  
> The [CLS] token is an artificial construction; it doesn't represent anything at the beginning, but his job is to concatenate the
information of the image patches into one **embedding**.

The [model](https://arxiv.org/abs/2404.05022) we are going to used in this part has been
trained on White Blood Cells images. It:
- Uses the Vision Transformer ([ViT](https://arxiv.org/abs/2010.11929)) architecture
- Uses [DINOv2](https://github.com/facebookresearch/dinov2) as training framework
- Was trained on ~300k images from open-sourced datasets
- Contains 4 models of different sizes, ranging from 22M to 1.1B of parameters


To use it, we are going to use [HuggingFace](https://huggingface.co/), a popular
python library for using pretrained models.  
For this you will have to :
- Create a HuggingFace account (or Login if you have one)
- Create a token to get access to the model
- Execute the below cell and enter the created token

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import timm

# Load model from the hub
model = timm.create_model(
  model_name="hf-hub:1aurent/vit_small_patch14_224.dinobloom", # you can change the size of the loaded model here
  pretrained=True,
).eval()

# Get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transform = timm.data.create_transform(**data_config, is_training=False)

print("Below is the list of the layers contained within our model:\n")
model

Here we can see the architecture of the model that we use, which is very long.

Indeed it is made of 12 blocks, each having:
- An Attention layer
- A Mlp layer with one hidden layer
- Normalizing and Scaling layers in between

> Can you find the dimension of the embeddings of the model that you use ? What about its number of elements ?

In [None]:
# Try here, or uncomment and execute the cell below to see the answer

In [None]:
# %load ../src/model_information.py

In [None]:
# Let's visualize a random image from the training set and use the model on it
img, label = train_dataset[random.randint(0, train_size)]

print(f"Image's class is: {label_dict.get(str(label[0]))}\n")
plt.imshow(img)
plt.show()
# Below is the code to use the model on a single image

data = transform(img).unsqueeze(0) # input is a (batch_size, num_channels, img_size, img_size) shaped tensor
output = model(data)
output.shape

### Probe the embeddings quality

Now that we have an easy access to the model embeddings, we can test their quality on the dataset
that we have used before.

Popular technics are:
- k-[Nearest-Neighbour](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier) (k-  NN), with k being usually 1 & 20, which consists on making a prediction according
    to the most similar points of the training dataset  
    (k being the number of neighbours to take into account).
- [Linear Probing](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html),
    which consists of fitting a linear regression on the embeddings, and then
    evaluating its performance.

#### Make the dataset

Calculating the embeddings can be compute intensive, as we use transformer models
having huge amounts of parameters.  
However we only need to do that once, as we use them as mathematical objects, and
**we do not modify** the foumdation model afterwards.  

To do that we will create a Dataset object that will contain the embeddings calculated
for a model on a Dataset (the same as the one we used before).

In [None]:
# We are going to make a new dataset object in order to calculate the embeddings of the images.
# This allow to calculate once the embeddings and then reuse them for later applications
from heidelberg.embedding_dataset import EmbeddingDataset

# Make the embedding dataset and test shapes
emb_train = EmbeddingDataset(training_dataset, model, transform)
emb_test = EmbeddingDataset(testing_dataset, model, transform)

Implement here the evaluation of the above mentionned techniques!  

> Tip: You can use the sklearn library for k-NN and Linear probing

In [None]:
# Try to make function implementing the k-NN and linear probing of the embedding
# Uncomment and execute the cell below to get the answer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.neighbors import KNeighborsClassifier


def k_nearest_neighbor_eval(train_embedding, test_embedding, k=1, target_names=names):
    """
    Define a function that train a k-NN classifier on the training embeddings
    and evaluates it on the validation embeddings.
    """
    # ToDo: Initialize the classifier
    

    # Unpack the embeddings & labels
    train_array = np.array([emb for emb, _ in train_embedding])
    train_labels = np.array([lab for _, lab in train_embedding])

    test_array = np.array([emb for emb, _ in test_embedding])
    test_labels = np.array([lab for _, lab in test_embedding])
    
    # ToDo: Fit the model on train_array and train_labels
    

    # ToDo: Make the predictions on test_array
    preds = None  # Replace with your predictions

    # Calculate the metrics of the predictions
    print(classification_report(test_labels, preds, target_names=target_names))

    return


def linear_probing_eval(train_embedding, test_embedding, target_names=names):
    """
    Define a function that train a k-NN classifier on the training embeddings
    and evaluates it on the validation embeddings.
    """
    # ToDo: Initialize the classifier
    

    # Unpack the embeddings & labels
    train_array = np.array([emb for emb, _ in train_embedding])
    train_labels = np.array([lab for _, lab in train_embedding])

    test_array = np.array([emb for emb, _ in test_embedding])
    test_labels = np.array([lab for _, lab in test_embedding])
    
    # ToDo: Fit the model on train_array and train_labels
    

    # ToDo: Make the predictions on test_array
    preds = None  # Replace with your predictions

    # Calculate the metrics of the predictions
    print(classification_report(test_labels, preds, target_names=target_names))

    return


# Use the functions defined above to evaluate the embeddings
print('1-NN evaluation:\n')
_ = k_nearest_neighbor_eval(emb_train, emb_test, k=1)
print('-' * 75)

print('\n20-NN evaluation:\n')
_ = k_nearest_neighbor_eval(emb_train, emb_test, k=20)
print('-' * 75)

print('\nLinear probing:\n')
_ = linear_probing_eval(emb_train, emb_test)
print('-' * 75)

In [None]:
# %load ../src/embedding_evaluation.py

#### Visualize the embeddings

Visualisation of the embeddings in a 2D (x - y) plan is a good way to check the quality of our embeddings.  
Indeed, if the embeddings we produced are of a good quality, there should be clusters corresponding to each one
of our classes.  

Among popular dimensionality reduction techniques there are :
- [UMAP](https://umap-learn.readthedocs.io/en/latest/basic_usage.html), a popular method for high dimensional biological datasets that captures well the clusters within a dataset
- [t-SNE](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html),
    **ToDo**: find a def of t-SNE


In [None]:
# Try implementing one of the two techniques, and use the plot function below to get the 2D representation
# Uncomment the cell below for the answer

In [None]:
# %load ../src/umap_tsne.py

In [None]:
import pandas as pd
import seaborn as sns

translate_dict = train_dataset.info.get("label")

def plot_embeddings(embs, labels):
    
    # Make a dataframe with the embeddings
    data = pd.DataFrame(
        [
            [e[0], e[1], translate_dict.get(str(lab[0]))[:20]]
            for e, lab in zip(embs, labels)
        ],
        columns=["x", "y", "class"]
    )

    sns.relplot(
        data=data,
        x="x", y="y", hue="class", style="class", height=8, aspect=1.5
    )
    plt.show()

In [None]:
uemb, labels = get_umap(emb_train)
plot_embeddings(uemb, labels)

In [None]:
temb, labels = get_tsne(emb_train)
plot_embeddings(temb, labels)

### Bonus : Attention map

As of now we almost only used the [CLS] token, and left alone the tokens of the image patches.  

Embeddings of an image evolve from one layer to another, and are calculated through the attention mechanism.  
In the attention mechanism new tokens are updated according to their similarity with other tokens (the more similar,
the more their update will be consequent).


As such, one interesting thing to look at is the **attention map of the [CLS] token in the last layer**,
as it kind of indicate which parts of the images are most used for the creation the image's embedding.

In [None]:
from heidelberg.attn import get_attn

patch_h, patch_w = 37, 37
model.cpu()

def plot_attention_map(img, layer, model, save=False):
    img_tensor = transform(img)
    cls_attn = get_attn(img_tensor, model, layer)
    num_heads, _ = cls_attn.shape
    
    cls_tot = torch.sum(cls_attn, dim=0).reshape((patch_h, patch_w))
    
    img_or = np.array(img) 
    # Plot the total attention (summed over the heads)
    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    
    ax[0].imshow(img_or)
    ax[0].axis("off")
    ax[0].set_title("Original image")
    ax[1].imshow(cls_tot)
    ax[1].axis("off")
    ax[1].set_title("Attention map <cls> token")
    
    if save:
        plt.savefig("attention_map_cls.png")
    plt.show()

    # Plot the attention of the heads
    k, l = 3, num_heads // 3 + (num_heads % 3 != 0)
    fig, ax = plt.subplots(l, k, figsize=(8, 4))

    for i in range(l):
        for j in range(k):
            attn_map = cls_attn[3 * i + j, :].reshape((patch_h, patch_w))
            ax[i][j].imshow(attn_map)
            ax[i][j].axis("off")
            ax[i][j].set_title(f"Attention map of Head {3 * i + j}")
    plt.show()

last_layer = [blk for blk in model.blocks][-1]
plot_attention_map(img, last_layer, model)