### Vision-Language Medical Foundation Models

#### 1.1. Introduction, application, VLMs, and Transformers library
---

In this tutorial, we will explore the use of **vision-language models for medical image analysis (medVLMs)**. In particular, we will focus on:

    - Contrastive Image-Text pre-training (CLIP)
    - Zero-shot classification
        - Single prompt
        - Prompt ensemble
    - Few-shot bla
        - Linear Probing
        - CLIP-Adapter
        - Advance Linear Probing techniques
    - Parameter-Efficient Fine-TUning
        - Selective methods
        - Additive methods
    
More concretely, we will explore examples using foundation models specialized on **histology images** (plip [1] /conch [2]). Nevertheless, note that the introduced methodologies are **applicable to ***any*** medVLM**. 

In this tutorial, we will build upon [huggingface](https://huggingface.co/), using the library `transformers`. This library is becoming increasingly popular, and contains an intersting number of tutoriasls/examples. I recommend you to take a look!

Said that, let's start!

First, in this notebook, **we will introduce the application addressed, the main vision-language foundation model employed, and an introduction to Transformers library**.

In [None]:
# General imports
import warnings
warnings.filterwarnings('ignore')

import copy
import torch

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

# Device for training/inference
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Available device: " + device)

#### **Preliminaries: Gleason grading application**
In this notebook, we will explore histology image analysis, and in particular, prostate cancer grading. To do so, we will use SICAPv2 dataset [3]. This dataset contains tissue patches labeled by expert pathologists according to its cancer severity: non-cancerous (NC), Gleason grade 3, (G3), GLeason grade 4 (G4), and grade 5 (G5). This labels are directly correlacted with patient prognosis, and measure the grade of differentiation of the glands in the tussuie, i.e. less diferentation implies worst prognosis.

In [None]:
# Fist, lets visualize few examples for each category.
nc = Image.open("./local_data/datasets/SICAPv2/images/16B0028148_Block_Region_8_2_14_xini_27858_yini_55555.jpg") 
g3 = Image.open("./local_data/datasets/SICAPv2/images/18B0006169B_Block_Region_6_3_7_xini_33958_yini_90365.jpg") 
g4 = Image.open("./local_data/datasets/SICAPv2/images/17B0032153_Block_Region_10_13_17_xini_23105_yini_15687.jpg") 
g5 = Image.open("./local_data/datasets/SICAPv2/images/16B0008067_Block_Region_0_6_2_xini_10859_yini_103113.jpg") 

fig = plt.figure(figsize=(12,12))
plt.subplot(1, 4, 1)
plt.imshow(nc)
plt.axis("off")
plt.title("Non-cancerous")
plt.subplot(1, 4, 2)
plt.imshow(g3)
plt.title("Gleason grade 3")
plt.axis("off")
plt.subplot(1, 4, 3)
plt.imshow(g4)
plt.axis("off")
plt.title("Gleason grade 4")
plt.subplot(1, 4, 4)
plt.imshow(g5)
plt.axis("off")
plt.title("Gleason grade 5")

The objective is perform a multi-class prediction to automatically grade such images, leveraging foundation models. We will employ the train/test splits provided in the dataset. 

In [None]:
# SICAPv2 dataset metadata.
categories = ["NC", "G3", "G4", "G5"]                                        # List of categories.
path_images = "./local_data/datasets/SICAPv2/images/"                        # Folder with the images.
dataframe_train = "./local_data/datasets/SICAPv2/partition/Test/Train.xlsx"  # Dataframe (Table) containing train images names and labels.
dataframe_test = "./local_data/datasets/SICAPv2/partition/Test/Test.xlsx"    # Dataframe (Table) containing test images names and labels.

#### **Preliminaries: PLIP, Transformers library and VLMs structure**

We will use in out main experiments **PLIP**, a **vision-language model specialized on histology image**. Interestingly, this model was pre-trained using Twitter data, by leveraging pathologist's comments on shared cases. You can now more [here](https://www.nature.com/articles/s41591-023-02504-3) [2], or take a look to its dema [here](https://huggingface.co/spaces/vinid/webplip). The architecture follows one of the CLIP options, using **ViT-B/32 and vision encoder, and GPT as text encoder, which are fine-tuned on histology data**.

First, we will start by digging-in the library **Transformers** organization, and how vision-language models are usually implemented there. Even though this library is quite useful to fastly access pre-trained models and perform inference, **we will need to make some adjustments if we want to do some more advance trainings/adaptations**.

In [None]:
# Load model and pre-processing tools from huggingface.
from transformers import AutoProcessor, AutoModel

# In Transformers library, models and versions are storages by an ID defininf the use and model name.
# For PLIP model, such ID is "vinid/plip".
processor = AutoProcessor.from_pretrained("vinid/plip") # pre-processing image and text.
plip = AutoModel.from_pretrained("vinid/plip").eval()   # model with pre-trained weights.
# We set model in eval mode to avoid droput inference and batchnorm stats update in CNNs.

In [None]:
# First, les'ts inspect the pre-processing operations.
print(processor)

In [None]:
# We can see that the Image Processor contains a set of operations such as image resizing,
# intensity normalization, and its requried average and std values. We will only do an small change:
# remove the option "do_center_crop", since we want to keep the entire image.
processor.image_processor.do_center_crop = False

# The text pre-processing class contains a tokenizer to prepare string inputs into a numerical 
# strucute that we can feed into the text encoder. So far, we will leave it as it is.

In [None]:
# PLIP model is not in a very friendly format for using it. Since we want to operate separately
# with the vision and text encoders, we will separate both backbones, and incorporate their
# projections into an l2-norm space. 

class PLIPWrapper(torch.nn.Module):
    def __init__(self, encoder, proj_layer):
        super().__init__()
        
        self.encoder = encoder         # Take one-modality encoder from VLM.
        self.proj_layer = proj_layer   # Take projection layer into joint embedding space.

    def forward(self, inputs):
        
        # Forward input trough encoder
        features = self.encoder(**inputs).pooler_output # Forward trough encoder - we keep a global feature for the image/text
        
        # Project features
        projected_features = self.proj_layer(features)  # Apply projection
        
        # Ensure image features are l2-norm
        projected_features = projected_features / projected_features.norm(dim=-1, keepdim=True) # l2-normalization

        return projected_features

# Create model wrapper for vision and text encoders
vision_encoder = PLIPWrapper(plip.vision_model, plip.visual_projection)
text_encoder = PLIPWrapper(plip.text_model, plip.text_projection)

In [None]:
# Lets inspect the outputs on one image

# Read an image (G5) and convert to array
im = Image.open("./local_data/datasets/SICAPv2/images/16B0008067_Block_Region_0_6_2_xini_10859_yini_103113.jpg") 
im = np.array(im)
print("Raw image shape:")
print(str(list(im.shape)))
print("", end="\n")

# Pre-processs the image, i.e.: resize, channel transpose, intensity normalization, etc.
inputs = processor.image_processor(im, return_tensors="pt")
# The "inputs" in this case are pixel values of input image
print("Analyzing the input image.")
print("   Elements after pre-processing:")
print("   " + str(inputs.keys()))
print("   Pre-processed image shape:")
print("   " + str(list(inputs['pixel_values'].shape)))
print("", end="\n")

# Forward image trough vision encoder
with torch.no_grad():
    vision_features = vision_encoder(inputs)

# Let's check the characteristic of the feature representation
print("Analyzing the output of the vision encoder.")
print("   Vision embedding shape:")
print("   " + str(list(vision_features.shape)))
print("   Norm of the vector:")
print("   " + str(vision_features.norm(dim=-1)))

In [None]:
# Lets inspect the outputs on one text prompt

# Define text input
prompt = ["Healthy prostate tissue",
          "Non cancerous prostate tissue"
          "a high resolution medical image",
          "Gleason grade 4",
          "histology tissue with ill-defined glands",
          "Gleason grade 5"]

# Tokenize text (padding will help to work over batches of texts in forward pass)
inputs = processor.tokenizer(prompt, max_length = 77, padding=True, truncation=True, return_tensors="pt") 

# Inspect inputs
print("Analyzing the input image.")
print("   Elements after pre-processing:")
print("   " + str(inputs.keys())) # Check to know more about the keys: https://lukesalamone.github.io/posts/what-are-attention-masks/
print("   Input embeddings:")
print("   " + str(list(inputs['input_ids'].shape)))
print("   Attention masks:")
print("   " + str(list(inputs['input_ids'].shape)))
print("", end="\n")

# Forward image trough text encoder
with torch.no_grad():
    text_features = text_encoder(inputs)

# Let's check the characteristic of the feature representation
print("Analyzing the output of the text encoder.")
print("   Text embedding shape:")
print("   " + str(list(text_features.shape)))
print("   Norm of the vector:")
print("   " + str(text_features.norm(dim=-1)))


In [None]:
# Finally, we can compute similarities between image and text embeddings
# using dot product. Remember, since they are l2-normalized, such similarities
# are equivalent to computing the cosine similarity.

# We multiply by the pre-trained temperature scaling, which calibrated the
# similarity considered "high" or "low" (e.g. a cosine similarity of 0.2)
# could already mean that both imports are semantically similar.

# Cosine similarities
with torch.no_grad():
    sim = vision_features @ text_features.t()
print("Cosine similarities: " + str(sim))

# Tempeature-calibrated similarities
with torch.no_grad():
    sim = vision_features @ text_features.t() * plip.logit_scale.exp().item()
print("logits: " + str(sim))

# Softmax outputs
with torch.no_grad():
    prob = torch.softmax(vision_features @ text_features.t() * plip.logit_scale.exp(), axis=-1)
print("softmax: " + str(prob))

# Index of predicted category
print("Predicted category: " + prompt[torch.argmax(prob, -1).item()])

Now that we are familiar with the structure of a VLM, and the pre-processing tools they require for each modality, we will move to make a toy example on how these models are pre-trained.

--- 
## **References**

[1] Silva-Rodríguez, J., Colomer, A., Sales, M. A., Molina, R., & Naranjo, V. (2020). Going deeper through the Gleason scoring scale : An automatic end-to-end system for histology prostate grading and cribriform pattern detection. Computer Methods and Programs in Biomedicine. \
[2] Huang Z, Bianchi F, Yuksekgonul M, Montine TJ, Zou J. (2023). A visual-language foundation model for pathology image analysis using medical Twitter. Nature Medicine. \
[3] Lu, M.Y., Chen, B., Williamson, D.F.K. et al. (2024) A visual-language foundation model for computational pathology. Nature Medicine.

--- 