### **Vision-Language Medical Foundation Models**

#### **1.3. Zero-shot classification**
---

Pre-trained vision-language models are capable or performing the so-called **zero-shot predictions**. These image-level predictions are driven by the language encoder, thanks to the multi-modal alignment. **Given a set of descriptions for a subset of target categories, we can compute text prototypes of each category** in the common space. Given a new image, **the class assigned will be the one corresponding to the most similar text prototype** to the image embedding.

In this notebook, we will explore two popular forms to perform zero-shot predictions: single prompt, and prompt ensemble.


In [None]:
# General imports
import warnings
warnings.filterwarnings('ignore')

import copy
import torch

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

# Device for training/inference
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Available device: " + device)

#### **Dataset details**

In [None]:
# SICAPv2 dataset metadata
categories = ["NC", "G3", "G4", "G5"]                                        # List of categories
path_images = "./local_data/datasets/SICAPv2/images/"                        # Folder with the images
dataframe_train = "./local_data/datasets/SICAPv2/partition/Test/Train.xlsx"  # Dataframe (Table) containing train images names and labels
dataframe_test = "./local_data/datasets/SICAPv2/partition/Test/Test.xlsx"    # Dataframe (Table) containing test images names and labels

#### **Load model, pre-processing, and wrapper**

In [None]:
# Load model and pre-processing tools from huggingface
from transformers import AutoProcessor, AutoModel

# In Transformers library, models and versions are storages by an ID defininf the use and model name.
# For PLIP model, such ID is "vinid/plip"
processor = AutoProcessor.from_pretrained("vinid/plip") # pre-processing image and text
processor.image_processor.do_center_crop = False
plip = AutoModel.from_pretrained("vinid/plip").eval() # model with pre-trained weights
# We set model in eval mode to avoid droput inference and batchnorm stats update in CNNs

In [None]:
# Again, we will use our PLIP Wrapper for easy interaction
class PLIPWrapper(torch.nn.Module):
    def __init__(self, encoder, proj_layer):
        super().__init__()
        
        self.encoder = encoder         # Take one-modality encoder from VLM.
        self.proj_layer = proj_layer   # Take projection layer into joint embedding space.

    def forward(self, inputs):
        
        # Forward input trough encoder
        features = self.encoder(**inputs).pooler_output # Forward trough encoder - we keep a global feature for the image/text
        
        # Project features
        projected_features = self.proj_layer(features)  # Apply projection
        
        # Ensure image features are l2-norm
        projected_features = projected_features / projected_features.norm(dim=-1, keepdim=True) # l2-normalization

        return projected_features

# Create model wrapper for vision and text encoders
vision_encoder = PLIPWrapper(plip.vision_model, plip.visual_projection)
text_encoder = PLIPWrapper(plip.text_model, plip.text_projection)

#### **Feature extraction**
First, we need to extract all feature representations from the test subset.

In [None]:
# To run over the whole dataset, we move now the vision model to gpu
vision_encoder.to(device)

# We need to format the pre-processing transforms in a more friendly format for torch Dataloaders
from torchvision import transforms

# Pre-processing transforms to apply during data loading.
plip_transforms = transforms.Compose(
    [
    transforms.ToTensor(),                                                 # Move PIL/array image to tensor
    transforms.Normalize(std=processor.image_processor.image_std,
                         mean=processor.image_processor.image_mean),       # Intensity normalization
    transforms.Resize(list(processor.image_processor.crop_size.values()))  # Resize to pre-trained resolution
    ])

In [None]:
# Create dataloader with the whole testing subset
from vlms.data import loader # Take a look to this function to create the data loader
test_loader = loader(dataframe_path=dataframe_test, path_images=path_images, categories=categories,
                     transforms=plip_transforms, batch_size=8, num_workers=0)

# We can check the dataset format and available samples
print("Samples available for testing: " + str(len(test_loader.dataset.data)))
print(test_loader.dataset.data[0])

In [None]:
# Extract features
from vlms.utils import extract_features
X_test, Y_test = extract_features(test_loader, vision_encoder)

# Lets check the training dataset
print("Test features: " + str(X_test.shape))
print("Test labels: " + str(Y_test.shape))

#### **1.3.1. Single prompt**

One the vision-language is pre-trained, **vision and image feature spaces are aligned**. This is, a **text description of a concept should produce a similar representation in the shared embedding space than an image of such category**. This phenomenon is profited to perform **zero-shot prediction without any adaptation** of the VLM. **Class-wise embeddings (class prototypes)** are computed from text descriptions of each category. Such prototypes for C categories and Dt features can be embeded into a feature matrix W. **Note that his is similar to a Linear layer in a classical MLP output (without bias term)**. Thus, **class predictions (logits) can be computed fron vision features Fv by performing matrix multiplication**, as in a fully-connected layer: Fv x transpose(W) = out -> (1xDv) x (DvxC) = (1xC)

In [None]:
# Define text input
class_prompts = ["non-cancerous",
                 "Gleason grade 3",
                 "Gleason grade 4",
                 "Gleason grade 5"]

# Tokenize text
inputs = processor.tokenizer(class_prompts, max_length = 77, padding=True, truncation=True, return_tensors="pt") 

# Compute text protoypes per class
with torch.no_grad():
    class_prototypes = text_encoder(inputs)

In [None]:
# Compute predictions
with torch.no_grad():
    prob = torch.softmax(torch.tensor(X_test) @ class_prototypes.t() * plip.logit_scale.exp(), axis=-1)
    prob = prob.detach().numpy()
    
print("Prediction shape: " + str(prob.shape))
print("Example: " + str(prob[0,:]))

In [None]:
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))

#### **1.3.2. Prompt ensemble**

A popular option to refine the text prototype, is to **combine multiple prompts (prompt ensemble)**, which are averaged per class. Thus, **noisy features in an specific prompt are alleviated**, and usually, performance is improved. Such prompt ensemble comes usually from using **different templates**, i.e. "A photo of [CLS]", and **different descriptions** of the target class.

In [None]:
# Ensemble of templates
templates = ["a histopathology slide showing [CLS]", "histopathology image of [CLS]",
             "pathology tissue showing [CLS]", "presence of [CLS] tissue on image"]

# Category-wise descriptions, which are more informative than category names. For instance, "atrophic dense glands" better 
# describes the local findings associated with Gleason grade 3.
prompts_dict = {"NC": ["benign glands"],
                "G3": ["atrophic dense glands"],
                "G4": ["cribriform ill-formed fused papillary patterns"],
                "G5": ["isolated nest cells without lumen roseting patterns"]}

# Combine all paired options of templates and descriptions
prompts = {}
for iCategory in categories:
    prompts[iCategory] = [caption.replace("[CLS]", iDescription) for iDescription in prompts_dict[iCategory]
                          for caption in templates]

# Compute embeddings per category
class_prototypes = []
for iKey in range(len(categories)):
    with torch.no_grad():
        # Retrieve descriptions for that particular category
        descriptions = prompts[categories[iKey]]
        # Tokenize text
        inputs = processor.tokenizer(descriptions, max_length = 77, padding=True, truncation=True, return_tensors="pt") 
        # Forward text encoder
        text_features_ensemble = text_encoder(inputs)
        # Get class prototypes as average of all text prompts
        avg_text_features = text_features_ensemble.mean(0).unsqueeze(0)
        # Re-normalize embedding
        avg_text_features = avg_text_features / avg_text_features.norm(dim=-1, keepdim=True)
        class_prototypes.append(avg_text_features)
                               
# Concatenate all class prototypes
class_prototypes = torch.concat(class_prototypes, dim=0)

In [None]:
# Compute predictions
with torch.no_grad():
    prob = torch.softmax(torch.tensor(X_test) @ class_prototypes.t() * plip.logit_scale.exp(), axis=-1)
    prob = prob.detach().numpy()
    
print("Prediction shape: " + str(prob.shape))
print("Example: " + str(prob[0,:]))

In [None]:
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))
# As you can see, prompt ensemble boost the performance! +6.7% balanced accuracy

--- 
##### **ACTIVITY**

Well, now you know everything you need about zero-shot predictions. If you want to know more, I reccomend:

- Try developing the same pipeline for [CONCH](https://huggingface.co/MahmoodLab/CONCH) [1], a revently introduced VLM for histology. Its vision backbone is large scale, takes large-resolution input images, and is pre-trained with more data. How does model scaling translates to black-box Adaptation?
- Indeed, CONCH also uses a different subset of templates for zero-shot in Gleason grades (which I share below). How the prompt selection affect performance? How realistic is designing prompts to optimize test performance?


In [None]:
# Text ensemble for CONCH
templates = ["[CLS].",
            "a photomicrograph showing [CLS].",
            "a photomicrograph of [CLS].",
            "an image of [CLS].",
            "an image showing [CLS].",
            "an example of [CLS].",
            "[CLS] is shown.",
            "this is [CLS].",
            "there is [CLS].",
            "a histopathological image showing [CLS].",
            "a histopathological image of [CLS].",
            "a histopathological photograph of [CLS].",
            "a histopathological photograph showing [CLS].",
            "shows [CLS].",
            "presence of [CLS].",
            "[CLS] is present.",
            "an H&E stained image of [CLS].",
            "an H&E stained image showing [CLS].",
            "an H&E image showing [CLS].",
            "an H&E image of [CLS].",
            "[CLS], H&E stain.",
            "[CLS], H&E."]

prompts_dict = {"NC": ["non-cancerous tissue", "non-cancerous prostate tissue", "benign tissue", "benign glands", 
                       "benign prostate glands", "benign prostate tissue"],
                "G3": ["gleason grade 3", "gleason pattern 3", "prostate cancer, gleason grade 3", 
                       "prostate cancer, gleason pattern 3", "prostate adenocarcinoma, well-differentiated",
                       "well-differentiated prostatic adenocarcinoma"],
                "G4": ["gleason grade 4", "gleason pattern 4", "prostate cancer, gleason grade 4", 
                       "prostate cancer, gleason pattern 4", "prostate adenocarcinoma, moderately differentiated",  
                       "moderately differentiated prostatic adenocarcinoma"],
                "G5": ["gleason grade 5", "gleason pattern 5", "prostate cancer, gleason grade 5",
                       "prostate cancer, gleason pattern 5", "prostate adenocarcinoma, poorly differentiated",
                       "poorly differentiated prostatic adenocarcinoma"]}


--- 
## **References**

[1] Lu, M.Y., Chen, B., Williamson, D.F.K. et al. (2024) A visual-language foundation model for computational pathology. Nature Medicine.

--- 