### **Vision-Language Medical Foundation Models**

#### **1.5. Few-shot Parameter Efficient Fine-Tuning**
---

Parameter-Efficient Fine-Tuning is a methodology of increasing interest, currently popularized in the NLP community to adapt the recently introduced large-scale LLMs.

**Objective**: Given an small set of examples per category, we want to fine-tune parts of the model to specialize it on a particular task. By tuning only few parameters, we can adapt large-complexity models with minimal resources.

**Few-shot**: We only use K number of images for each new category.

**Why PEFT?**: They are efficient (ar least more than full fine-tuning), usually run with in-house GPUs. They are fast: you can transfer the model in a matter of minutes. They are more flexible than black-box Adapters, since it allows to refine deep features.

In [None]:
# General imports.
import warnings
warnings.filterwarnings('ignore')

import copy
import torch
import random

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

# Device for training/inference.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Available device: " + device)

# Seeds for reproducibility
def set_seeds(seed_value, use_cuda):
    np.random.seed(seed_value)     # cpu vars
    torch.manual_seed(seed_value)  # cpu vars
    random.seed(seed_value)        # Python
    if use_cuda:
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)     # gpu vars
        torch.backends.cudnn.deterministic = True  # needed
        torch.backends.cudnn.benchmark = False

set_seeds(42, use_cuda=torch.cuda.is_available())

#### **Dataset details**

In [None]:
# SICAPv2 dataset metadata
categories = ["NC", "G3", "G4", "G5"]                                        # List of categories
path_images = "./local_data/datasets/SICAPv2/images/"                        # Folder with the images
dataframe_train = "./local_data/datasets/SICAPv2/partition/Test/Train.xlsx"  # Dataframe (Table) containing train images names and labels
dataframe_test = "./local_data/datasets/SICAPv2/partition/Test/Test.xlsx"    # Dataframe (Table) containing test images names and labels

#### **VLM model wrapper**

In [None]:
# Load model and pre-processing tools from huggingface.
from transformers import AutoProcessor, AutoModel

# In Transformers library, models and versions are storages by an ID defininf the use and model name.
# For PLIP model, such ID is "vinid/plip".
processor = AutoProcessor.from_pretrained("vinid/plip") # pre-processing image and text.
processor.image_processor.do_center_crop = False
plip = AutoModel.from_pretrained("vinid/plip").eval() # model with pre-trained weights.
# We set model in eval mode to avoid droput inference and batchnorm stats update in CNNs.

In [None]:
# Again, we will use our PLIP Wrapper for easy interaction
class PLIPWrapper(torch.nn.Module):
    def __init__(self, encoder, proj_layer):
        super().__init__()
        
        self.encoder = encoder         # Take one-modality encoder from VLM.
        self.proj_layer = proj_layer   # Take projection layer into joint embedding space.

    def forward(self, inputs):
        
        # Forward input trough encoder
        features = self.encoder(**inputs).pooler_output # Forward trough encoder - we keep a global feature for the image/text
        
        # Project features
        projected_features = self.proj_layer(features)  # Apply projection
        
        # Ensure image features are l2-norm
        projected_features = projected_features / projected_features.norm(dim=-1, keepdim=True) # l2-normalization

        return projected_features

# Create model wrapper for vision and text encoders
vision_encoder = PLIPWrapper(plip.vision_model, plip.visual_projection)
text_encoder = PLIPWrapper(plip.text_model, plip.text_projection)

#### **Compute text prototypes**
(We will need them latter for classification head initialization)

In [None]:
# Ensemble of templates.
templates = ["a histopathology slide showing [CLS]", "histopathology image of [CLS]",
             "pathology tissue showing [CLS]", "presence of [CLS] tissue on image"]

# Category-wise descriptions, which are more informative than category names. For instance, "atrophic dense glands" better 
# describes the local findings associated with Gleason grade 3.
prompts_dict = {"NC": ["benign glands"],
                "G3": ["atrophic dense glands"],
                "G4": ["cribriform ill-formed fused papillary patterns"],
                "G5": ["isolated nest cells without lumen roseting patterns"]}

# Combine all paired options of templates and description.
prompts = {}
for iCategory in categories:
    prompts[iCategory] = [caption.replace("[CLS]", iDescription) for iDescription in prompts_dict[iCategory]
                          for caption in templates]

# Compute embeddings per category.
class_prototypes = []
for iKey in range(len(categories)):
    with torch.no_grad():
        # Retrieve descriptions for that particular category.
        descriptions = prompts[categories[iKey]]
        # Tokenize text.
        inputs = processor.tokenizer(descriptions, max_length = 77, padding=True, truncation=True, return_tensors="pt") 
        # Forward text encoder.
        text_features_ensemble = text_encoder(inputs)
        # Get class prototypes as average of all text prompts.
        avg_text_features = text_features_ensemble.mean(0).unsqueeze(0)
        # Re-normalize embedding.
        avg_text_features = avg_text_features / avg_text_features.norm(dim=-1, keepdim=True)
        class_prototypes.append(avg_text_features)
                               
# Concatenate all class prototypes.
zero_shot_prot = torch.concat(class_prototypes, dim=0)

#### **Datasets: few-shot training and test**
Now we are going to modify the vision backbone. Thus, **we cannot pre-compute the deep features, since these will change during adaptation**. Thus, we define the data loader, but we do not pre-compute the features.

In [None]:
from vlms.data import loader
from vlms.data import few_shot_loader  
from torchvision import transforms

# Pre-processing transforms to apply during data loading.
plip_transforms = transforms.Compose(
    [
    transforms.ToTensor(),                                                 # Move PIL/array image to tensor
    transforms.Normalize(std=processor.image_processor.image_std,
                         mean=processor.image_processor.image_mean),       # Intensity normalization
    transforms.Resize(list(processor.image_processor.crop_size.values()))  # Resize to pre-trained resolution
    ])

# Set test data loader.
test_loader = loader(dataframe_path=dataframe_test, path_images=path_images, categories=categories,
                     transforms=plip_transforms, batch_size=8, num_workers=0)

# Set train data loader.
shots, seed = 16, 1
train_loader = few_shot_loader(dataframe_path=dataframe_train, path_images=path_images, categories=categories, transforms=plip_transforms,
                               shots=shots, batch_size=32, num_workers=0, seed=seed)


#### **Preliminaries: coupling base model with classification head**
First, we have to equip the vision backbone with a classification head. The best option is to re-use the Linear Probe head explored in the previous notebook.

In [None]:
# We define the classification head.
class ZSLinearProbe(torch.nn.Module):
    def __init__(self, zero_shot_prot, logit_scale):
        super().__init__()
        self.logit_scale = logit_scale
        self.logit_scale.requires_grad = False
        self.prototypes = torch.nn.Parameter(zero_shot_prot.clone())
        # move to device.
        self.to(device)

    def forward(self, features):

        # Get trained prototype
        prototypes = self.prototypes.to(device)

        # l2-normalized trained weights.
        prototypes_norm = prototypes / prototypes.norm(dim=-1, keepdim=True)

        # temparature-scaled similarity per class.
        logit_scale = self.logit_scale.exp()
        logits = features @ prototypes_norm.t() * logit_scale

        return logits
    
    def loss(self, logits, y):
        loss = torch.nn.functional.cross_entropy(logits, y)
        return loss

In [None]:
# Instantiate the classification head, initializaed with zero-shot, text prototypes.
head = ZSLinearProbe(zero_shot_prot=zero_shot_prot, logit_scale=plip.logit_scale.detach().clone())

# Create model combining backbone and classification head - Also, we move the model to gpu.
model = torch.nn.Sequential(copy.deepcopy(vision_encoder),
                            head).to(device).to(torch.float32)
print(model) # Look again at the architecture of the model: it is a ViT/B/32 composed of 12 Transformer blocks.

#### **Preliminaries: counting and freezing parameters**
We present functions to control which parameters are trainable in the network, and which are frozen.

In [None]:
# Auxiliary function to count parameters.
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Auxiliary function to print the number of parameters.
def print_parameters(model):
    print("Number of trainable parameters: " + str(count_parameters(model)))
    for name, param in model.named_parameters():
        if param.requires_grad is True:
            print(name + " " * (70 - len(name)) + " -> Trained:" + str(param.requires_grad))

Lets count the number of trainable parameters currently in the model

In [None]:
print_parameters(model)

**87.8M parameters!!** And this is an "small" architecture compared with state-of-the-art models. Do we really need to finetune the whole model?

I would say no! In this notebook we will learn how to avoid this challenge with PEFT :)

**First, we will freeze all parameters in the backbone, but the classification head**.

In [None]:
# Freeze all parameters in backbone.
for name, param in model.named_parameters():
    param.requires_grad = False # Freeze.
# Unfreeze classification head.
for name, param in model[1].named_parameters():
    param.requires_grad = True # Unfreeze.
# Print trainable parameters
print_parameters(model)

Now, **we will explore different alternatives to selectively or additively adapt the model.**

------
#### **1.5.1. Selective PEFT**

These methods **tune only a small subset of the network**. 
- **Advantadges**: they usually not have specific hyper-parameters, and keep inference times.
- **Drawbacks**: they might distort the pre-trained representations more severely.


#### **Affine-LN Tuning**

**Tuning the Affine parameters from batch [1] or layer [2] normalization layers**, i.e. $\gamma$ and $\beta$. The intuiton behind the methos is: these parameters perform an scaling of such features more relevant for the task at hand, and decrease the scale of features unrelated to the downstream task.

$$\text{Affine} \rightarrow  out=\frac{x-E[x]}{\sqrt{Var[x]+\epsilon}}*\gamma +\beta$$

In [None]:
# Create a copy of the model to be adapted
model_peft = copy.deepcopy(model).eval().to(device)

In [None]:
print("Unfreeze encoder layer norm affine parameters... ", end="\n")
for m in model_peft.modules():
    for name, param in m.named_parameters():
        if "layer_norm" in name:        # Check parameters called layer norm.
            param.requires_grad = True  # Set trainable to True.
# Print trainable parameters
print_parameters(model_peft)

In [None]:
# Train the selected parameters of the model.
# Meanwhile, open a command window and check GPU usage with nvidia-smi.
from vlms.utils import train_ft # Take a look to this function, for fine-tuning on mini-batches.
# Define training hyper-parameters. Note that we decrease the number of epochs, since the convergence is faster.
# with respect to the number of forward-backward passes. This is due to the more aggresive update.
epochs, batch_size, learning_rate = 20, 16, 0.001
# Set optimizer: we use Adam optimizer, which provides better convergence in deep architectures.
optimizer = torch.optim.Adam(model_peft.parameters(), lr=learning_rate)
train_ft(loader=train_loader, model=model_peft, optimizer=optimizer, batch_size=batch_size, epochs=epochs)

In [None]:
# Predict on test set - now, we will do it on mini-batches.
from vlms.utils import predict # Take a look to this function! Inference on mini-batches.
prob, Y_test = predict(test_loader, model_peft)

In [None]:
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))

#### **Bias Tuning**

**Tuning only the Bias parameters** in ViTs, i.e. BitFit [3] has shown promising performance compared to full fine-tuning.


In [None]:
# Create a copy of the model to be adapted.
model_peft = copy.deepcopy(model).eval().to(device)

In [None]:
print("Unfreeze bias parameters... ", end="\n")
for m in model_peft.modules():
    for name, param in m.named_parameters():
        if "bias" in name:              # Check parameters called bias.
            param.requires_grad = True  # Set trainable to True.
# Print trainable parameters
print_parameters(model_peft)

In [None]:
# Train the selected parameters of the model.
# Meanwhile, open a command window and check GPU usage with nvidia-smi.
from vlms.utils import train_ft # Take a look to this function, for fine-tuning on mini-batches.
# Define training hyper-parameters. Note that we decrease the number of epochs, since the convergence is faster.
# with respect to the number of forward-backward passes. This is due to the more aggresive update.
epochs, batch_size, learning_rate = 20, 16, 0.001
# Set optimizer: we use Adam optimizer, which provides better convergence in deep architectures.
optimizer = torch.optim.Adam(model_peft.parameters(), lr=learning_rate)
train_ft(loader=train_loader, model=model_peft, optimizer=optimizer, batch_size=batch_size, epochs=epochs)

In [None]:
# Predict on test set - now, we will do it on mini-batches
from vlms.utils import predict # Take a look to this function! Inference on mini-batches
prob, Y_test = predict(test_loader, model_peft)

In [None]:
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))

------
#### **1.5.2. Additive PEFT**

These methods add an small set of parameters, so-called Adapters. Usually, they perform residual modifications of pre-trained features. Nowadays, the most popular method is LoRA, which performs a low-rank adaptation.
 
- **Advantadges**: they are more flexible than selective methods, since you can control how many parameters you introduce. The residual feature modification produces smoother changes.
- **Drawbacks**: You need to set the number of parameters you introduce, which produce extra hyper-parameters. Some Adapters also increase inference time, since you are adding operations. Note that this is not neccesary the case of LoRA, since it can be computed in paralel.

LoRA introduces to new matrices, A and B, which perform a residual modification on the output of an specific weight. Given a base linear weight $W$, and an input feature representation $x$, we can formalize LoRA as:

$$out = W(x) + B(A(x))$$


Where A and B are low-rank matrices.

**How parameter-efficient are Low-rank Adapters?** Let us denote that $x$ has dimensionality of $D$ features, and $W$ is a Linear layer with $D$ features. If we were to use only one full-rank layer for the residual modification, which we denote as $W'$, such that $out = W(x) + W'(x)$, this new later would introduce $D\times D$ parameters. Instead, the low-rank matrices A and B, with rank $r$ (e.g. $r=4$), have the dimensionality $A(D\times r)$, and $B(r \times D)$, such that A firstly compress the embedding in $r$ features, and B return it to the original dimensionality. Thus, the number of introduced parameters are $2 \cdot r \cdot D$. Image $D=128$, and typically, $r=4$. A basic Adapter would introduce 16.3K parameters, while LoRA introduces 1K.

**Numbers apart, let's see how it works!**


In [None]:
# Create the LoRA layer, to replace any linear weight
class _LoRALayer(torch.nn.Module):
    def __init__(self, w, w_a, w_b):
        super().__init__()
        self.w = w      # Original weight.
        self.w_a = w_a  # Matrix A.
        self.w_b = w_b  # Matrix B.

    def forward(self, x):
        x = self.w(x) + self.w_b(self.w_a(x)) # Residual modification with Adapter.
        return x

In [None]:
# Create LoRA Wrapper
class LoRAWrapper(torch.nn.Module):
    def __init__(self, vit_model, r=4):
        super(LoRAWrapper, self).__init__()
        # Inits
        self.ViTbase = vit_model # ViT to modify
        self.r = r               # Rank
        # create for storage, then we can init them or load weights.
        self.w_As = []  # Storage for linear layers of A matrices.
        self.w_Bs = []  # Storage for linear layers of B matrices.
        
        # We go trough the base encoder, detect Multi-Head Attention blocks, and modify adding the Adapters.
        for i, layer in enumerate(list(list(self.ViTbase.encoder.children())[2].modules())):
            if layer._get_name() == 'CLIPAttention':  # Multi-Head Attention Blocks.

                # k_proj (key)
                w_a_linear_qkv = torch.nn.Linear(layer.k_proj.in_features, r, bias=False) # layer for A matrix.
                w_b_linear_qkv = torch.nn.Linear(r, layer.k_proj.in_features, bias=False) # layer for B matrix.
                torch.nn.init.zeros_(w_b_linear_qkv.weight)                               # Set values in B to 0s.
                self.w_As.append(w_a_linear_qkv), self.w_Bs.append(w_b_linear_qkv)        # Store new weights.
                layer.k_proj = _LoRALayer(layer.k_proj, w_a_linear_qkv, w_b_linear_qkv)   # Modify layer with LoRA layer.

                # v_proj (query)
                w_a_linear_qkv = torch.nn.Linear(layer.v_proj.in_features, r, bias=False) # layer for A matrix.
                w_b_linear_qkv = torch.nn.Linear(r, layer.v_proj.in_features, bias=False) # layer for B matrix.
                torch.nn.init.zeros_(w_b_linear_qkv.weight)                               # Set values in B to 0s.
                self.w_As.append(w_a_linear_qkv), self.w_Bs.append(w_b_linear_qkv)        # Store new weights.
                layer.v_proj = _LoRALayer(layer.v_proj, w_a_linear_qkv, w_b_linear_qkv)   # Modify layer with LoRA layer.

                # q_proj (value)
                w_a_linear_qkv = torch.nn.Linear(layer.q_proj.in_features, r, bias=False) # layer for A matrix.
                w_b_linear_qkv = torch.nn.Linear(r, layer.q_proj.in_features, bias=False) # layer for B matrix.
                torch.nn.init.zeros_(w_b_linear_qkv.weight)                               # Set values in B to 0s.
                self.w_As.append(w_a_linear_qkv), self.w_Bs.append(w_b_linear_qkv)        # Store new weights.
                layer.q_proj = _LoRALayer(layer.q_proj, w_a_linear_qkv, w_b_linear_qkv)   # Modify layer with LoRA layer.

    def forward(self, x):
        return self.ViTbase(x)

In [None]:
# Create a copy of the model to be adapted.
model_peft = copy.deepcopy(model).eval().to(device)
# Add LoRA Wrapper to the model.
r=4                                                        # Rank for low-rank adaptation. This is a hyper-parameter.
model_peft[0] = LoRAWrapper(model_peft[0], r=r).to(device) # Modify vision backbone with the new architecture with Adapters
# Print trainable parameters
print_parameters(model_peft)

In [None]:
# Train the selected parameters of the model.
# Meanwhile, open a command window and check GPU usage with nvidia-smi.
from vlms.utils import train_ft # Take a look to this function, for fine-tuning on mini-batches.
# Define training hyper-parameters. Note that we decrease the number of epochs, since the convergence is faster.
# with respect to the number of forward-backward passes. This is due to the more aggresive update.
epochs, batch_size, learning_rate = 20, 16, 0.001
# Set optimizer: we use Adam optimizer, which provides better convergence in deep architectures.
optimizer = torch.optim.Adam(model_peft.parameters(), lr=learning_rate)
train_ft(loader=train_loader, model=model_peft, optimizer=optimizer, batch_size=batch_size, epochs=epochs)

In [None]:
# Predict on test set - now, we will do it on mini-batches.
from vlms.utils import predict # Take a look to this function! Inference on mini-batches.
prob, Y_test = predict(test_loader, model_peft)

In [None]:
# Compute metrics
from vlms.utils import evaluate
aca, cm = evaluate(Y_test, prob)
print("Balanced accuracy: " + str(aca))
print("Confusion matrix: ")
print(str(cm))

--- 
##### **ACTIVITY**

Well, now you know the basics of PEFT methods. If you want to know more, I reccomend:

- How is the performance if you do not initialize the B matrix with 0s in LoRA? How rank modification in LoRA affects the performance?
- Doing early stopping based on validation data also helps on avoiding over-fitting during PEFT. Modify the loaders function to create a few-shot dataset for validation, and modify training to save the best model based on validation loss.
- Explore the comparison with Black-box Adapters for more and less than 16 shots.
- Try developing the same pipeline for [CONCH](https://huggingface.co/MahmoodLab/CONCH) [4], a revently introduced VLM for histology. Its vision backbone is large scale (ViT-B/16), takes large-resolution input images (448x448), and is pre-trained with more data. How does model scaling translates to PEFT? As you can see, the larger the network, the more convinient is PEFT with respet to full fine-tuning.


--- 
## **References**


[1] Frankle, J., Schwab, D. J., Morcos, A. S. (2021). Training batchnorm and only batchnorm: On the expressive power of random features in cnns. International Conference on Learning Representations (ICLR). \
[2] Ben-Zaken, E., Ravfogel, S., Goldberg, Y. (2021). Bitfit: Simple parameter efficient fine-tuning for transformer-based masked language-models. Association for Computational Linguistics. \
[3] Hu, E. J., et al., (2022). LoRA: Low-rank adaptation of large language models. International Conference on Learning Representations (ICLR). \
[4] Lu, M.Y., Chen, B., Williamson, D.F.K. et al. (2024) A visual-language foundation model for computational pathology. Nature Medicine.

--- 