### **Vision-Language Medical Foundation Models**

#### **1.4. Few-shot black-box Adapters**
---

**Objective**: Given an small set of examples per category, we want to efficiently use the vision features to perform classification on a downstream task, without fine-tuning the base model.

**Few-shot**: We only use K number of images for each new category.

**Why black-box Adapters?**: They are efficient, usually run over CPU. They are fast: you can transfer the model in a matter of minutes. They are backbone-agnostic, this is, they work the same over any vision encoder.

In [None]:
# General imports
import warnings
warnings.filterwarnings('ignore')

import copy
import torch
import random

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

# Device for training/inference
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Available device: " + device)

# Seeds for reproducibility
def set_seeds(seed_value, use_cuda):
    np.random.seed(seed_value)     # cpu vars
    torch.manual_seed(seed_value)  # cpu  vars
    random.seed(seed_value)        # Python
    if use_cuda:
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)     # gpu vars
        torch.backends.cudnn.deterministic = True  # needed
        torch.backends.cudnn.benchmark = False

set_seeds(42, use_cuda=torch.cuda.is_available())

#### **Dataset details**

In [None]:
# SICAPv2 dataset metadata
categories = ["NC", "G3", "G4", "G5"]                                        # List of categories
path_images = "./local_data/datasets/SICAPv2/images/"                        # Folder with the images
dataframe_train = "./local_data/datasets/SICAPv2/partition/Test/Train.xlsx"  # Dataframe (Table) containing train images names and labels
dataframe_test = "./local_data/datasets/SICAPv2/partition/Test/Test.xlsx"    # Dataframe (Table) containing test images names and labels

#### **VLM model wrapper**

In [None]:
# Load model and pre-processing tools from huggingface
from transformers import AutoProcessor, AutoModel

# In Transformers library, models and versions are storages by an ID defininf the use and model name.
# For PLIP model, such ID is "vinid/plip"
processor = AutoProcessor.from_pretrained("vinid/plip") # pre-processing image and text
processor.image_processor.do_center_crop = False
plip = AutoModel.from_pretrained("vinid/plip").eval() # model with pre-trained weights
# We set model in eval mode to avoid droput inference and batchnorm stats update in CNNs

In [None]:
# Again, we will use our PLIP Wrapper for easy interaction
class PLIPWrapper(torch.nn.Module):
    def __init__(self, encoder, proj_layer):
        super().__init__()
        
        self.encoder = encoder         # Take one-modality encoder from VLM.
        self.proj_layer = proj_layer   # Take projection layer into joint embedding space.

    def forward(self, inputs):
        
        # Forward input trough encoder
        features = self.encoder(**inputs).pooler_output # Forward trough encoder - we keep a global feature for the image/text
        
        # Project features
        projected_features = self.proj_layer(features)  # Apply projection
        
        # Ensure image features are l2-norm
        projected_features = projected_features / projected_features.norm(dim=-1, keepdim=True) # l2-normalization

        return projected_features

# Create model wrapper for vision and text encoders
vision_encoder = PLIPWrapper(plip.vision_model, plip.visual_projection)
text_encoder = PLIPWrapper(plip.text_model, plip.text_projection)

#### **Test features extraction**
First, we need to extract all feature representations from the test subset.

In [None]:
# To run over the whole dataset, we move now the vision model to gpu
vision_encoder.to(device)

# We need to format the pre-processing transforms in a more friendly format for torch Dataloaders
from torchvision import transforms

# Pre-processing transforms to apply during data loading.
plip_transforms = transforms.Compose(
    [
    transforms.ToTensor(),                                                 # Move PIL/array image to tensor
    transforms.Normalize(std=processor.image_processor.image_std,
                         mean=processor.image_processor.image_mean),       # Intensity normalization
    transforms.Resize(list(processor.image_processor.crop_size.values()))  # Resize to pre-trained resolution
    ])

In [None]:
# Create dataloader with the whole testing subset
from vlms.data import loader
test_loader = loader(dataframe_path=dataframe_test, path_images=path_images, categories=categories,
                     transforms=plip_transforms, batch_size=8, num_workers=0)

# We can check the dataset format and available samples
print("Samples available for testing: " + str(len(test_loader.dataset.data)))
print(test_loader.dataset.data[0])


In [None]:
# Extract features
from vlms.utils import extract_features
X_test, Y_test = extract_features(test_loader, vision_encoder)

# Lets check the training dataset
print("Test features: " + str(X_test.shape))
print("Test labels: " + str(Y_test.shape))

#### **Compute text prototypes**
(We will need them latter for classification head initialization)

In [None]:
# Ensemble of templates
templates = ["a histopathology slide showing [CLS]", "histopathology image of [CLS]",
             "pathology tissue showing [CLS]", "presence of [CLS] tissue on image"]

# Category-wise descriptions, which are more informative than category names. For instance, "atrophic dense glands" better 
# describes the local findings associated with Gleason grade 3.
prompts_dict = {"NC": ["benign glands"],
                "G3": ["atrophic dense glands"],
                "G4": ["cribriform ill-formed fused papillary patterns"],
                "G5": ["isolated nest cells without lumen roseting patterns"]}

# Combine all paired options of templates and descriptions
prompts = {}
for iCategory in categories:
    prompts[iCategory] = [caption.replace("[CLS]", iDescription) for iDescription in prompts_dict[iCategory]
                          for caption in templates]

# Compute embeddings per category
class_prototypes = []
for iKey in range(len(categories)):
    with torch.no_grad():
        # Retrieve descriptions for that particular category
        descriptions = prompts[categories[iKey]]
        # Tokenize text
        inputs = processor.tokenizer(descriptions, max_length = 77, padding=True, truncation=True, return_tensors="pt") 
        # Forward text encoder
        text_features_ensemble = text_encoder(inputs)
        # Get class prototypes as average of all text prompts
        avg_text_features = text_features_ensemble.mean(0).unsqueeze(0)
        # Re-normalize embedding
        avg_text_features = avg_text_features / avg_text_features.norm(dim=-1, keepdim=True)
        class_prototypes.append(avg_text_features)
                               
# Concatenate all class prototypes
zero_shot_prot = torch.concat(class_prototypes, dim=0)

#### **Few-shot training dataset**

In [None]:
from vlms.data import few_shot_loader  # Take a look to this new function. We will randomly retrieve few samples for each class.
shots, seed = 16, 1 # Define the number of shots per class for training and set reproducibility seed
# Set data loader
train_loader = few_shot_loader(dataframe_path=dataframe_train, path_images=path_images, categories=categories, transforms=plip_transforms,
                               shots=shots, batch_size=32, num_workers=0, seed=seed)
# Extract features
X_train, Y_train = extract_features(train_loader, vision_encoder)

---
#### **1.4.1. Linear Probing**

**The most straightforward adaptation strategy is training a logistic regression classifier, to learn new class prototypes using the few available shots**. This method is commonly called **Linear Probe** in the literature, and is employed to compare the transferability of pre-trained models. This method was the explored strategy on the seminal CLIP publication [1]. In the following, we implement and train such a strategy.

In [None]:
# Our Adapter class will be a module composed by: initialization, forward definition, and loss computing.
# Init: we store the logit scale, and initialize a learnable set ot class prototypes.
# Forward: compute softmax cosine similarity between current weights and input features.
# loss: We minimize the categorical cross entropy as objective function during training.

class LinearProbe(torch.nn.Module):
    def __init__(self, input_features, number_classes, logit_scale):
        super().__init__()
        self.logit_scale = logit_scale
        self.logit_scale.requires_grad = False
        self.prototypes = torch.nn.Parameter(
        torch.nn.init.kaiming_normal_(torch.empty((number_classes, input_features))))
        # move to device
        self.to(device)

    def forward(self, features):

        # Get trained prototype
        prototypes = self.prototypes.to(device)

        # l2-normalized trained weights
        prototypes_norm = prototypes / prototypes.norm(dim=-1, keepdim=True)

        # temparature-scaled similarity per class
        logit_scale = self.logit_scale.exp()
        logits = features @ prototypes_norm.t() * logit_scale

        return logits
    
    def loss(self, logits, y):
        loss = torch.nn.functional.cross_entropy(logits, y)
        return loss
    

In [None]:
# Create the instance of linear probe with the number of features used and number of classes for prototypes.
lp_adapter = LinearProbe(input_features=X_train.shape[-1], number_classes=len(categories),
logit_scale=plip.logit_scale.detach().clone()) # Also, we need the same temperature scaling as in pre-training!

In [None]:
# We train our adapter using few-shot data
from vlms.utils import train_adapter # Take a look to this function, to train in mini-batches the Adapter.
epochs, batch_size, learning_rate = 100, 32, 0.001 # Define training hyper-parameters
optimizer = torch.optim.SGD(lp_adapter.parameters(), lr=learning_rate, momentum=0.9) # Define optimizer
train_adapter(X_train, Y_train, lp_adapter, optimizer, batch_size, epochs) # Train adapter

In [None]:
# Now, we can test de resultant Adapter on test data. Since Adapters are light, we can do
# a full-batch forward pass on test data.
with torch.no_grad():
    prob = torch.softmax(lp_adapter(torch.tensor(X_test).to(device)), axis=-1).cpu().numpy()
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))

##### **ACTIVITY**
First, try training the Adapter with 16 shots per class, and then decrease the number to 1 shot: what do you observe? **Which is the improvement of 1-shot adaptation with respect to zero-shot?** - **Try using different seeds**. In some datasets, the performance in the low-shot regime (k<4) was below zero-shot. Recently, Adapters that consider also text information, beyond randomly-initialized Linear Probe, have been considered to solve this issue.

------
#### **1.4.2. CLIP-Adapter**

The basic Linear Probe showed improvements detriment with respect to zero-shot in the initial studies. Since Linear Probe does not profit the text knowledge, some works explored more advanced options for black-box adaptation.

Concretely, CLIP-Adapter [2] proposed to keep the text embeddings as text prototypes, and residually modify the vision features to approxiate the representations for their corresponding category. This residual modification is driven by a low-rank mlp arquitecture, and a blending hyper-parameter that controls how far you deviate from the initial representations. Concretely, CLIP-Adapter mlp module consist of:

$$v' = (1-alpha) \cdot v+alpha \cdot mlp(v)$$


In [None]:
class CLIPAdapter(torch.nn.Module):
    def __init__(self, zero_shot_prot, logit_scale, alpha=0.5):
        super().__init__()
        self.logit_scale = logit_scale
        self.logit_scale.requires_grad = False
        self.zero_shot_prot = zero_shot_prot.clone().to(device) # Since it is not a parameter, we need to move it to the device ourselves
        # The mlp residual Adapter that modifies the vision features:
        self.mlp = torch.nn.Sequential(torch.nn.Linear(zero_shot_prot.shape[-1], 4, bias=False),
                                       torch.nn.ReLU(inplace=True),
                                       torch.nn.Linear(4, zero_shot_prot.shape[-1], bias=False),
                                       torch.nn.ReLU(inplace=True),)
        # Alpha value for blending zero-shot and learned information on few shots.
        self.alpha = alpha
        # move to device
        self.to(device)

    def forward(self, features):
        
       # Residual adapter features: weighted residual modification of original features
        features = (1-self.alpha) * features + self.alpha * self.mlp(features)

        # Normalize output of feature adaptation and class prototype into an l2-norm space
        image_features_norm = features / features.norm(dim=-1, keepdim=True)
        prototypes_norm = self.zero_shot_prot / self.zero_shot_prot.norm(dim=-1, keepdim=True)

        # Logits: note that we keep the text prototypes as they were
        logit_scale = self.logit_scale.exp()
        logits = image_features_norm @ prototypes_norm.t() * logit_scale
        
        return logits
    
    def loss(self, logits, y):
        loss = torch.nn.functional.cross_entropy(logits, y)
        return loss

In [None]:
# Create the instance of CLIP-Adapter with a concrete alpha value
alpha = 0.5
CLIPAd_adapter = CLIPAdapter(zero_shot_prot=zero_shot_prot, logit_scale=plip.logit_scale.detach().clone())

In [None]:
# We train our adapter using few-shot data
from vlms.utils import train_adapter # Take a look to this function, to train in mini-batches the Adapter.
epochs, batch_size, learning_rate = 100, 32, 0.001 # Define training hyper-parameters
optimizer = torch.optim.SGD(CLIPAd_adapter.parameters(), lr=learning_rate, momentum=0.9)
train_adapter(X_train, Y_train, CLIPAd_adapter, optimizer, batch_size, epochs)

In [None]:
# Now, we can test de resultant Adapter on test data. Since Adapter are light, we can do
# a full-batch forward pass on test data.
with torch.no_grad():
    prob = torch.softmax(CLIPAd_adapter(torch.tensor(X_test).to(device)), axis=-1).cpu().numpy()
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))

##### **NOTE**
As you can see, **CLIP-Adapter prevents the performance drop when K=1**. Now, try using **different number of shots, and values for the alpha hyper-parameter**. What limitations do you observe? **How can you properly fix alpha value in a few-shot setting**, without using test data feedback? As demonstrated in [3], the value of this configuration is dataset-dependant, which is unrealistic in the few-shot setting.

------
#### **1.4.3. Zero-shot initialized Linear Probe**

Motivated by the **absence of model selection strategies in CLIP-Adapter and other methods**, the work in [3] revisits few-shot adaptation of vision-language models.

Concretely, one observation is that **the limited performance of Linear Probing was explained by the random initialization of the new class prototypes**. Indeed, employing the **text-driven class prototypes as initial trained weights is competitive with more convoluted methods.

Lets train ZS-LP [3], this well-initialized Linear Probe. Remember checking the 1-shot case, and compare the performance with respect to the previously introduced Linear Probing.


In [None]:
# Remember the shape of the text-driven prototypes: classes x features.
print(zero_shot_prot.shape)

In [None]:
# This Linear Probing class takes zero-shot prototypes as input to initialize the prototypes
class ZSLinearProbe(torch.nn.Module):
    def __init__(self, zero_shot_prot, logit_scale):
        super().__init__()
        # We keep the same temperature scaling, but we do not train it any more
        self.logit_scale = logit_scale
        self.logit_scale.requires_grad = False
        # Initialize prototypes with zero-shot weights
        self.prototypes = torch.nn.Parameter(zero_shot_prot.clone())
        # move to device
        self.to(device)

    def forward(self, features):

        # Get trained prototype
        prototypes = self.prototypes.to(device)

        # l2-normalized trained weights
        prototypes_norm = prototypes / prototypes.norm(dim=-1, keepdim=True)

        # temparature-scaled similarity per class
        logit_scale = self.logit_scale.exp()
        logits = features @ prototypes_norm.t() * logit_scale

        return logits
    
    def loss(self, logits, y):
        loss = torch.nn.functional.cross_entropy(logits, y)
        return loss

In [None]:
# Create the instance of linear probe initialized with zero-shot weights
zslp_adapter = ZSLinearProbe(zero_shot_prot=zero_shot_prot, logit_scale=plip.logit_scale.detach().clone())

In [None]:
# We train our adapter using few-shot data
from vlms.utils import train_adapter # Take a look to this function, to train in mini-batches the Adapter.
epochs, batch_size, learning_rate = 100, 32, 0.001 # Define training hyper-parameters
optimizer = torch.optim.SGD(zslp_adapter.parameters(), lr=learning_rate, momentum=0.9)
train_adapter(X_train, Y_train, zslp_adapter, optimizer, batch_size, epochs)

In [None]:
# Now, we can test de resultant Adapter on test data. Since Adapter are light, we can do
# a full-batch forward pass on test data.
with torch.no_grad():
    prob = torch.softmax(zslp_adapter(torch.tensor(X_test).to(device)), axis=-1).cpu().numpy()
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))

------
#### **1.4.4. Class-Adaptive Linear Probing (CLAP)**

Finally, we will check the **state-of-the-art method for black-box adaptation of vision-language models, CLAP [3]**. The motivation of CLAP is adaptively **retain the robust zero-shot prototypes when updating the new ones using few-shots**. The idea is quite straightforward: **if the zero-shot prototypes performs well, why would you want to go far from it?** This is solved by constraining the learned prototypes to stay close to the initial solution. The overall loss function is defined as follows:

$$
\phantom{.}\min_{\mathcal{W}}  \quad
\sum\limits_{i \in \mathcal{S}} \mathcal{H}_{ce}({\mathbf{y}^{(i)},\hat{\mathbf{y}}^{(i)}}) +
\sum_{c=1}^C \lambda_{c} \; ||\mathbf{t}_c - \mathbf{w}_{c}||_{2}^{2}
$$

where $\mathcal{H}_{ce}$ is cross-entropy loss on predictions with the learned prototypes, and the second term is an **l2-penalty**, which provides **large values if the learned prototypes, $\mathbf{w}$, deviate from the text prototypes, $\mathbf{t}_c$**. Note that we are minimizing this loss function, so we will also minimize the penalty/deviation. Importantly, this is done class-wise (for each $c$ of the $C$ categories).

The zero-shot prototypes might be stronger for one categories than for others. To consider this, **CLAP uses class-wise weights that control the relevance of the penalty**. This weight is directly **estimated by quantifying the quality of these prototypes on support data** (few-shots) before training.

You can know more about this and balck-box Adapters here: [https://github.com/jusiro/CLAP](https://github.com/jusiro/CLAP).


In [None]:
class CLAP(torch.nn.Module):
    def __init__(self, zero_shot_prot, logit_scale):
        super().__init__()
        # We use the same temperature scaling value as in pre-training
        self.logit_scale = logit_scale
        self.logit_scale.requires_grad = False
        # Trained weights W, which we initialize with zero-shot weights.
        self.prototypes = torch.nn.Parameter(zero_shot_prot.clone())
        # Zero-shot prototypes, t, which we will use as anchor for the penalty term.
        self.anchors = torch.nn.Parameter(zero_shot_prot.clone())
        self.anchors.requires_grad = False
        # Init penaly weights (we will initialize them latter, once we get train data)
        self.lambdas = torch.zeros((zero_shot_prot.shape[0])).to(device)
        # move to device
        self.to(device)
    
    def init_multipliers(self, X_train, Y_train):
        # Function to compute the initial multipliers value:
        # 1. Get predictions (softmax outputs) from train data
        # 2. Take average softmax value for each category
        # (Idea): larger avg. softmax for the correct category, better the model is and larger the
        # penalty if you deviate (lambda -> 1). The lower the softmax output in average, the worsr
        # the zero-shot prototype for this category is, and the lower the penalty if you deviate 
        # (lambda -> 0)
        
        
        with torch.no_grad():
            # Move to device inputs
            X_train = torch.tensor(X_train).to(device)
            Y_train = torch.tensor(Y_train).to(device)

            # Compute logits in train data
            logits = self.forward(X_train)

            # Pass Y_train to one-hot to compute average [3] -> [0, 0, 0, 1]
            labels_one_hot = torch.nn.functional.one_hot(Y_train)

            # Estimate the quality of the zero-shot protoypes per class / average per class
            anchors_q = torch.diag(torch.softmax(logits, -1).t() @ labels_one_hot.to(torch.float32)) / \
                    labels_one_hot.sum(0)

            # Init new lambdas
            self.lambdas = torch.clone(anchors_q).to(device)

    def forward(self, features):
        # Note that the forward pass is the same as in Linear Probing!!
        
        # Get trained prototype
        prototypes = self.prototypes.to(device)

        # l2-normalized trained weights
        prototypes_norm = prototypes / prototypes.norm(dim=-1, keepdim=True)

        # temparature-scaled similarity per class
        logit_scale = self.logit_scale.exp()
        logits = features @ prototypes_norm.t() * logit_scale

        return logits
    
    def loss(self, logits, y):
        # Cross-entropy on labels and predictions
        ce_loss = torch.nn.functional.cross_entropy(logits, y)
        # L2-penalty (distance between vectors) for base class prototypes
        penalty = (self.prototypes - self.anchors).pow(2).sum(-1)
        # Weight with class-wise multipliers
        weighted_penalty = torch.mean(self.lambdas * penalty)
        # Compute overall loss as the sum of both terms
        loss = ce_loss + weighted_penalty
        return loss

In [None]:
# Create the instance of linear probe initialized with zero-shot weights
clap_adapter = CLAP(zero_shot_prot=zero_shot_prot, logit_scale=plip.logit_scale.detach().clone())
# Init multipliers
clap_adapter.init_multipliers(X_train, Y_train)
print("Lambda multipliers: " + str(clap_adapter.lambdas))

In [None]:
# We train our adapter using few-shot data
from vlms.utils import train_adapter # Take a look to this function, to train in mini-batches the Adapter.
epochs, batch_size, learning_rate = 100, 32, 0.001 # Define training hyper-parameters
optimizer = torch.optim.SGD(clap_adapter.parameters(), lr=learning_rate, momentum=0.9)
train_adapter(X_train, Y_train, clap_adapter, optimizer, batch_size, epochs)

In [None]:
# Now, we can test de resultant Adapter on test data. Since Adapters are light, we can do
# a full-batch forward pass on test data.
with torch.no_grad():
    prob = torch.softmax(clap_adapter(torch.tensor(X_test).to(device)), axis=-1).cpu().numpy()
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))

--- 
##### **ACTIVITY**

Well, now you know everything you need about black-box Adapters. If you want to know more, I reccomend:

- Try different random seeds, since few-shot transferability might present large variability in performance depending on the chosen samples.
- Try developing the same pipeline for [CONCH](https://huggingface.co/MahmoodLab/CONCH) [4], a revently introduced VLM for histology. Its vision backbone is large scale, takes large-resolution input images, and is pre-trained with more data. How does model scaling translates to black-box Adaptation?


--- 
## **References**


[1] Radford, A., Kim, J.W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark, J., Krueger, G., & Sutskever, I. (2021). Learning Transferable Visual Models From Natural Language Supervision. International Conference on Machine Learning. \
[2] Gao, P., Geng, S., Zhang, R. et al. (2024). CLIP-Adapter: Better Vision-Language Models with Feature Adapters. Int J Comput Vis. \
[3] Silva-Rodriguez, J., Hajimiri, S., Ben Ayed, I., Dolz, J. (2024). A Closer Look at the Few-Shot Adaptation of Large Vision-Language Models. IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). \
[4] Lu, M.Y., Chen, B., Williamson, D.F.K. et al. (2024) A visual-language foundation model for computational pathology. Nature Medicine.
--- 