In [None]:
import os
from langchain_community.document_loaders import DirectoryLoader
from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter
import numpy as np
import pandas as pd
import nltk
nltk.download('punkt')

In [None]:
import torch
from torchvision import transforms
from PIL import Image
from pathlib import Path


from health_multimodal.image.model.pretrained import get_biovil_image_encoder
from health_multimodal.image.data.transforms import (
    create_chest_xray_transform_for_inference,
)

from health_multimodal.image import ImageInferenceEngine
from health_multimodal.text import TextInferenceEngine

from health_multimodal.text.utils import BertEncoderType, get_bert_inference


In [None]:
RESIZE = 512
CENTER_CROP_SIZE = 512

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



First, we can instantiate both the image and text inference engines from Microsoft's hi-ml library.

In [None]:
img_model = ImageInferenceEngine(
        image_model=get_biovil_image_encoder().to(device),
        transform=create_chest_xray_transform_for_inference(
            resize=RESIZE, center_crop_size=CENTER_CROP_SIZE
        ),
)
text_model = get_bert_inference(BertEncoderType.BIOVIL_T_BERT)

For the images, we need to load and transform the image and then retrieve both the patch and global embeddings. 

I.e. for one image:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def repeated_arange(n, repeats):
    # Create a tensor with values from 0 to n
    arange_tensor = torch.arange(n)
    # Repeat each element 'repeats' times
    repeated_tensor = arange_tensor.repeat_interleave(repeats)
    return repeated_tensor

First, we can write the contrastive loss function for the local to local alignment. This means that we will be comparing every single image patch embedding with every single text embedding for one specific note and image pair. 

In [None]:
class LocalContrastiveLoss(nn.Module):
    def __init__(self, image_model, text_model, temperature=0.1):
        super(LocalContrastiveLoss, self).__init__()
        self.image_model = image_model
        self.text_model = text_model
        self.temperature = temperature

    def forward(self, image1, image2, queries1, queries2):
        image_path1 = Path(image1)
        # Get patch embeddings from the image inference engine
        patch_embeddings1, _ = self.image_model.get_projected_patch_embeddings(image_path1)  # Shape: (batch_size, n_patches_h, n_patches_w, embed_dim)

        # Get text embeddings from the text inference engine
        text_embeddings1 = self.text_model.get_embeddings_from_prompt(queries1)  # Shape: (batch_size, num_text, embed_dim)

        image_path2 = Path(image2)
        # Get patch embeddings from the image inference engine
        patch_embeddings2, _ = self.image_model.get_projected_patch_embeddings(image_path2)  # Shape: (batch_size, n_patches_h, n_patches_w, embed_dim)

        # Get text embeddings from the text inference engine
        text_embeddings2 = self.text_model.get_embeddings_from_prompt(queries2)  # Shape: (batch_size, num_text, embed_dim)

        patch_embeddings = torch.stack((patch_embeddings1, patch_embeddings2), dim=0)
        text_embeddings = torch.stack((text_embeddings1, text_embeddings2), dim=0)

        print("patch emb shape after adding batch: ", patch_embeddings.shape)
        print("text emb shape after adding batch: ", text_embeddings.shape)

        # Ensure tensors are on the same device
        patch_embeddings = patch_embeddings.to(device)
        text_embeddings = text_embeddings.to(device)
        
        # Normalize embeddings
        patch_embeddings = F.normalize(patch_embeddings, dim=-1)
        text_embeddings = F.normalize(text_embeddings, dim=-1)
        
        # Flatten the spatial dimensions of patch_embeddings
        batch_size, n_patches_h, n_patches_w, embed_dim = patch_embeddings.shape
        patch_embeddings = patch_embeddings.view(batch_size, n_patches_h * n_patches_w, embed_dim)  # Shape: (batch_size, num_patches, embed_dim)

        batch_size, num_patches, emb_dim = patch_embeddings.shape
        batch_size, num_text, emb_dim = text_embeddings.shape

        # flatten batches
        patch_embeddings = patch_embeddings.view(-1, 128) #shape: (batch_size * num_patches, embed_dim)
        text_embeddings = text_embeddings.view(-1, 128) #shape: (batch_size * num_text, embed_dim)
        
        # Calculate dot product between every text embedding and every image embedding
        similarities_1 = torch.einsum('pd,td->pt', patch_embeddings, text_embeddings)  # shape: (batch_size * num_patches, batch_size * num_text)

        # should be (batch_size * num_patches, batch_size * num_text) -- with averaging --> (batch_size * num_patches, batch_size) for similarity 1
        # should be 
        
        # Temperature scaling
        similarities_1 = similarities_1 / self.temperature

        print("similarities 1 shape: ", similarities_1.shape)
        
        # Get dimensions of similarities matrix
        batch_patches, batch_texts = similarities_1.shape

        # Take average of 
        #stride_1 = num_patches
        #stride_2 = num_text
        #step = 1

        new_shape = (batch_patches, -1, num_text)
        reshaped_similarities_1 = similarities_1.view(new_shape)
        #reshaped_similarities_1 = similarities_1.view(batch_size, num_patches, -1)
        print("reshaped sim 1: ", reshaped_similarities_1.shape)
        averaged_similarities_1 = reshaped_similarities_1.mean(dim=-1)
        print("avg sim1: ", averaged_similarities_1.shape)

        #targets = torch.arange(batch_size)
        #targets = repeated_arange(batch_size, num_sims)
        
        targets_1 = torch.arange(batch_size).repeat_interleave(num_patches)
        targets_1 = targets_1.to(device)
        print("targets1 shape: ", targets_1.shape)

        logits_p = averaged_similarities_1 # shape: (batch_size, num_text)
        
        print("logits p shape: ", logits_p.shape)
        
        # Image to text cross-entropy loss
        loss_i_to_t = F.cross_entropy(logits_p, targets_1)
        
        # Combined loss
        loss = (loss_i_to_t)
        
        return loss


In [None]:
image1 = '/opt/gpudata/mimic-cxr/jpg/p10/p10011365/s53459647/f6fccc21-ded29731-2a7419a6-961566fe-710630d3.jpg'
queries1 = ['penumonia', 'pneumonia seen', 'pneumonia is seen', 'peumonia present']

image2 = '/opt/gpudata/mimic-cxr/jpg/p10/p10000980/s50985099/6ad03ed1-97ee17ee-9cf8b320-f7011003-cd93b42d.jpg'
queries2 = ['penumonia', 'pneumonia not seen', 'pneumonia is not seen', 'peumonia not present']

In [None]:
model = LocalContrastiveLoss(img_model, text_model).to(device)
loss = model(image1, image2, queries1, queries2)
print(loss)

In [None]:
class LocalContrastiveLoss(nn.Module):
    def __init__(self, image_model, text_model, temperature=0.1):
        super(LocalContrastiveLoss, self).__init__()
        self.image_model = image_model
        self.text_model = text_model
        self.temperature = temperature

    def forward(self, image1, image2, queries1, queries2):
        image_path1 = Path(image1)
        # Get patch embeddings from the image inference engine
        patch_embeddings1, _ = self.image_model.get_projected_patch_embeddings(image_path1)  # Shape: (batch_size, n_patches_h, n_patches_w, embed_dim)
        print("patch emb 1 shape originally: ", patch_embeddings1.shape)

        # Get text embeddings from the text inference engine
        text_embeddings1 = self.text_model.get_embeddings_from_prompt(queries1)  # Shape: (batch_size, num_text, embed_dim)
        print("text emb 1 shape originally: ", text_embeddings1.shape)

        image_path2 = Path(image2)
        # Get patch embeddings from the image inference engine
        patch_embeddings2, _ = self.image_model.get_projected_patch_embeddings(image_path2)  # Shape: (batch_size, n_patches_h, n_patches_w, embed_dim)
        print("patch emb 2 shape originally: ", patch_embeddings2.shape)

        # Get text embeddings from the text inference engine
        text_embeddings2 = self.text_model.get_embeddings_from_prompt(queries2)  # Shape: (batch_size, num_text, embed_dim)
        print("text emb 2 shape originally: ", text_embeddings2.shape)

        patch_embeddings = torch.stack((patch_embeddings1, patch_embeddings2), dim=0)
        text_embeddings = torch.stack((text_embeddings1, text_embeddings2), dim=0)

        print("patch emb shape after adding batch: ", patch_embeddings.shape)
        print("text emb shape after adding batch: ", text_embeddings.shape)

        # Ensure tensors are on the same device
        patch_embeddings = patch_embeddings.to(device)
        text_embeddings = text_embeddings.to(device)
        
        # Normalize embeddings
        patch_embeddings = F.normalize(patch_embeddings, dim=-1)
        text_embeddings = F.normalize(text_embeddings, dim=-1)
        
        # Flatten the spatial dimensions of patch_embeddings
        batch_size, n_patches_h, n_patches_w, embed_dim = patch_embeddings.shape
        patch_embeddings = patch_embeddings.view(batch_size, n_patches_h * n_patches_w, embed_dim)  # Shape: (batch_size, num_patches, embed_dim)

        batch_size, num_patches, emb_dim = patch_embeddings.shape
        batch_size, num_text, emb_dim = text_embeddings.shape

        # flatten batches
        patch_embeddings = patch_embeddings.view(-1, 128) #shape: (batch_size * num_patches, embed_dim)
        text_embeddings = text_embeddings.view(-1, 128) #shape: (batch_size * num_text, embed_dim)
        
        # Calculate dot product between every text embedding and every image embedding
        similarities_1 = torch.einsum('pd,td->pt', patch_embeddings, text_embeddings)  # shape: (batch_size * num_patches, batch_size * num_text)
        similarities_2 = torch.einsum('td,pd->tp', text_embeddings, patch_embeddings)  # shape: (batch_size * num_text, batch_size * num_patches)

        # should be (batch_size * num_patches, batch_size * num_text) -- with averaging --> (batch_size * num_patches, batch_size) for similarity 1
        # should be 
        
        # Temperature scaling
        similarities_1 = similarities_1 / self.temperature
        similarities_2 = similarities_2 / self.temperature

        print("similarities 1 shape: ", similarities_1.shape)
        print("similarities 2 shape: ", similarities_2.shape)
        
        # Get dimensions of similarities matrix
        batch_patches, batch_texts = similarities_1.shape

        # Take average of 
        #stride_1 = num_patches
        #stride_2 = num_text
        #step = 1

        new_shape_1 = (batch_patches, -1, num_text)
        reshaped_similarities_1 = similarities_1.view(new_shape_1)
        #reshaped_similarities_1 = similarities_1.view(batch_size, num_patches, -1)
        print("reshaped sim 1: ", reshaped_similarities_1.shape)
        averaged_similarities_1 = reshaped_similarities_1.mean(dim=-1)
        print("avg sim1: ", averaged_similarities_1.shape)

        new_shape_2 = (batch_texts, -1, num_text)
        reshaped_similarities_2 = similarities_2.view(new_shape_2)
        averaged_similarities_2 = reshaped_similarities_2.mean(dim=-1)
        print("avg sim2: ", averaged_similarities_2.shape)

        #targets = torch.arange(batch_size)
        #targets = repeated_arange(batch_size, num_sims)
        
        targets_1 = torch.arange(batch_size).repeat_interleave(num_patches)
        targets_1 = targets_1.to(device)
        print("targets1 shape: ", targets_1.shape)

        targets_2 = torch.arange(batch_size).repeat_interleave(num_text)
        targets_2 = targets_2.to(device)
        print("targets2 shape: ", targets_2.shape)

        logits_p = averaged_similarities_1 # shape: (batch_size, num_text)
        logits_t = averaged_similarities_2 # shape: (batch_size, num_text)
        
        print("logits p shape: ", logits_p.shape)
        print("logits t shape: ", logits_t.shape)
        
        # Image to text cross-entropy loss
        loss_i_to_t = F.cross_entropy(logits_p, targets_1)
        
        # Text to image cross-entropy loss
        loss_t_to_i = F.cross_entropy(logits_t, targets_2)
        
        # Combined loss
        loss = (loss_i_to_t + loss_t_to_i) / 2
        
        return loss


Make function generalizeable.

In [None]:
class LocalContrastiveLoss(nn.Module):
    def __init__(self, image_model, text_model, temperature=1.0):
        super(LocalContrastiveLoss, self).__init__()
        self.image_model = image_model
        self.text_model = text_model
        self.temperature = temperature

    def info_nce(self, first_embeddings, second_embeddings):

        batch_size, first_num, dim = first_embeddings.shape
        batch_size, second_num, dim = second_embeddings.shape

        #flatten batches
        first_embeddings = first_embeddings.view(-1, dim)
        second_embeddings = second_embeddings.view(-1, dim)

        similarities = torch.einsum('fd,sd->fs', first_embeddings, second_embeddings)
        similarities = similarities / self.temperature

        batch_first, batch_second = similarities.shape

        new_shape = (batch_first, -1, second_num)
        reshaped_similarities = similarities.view(new_shape)
        avg_similarities = reshaped_similarities.mean(dim=-1)

        targets = torch.arange(batch_size).repeat_interleave(first_num).to(device)

        logits = avg_similarities
        loss = F.cross_entropy(logits, targets)

        return loss

    def forward(self, image_embeddings, patch_embeddings, text_embeddings, chunk_embeddings):

        # Normalize and ensure all tensors are on the same device
        image_embeddings = F.normalize(image_embeddings.to(device), dim=-1)
        patch_embeddings = F.normalize(patch_embeddings.to(device), dim=-1)
        text_embeddings = F.normalize(text_embeddings.to(device), dim=-1)
        chunk_embeddings = F.normalize(chunk_embeddings.to(device), dim=-1)
        
        # Flatten the spatial dimensions of patch_embeddings
        batch_size, n_patches_h, n_patches_w, embed_dim = patch_embeddings.shape
        patch_embeddings = patch_embeddings.view(batch_size, n_patches_h * n_patches_w, embed_dim)  # Shape: (batch_size, n_patches, embed_dim)

        loss_local_i_to_local_t = self.info_nce(patch_embeddings, chunk_embeddings)
        loss_local_t_to_local_i = self.info_nce(chunk_embeddings, patch_embeddings)

        loss_local_i_to_global_t = self.info_nce(patch_embeddings, text_embeddings)
        loss_local_t_to_global_i = self.info_nce(chunk_embeddings, image_embeddings)
        loss_global_i_to_local_t = self.info_nce(image_embeddings, chunk_embeddings)
        loss_global_t_to_local_i = self.info_nce(text_embeddings, patch_embeddings)

        loss_global_t_to_global_i = self.info_nce(text_embeddings, image_embeddings)
        loss_global_i_to_global_t = self.info_nce(image_embeddings, text_embeddings)
        
        # Combined loss
        combined_loss = (loss_local_i_to_local_t + loss_local_t_to_local_i \
                        + loss_local_i_to_global_t + loss_local_t_to_global_i \
                        + loss_global_i_to_local_t + loss_global_t_to_local_i \
                        + loss_global_t_to_global_i + loss_global_i_to_global_t) / 8
        
        return combined_loss


In [None]:
test_image_batch = torch.zeros(2, 1, 128)
test_patch_batch = torch.zeros(2, 2, 9, 128)
test_text_batch = torch.ones(2, 1, 128)
test_chunks_batch = torch.ones(2, 2, 128)

In [None]:
gen_model = LocalContrastiveLoss(img_model, text_model).to(device)
loss = gen_model(test_image_batch, test_patch_batch, test_text_batch, test_chunks_batch)
print(loss)