### **Vision-Language Medical Foundation Models**

#### **1.2. Contrastive text-image pre-training**
---

**The most popular paradigm for vision-language pre-training was introduced in CLIP [1]**. Given a dataset with **paired images and text descriptions**, a **vision and a text encoder** are pre-trained to **produce a joint embedding space** in which paired data propuce similar representation, which are pushed away from unpaired samples.

Note that pre-training is time-consuming, and you require a large dataset with image-text pairs, which are scarce in medical imaging. In this notebook, we will simply do a toy example to compute the CLIP **contrastive language-image pre-training loss**.



In [None]:
# General imports
import warnings
warnings.filterwarnings('ignore')

import copy
import torch

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

# Device for training/inference
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Available device: " + device)

#### **VLM model wrapper**

In [None]:
# Load model and pre-processing tools from huggingface
from transformers import AutoProcessor, AutoModel

# In Transformers library, models and versions are storages by an ID defininf the use and model name.
# For PLIP model, such ID is "vinid/plip"
processor = AutoProcessor.from_pretrained("vinid/plip") # pre-processing image and text
processor.image_processor.do_center_crop = False
plip = AutoModel.from_pretrained("vinid/plip").eval() # model with pre-trained weights
# We set model in eval mode to avoid droput inference and batchnorm stats update in CNNs

In [None]:
# Again, we will use our PLIP Wrapper for easy interaction
class PLIPWrapper(torch.nn.Module):
    def __init__(self, encoder, proj_layer):
        super().__init__()
        
        self.encoder = encoder         # Take one-modality encoder from VLM.
        self.proj_layer = proj_layer   # Take projection layer into joint embedding space.

    def forward(self, inputs):
        
        # Forward input trough encoder
        features = self.encoder(**inputs).pooler_output # Forward trough encoder - we keep a global feature for the image/text
        
        # Project features
        projected_features = self.proj_layer(features)  # Apply projection
        
        # Ensure image features are l2-norm
        projected_features = projected_features / projected_features.norm(dim=-1, keepdim=True) # l2-normalization

        return projected_features

# Create model wrapper for vision and text encoders
vision_encoder = PLIPWrapper(plip.vision_model, plip.visual_projection)
text_encoder = PLIPWrapper(plip.text_model, plip.text_projection)

#### **Contrastive pre-training**

Now, we are going to compute the CLIP pre-training loss. We will do an example with few images, and naive text descriptions

In [None]:
# Fist, lets load  few examples for each category.
nc = Image.open("./local_data/datasets/SICAPv2/images/16B0028148_Block_Region_8_2_14_xini_27858_yini_55555.jpg") 
g3 = Image.open("./local_data/datasets/SICAPv2/images/18B0006169B_Block_Region_6_3_7_xini_33958_yini_90365.jpg") 
g4 = Image.open("./local_data/datasets/SICAPv2/images/17B0032153_Block_Region_10_13_17_xini_23105_yini_15687.jpg") 
g5 = Image.open("./local_data/datasets/SICAPv2/images/16B0008067_Block_Region_0_6_2_xini_10859_yini_103113.jpg") 

# Now, we can read the images, pre-processing then, and concatenate as a batched input.
image_inputs = processor.image_processor([nc, g3, g4, g5], return_tensors="pt")
print(image_inputs['pixel_values'].shape)

# Now, we can produce some naive text descriptions for the toy example, which would be "paired" with each image
prompt = ["Healthy prostate tissue",
          "Gleason grade 3",
          "Gleason grade 4",
          "Gleason grade 5"]

# Pre-process text data
text_inputs = processor.tokenizer(prompt, max_length = 77, padding=True, truncation=True, return_tensors="pt") 

# Forward representations into common space
with torch.no_grad():
    image_features = vision_encoder(image_inputs)
    text_features = text_encoder(text_inputs)

# Compute similarity matrix (matrix multiplication v(bs x D) @ t(D x bs) -> (bs x bs))
# This matrix represents, for each row, the similarity of the image ith to the text jth.
# Ideally, we want large similarity in the diagonal (image i=0 with text j=0, this is, paired).
# Also, we want smaller similarity in elements out of the diagonal.
# The contrastive term does so, from images to texts (i.e. per rows) and text to images (i.e. per columns).
with torch.no_grad():
    sim = image_features @ text_features.t() * plip.logit_scale.exp()
print("Predicted similarity matrix")
print(sim)
print(" ")

# One-to-One Target
target = torch.eye(text_features.shape[0]).detach()  # Create target similarity matrix
print("Target similarity matrix")
print(target)
print(" ")

# Image-to-text loss:
# 1.Compute softmax over rows.
# 2. Apply cross-entropy, being the target for each sample the intex of its paired text.
logits_per_image = sim
print("I2T Softmax:")
with torch.no_grad():
    print(str(torch.softmax(logits_per_image, dim=-1).numpy().round(2)))
    print(" ")
i2t_loss = torch.nn.functional.cross_entropy(logits_per_image, target)

# Text-to-image loss:
# 1.Compute softmax over columns.
# 2. Apply cross-entropy, being the target for each text the inted of its paired image.
logits_per_text = logits_per_image.t()
print("T2I Softmax:")
with torch.no_grad():
    print(str(torch.softmax(logits_per_text, dim=-1).numpy().round(2)))
    print(" ")
t2i_loss = torch.nn.functional.cross_entropy(logits_per_text, target)

# Overall clip loss
clip_loss = (i2t_loss + t2i_loss) / 2
print("CLIP Loss: " + str(clip_loss.item()))

# This loss is computed on mini-batches, and is backpropagated trough both encoders. 
# During training, the joint embedding space will be trained to minimize CLIP loss, which will align
# paired concepts in both modalities. It is worth mentioning that the temperature scaling parameter
# is also optimized during pre-training.

## **References**


[1] Radford, A., Kim, J.W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., Clark, J., Krueger, G., & Sutskever, I. (2021). Learning Transferable Visual Models From Natural Language Supervision. International Conference on Machine Learning. 

---