In [15]:
import pandas as pd
import sys
import os
import numpy as np
import base64
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
import matplotlib.pyplot as plt

sys.path.append("/mnt/data2/datasets_lfay/MedImageInsights")
from MedImageInsight.medimageinsightmodel import MedImageInsight
sys.path.append("/mnt/data2/datasets_lfay/MedImageInsights/predictions")
from utils import read_image, zero_shot_prediction, extract_findings_and_impressions, create_wandb_run_name, balance_dataset

### Prompt shuffeling

In [16]:
prompt_items = ["chest", "x-ray", "anteroposterior"]

# create random prompts by shuffling the prompt items. The word "Pneumonia" is always included in the prompt, positioned randomly
templates_pneumonia = ["Pneumonia"]
for i in range(1000):
    # select random amount of prompt items
    n = np.random.randint(1, 4)
    prompt = np.random.choice(prompt_items, n, replace=False)
    # add "Pneumonia" to the prompt
    prompt = np.insert(prompt, np.random.randint(0, n+1), "Pneumonia")
    prompt = " ".join(prompt)
    templates_pneumonia.append(prompt)

print(len(templates_pneumonia))
print(len(set(templates_pneumonia)))

templates_no_finding = ["No Finding"]
for i in range(1000):
    # select random amount of prompt items
    n = np.random.randint(1, 4)
    prompt = np.random.choice(prompt_items, n, replace=False)
    # add "Pneumonia" to the prompt
    prompt = np.insert(prompt, np.random.randint(0, n+1), "No Finding")
    prompt = " ".join(prompt)
    templates_no_finding.append(prompt)

print(len(templates_no_finding))


1001
49
1001


In [17]:
prompts_disease = list(set(templates_pneumonia))
prompts_no_disease = list(set(templates_no_finding))

In [18]:
classifier = MedImageInsight(
    model_dir="/mnt/data2/datasets_lfay/MedImageInsights/MedImageInsight/2024.09.27",
    vision_model_name="medimageinsigt-v1.0.0.pt",
    language_model_name="language_model.pth"
)
classifier.load_model()

Model loaded successfully on device: cuda


In [19]:
embeddings_disease = classifier.encode(texts=prompts_disease)["text_embeddings"]
embeddings_no_disease = classifier.encode(texts=prompts_no_disease)["text_embeddings"]

print(embeddings_disease.shape)
print(embeddings_no_disease.shape)

(49, 1024)
(49, 1024)


In [20]:
dataset = "MIMIC"
if dataset =="MIMIC":
    read_path = "/mnt/data2/datasets_lfay/MedImageInsights/data/MIMIC-v1.0-512"
elif dataset == "CheXpert":
    read_path = "/mnt/data2/datasets_lfay/MedImageInsights/data/CheXpert-v1.0-512"
elif dataset == "VinDr":
    read_path = "/mnt/data2/datasets_lfay/MedImageInsights/data/vindr-pcxr"

df_train = pd.read_csv(read_path + "/train.csv")
df_train = df_train[(df_train["No Finding"] == 1) | (df_train["Pneumonia"] == 1)]
df_train = balance_dataset(df_train, "Pneumonia", 1, True)
len(df_train)
print(dataset)

MIMIC


In [21]:
df_disease = df_train[df_train.Pneumonia==1]
print(len(df_disease))
df_no_disease = df_train[df_train.Pneumonia==0]
print(len(df_no_disease))

# get embeddings for the images from the last 1024 columns of df_train
image_embeddings_disease = df_disease.iloc[:, -1024:].values
image_embeddings_no_disease = df_no_disease.iloc[:, -1024:].values

print(image_embeddings_disease.shape)
print(image_embeddings_no_disease.shape)


372
372
(372, 1024)
(372, 1024)


In [22]:
# compute cosine similarity between each image embedding and each text embedding
cosine_similarities_disease = cosine_similarity(image_embeddings_disease, embeddings_disease)
cosine_similarities_no_disease = cosine_similarity(image_embeddings_no_disease, embeddings_no_disease)

print(cosine_similarities_disease.shape)
print(cosine_similarities_no_disease.shape)

# get the overall top k most similar texts for each image
k = 5
top_k_disease = np.argsort(cosine_similarities_disease, axis=1)[:, -k:][:, ::-1]
top_k_no_disease = np.argsort(cosine_similarities_no_disease, axis=1)[:, -k:][:, ::-1]

print(top_k_disease.shape)
print(top_k_no_disease.shape)


(372, 49)
(372, 49)
(372, 5)
(372, 5)


In [23]:
unique_k_disease, counts_k_disease = np.unique(top_k_disease, return_counts=True)
unique_k_no_disease, counts_k_no_disease = np.unique(top_k_no_disease, return_counts=True)

# Sort by counts in descending order
sorted_k_disease = np.argsort(-counts_k_disease)  # Negative sign for descending sort
sorted_k_no_disease = np.argsort(-counts_k_no_disease)  # Negative sign for descending sort

sorted_values_k_disease = unique_k_disease[sorted_k_disease]
sorted_counts_k_disease = counts_k_disease[sorted_k_disease]

sorted_values_k_no_disease = unique_k_no_disease[sorted_k_no_disease]
sorted_counts_k_no_disease = counts_k_no_disease[sorted_k_no_disease]

print(sorted_values_k_disease)
print(sorted_values_k_no_disease)


[10 30 42 11 27 37  0 32 26 46  6  9 21  1 40 22  8  2 12 17]
[17 46 31 37  6 34 39 33 13 20 45 25 23  4  1 14 12 48 27]


In [24]:
# extract elements from the list of prompts at index sorted_values[:10]
top_k_texts_disease=[prompts_disease[i] for i in sorted_values_k_disease[:10]]
top_k_texts_no_disease=[prompts_no_disease[i] for i in sorted_values_k_no_disease[:10]]

print(top_k_texts_disease)
print(top_k_texts_no_disease)

['x-ray Pneumonia anteroposterior chest', 'x-ray chest Pneumonia anteroposterior', 'x-ray chest anteroposterior Pneumonia', 'x-ray anteroposterior chest Pneumonia', 'x-ray Pneumonia chest anteroposterior', 'x-ray anteroposterior Pneumonia chest', 'chest Pneumonia x-ray anteroposterior', 'Pneumonia x-ray chest anteroposterior', 'chest x-ray Pneumonia anteroposterior', 'Pneumonia x-ray anteroposterior chest']
['anteroposterior No Finding chest', 'x-ray anteroposterior No Finding chest', 'x-ray No Finding anteroposterior chest', 'No Finding chest anteroposterior x-ray', 'anteroposterior chest No Finding', 'No Finding anteroposterior x-ray', 'No Finding anteroposterior chest x-ray', 'No Finding anteroposterior', 'x-ray anteroposterior chest No Finding', 'chest anteroposterior No Finding']


In [25]:
[print(i) for i in top_k_texts_disease]
print("***"*20)
[print(i) for i in top_k_texts_no_disease]
print("***"*20)

x-ray Pneumonia anteroposterior chest
x-ray chest Pneumonia anteroposterior
x-ray chest anteroposterior Pneumonia
x-ray anteroposterior chest Pneumonia
x-ray Pneumonia chest anteroposterior
x-ray anteroposterior Pneumonia chest
chest Pneumonia x-ray anteroposterior
Pneumonia x-ray chest anteroposterior
chest x-ray Pneumonia anteroposterior
Pneumonia x-ray anteroposterior chest
************************************************************
anteroposterior No Finding chest
x-ray anteroposterior No Finding chest
x-ray No Finding anteroposterior chest
No Finding chest anteroposterior x-ray
anteroposterior chest No Finding
No Finding anteroposterior x-ray
No Finding anteroposterior chest x-ray
No Finding anteroposterior
x-ray anteroposterior chest No Finding
chest anteroposterior No Finding
************************************************************


In [26]:
# Create text embeddings for the top k texts

embeddings_top_k_disease = classifier.encode(texts=top_k_texts_disease)["text_embeddings"]
embeddings_top_k_no_disease = classifier.encode(texts=top_k_texts_no_disease)["text_embeddings"]

print(embeddings_top_k_disease.shape)
print(embeddings_top_k_no_disease.shape)

(10, 1024)
(10, 1024)


In [27]:
# genearte averaged embeddings for the top k texts
average_embeddings_top_k_disease = np.mean(embeddings_top_k_disease, axis=0)
average_embeddings_top_k_no_disease = np.mean(embeddings_top_k_no_disease, axis=0)

print(average_embeddings_top_k_disease.shape)
print(average_embeddings_top_k_no_disease.shape)

averaged_embeddings = np.vstack([average_embeddings_top_k_no_disease, average_embeddings_top_k_disease])
print(averaged_embeddings.shape)

(1024,)
(1024,)
(2, 1024)


In [28]:
np.save("/mnt/data2/datasets_lfay/MedImageInsights/data/text_embeddings"+"/"+dataset+"/filtered_averaged_embeddings_"+dataset+".npy", averaged_embeddings)