In [1]:
import pandas as pd
import sys
import os
import numpy as np
import base64
from sklearn.metrics.pairwise import cosine_similarity
from PIL import Image
import matplotlib.pyplot as plt

sys.path.append("/mnt/data2/datasets_lfay/MedImageInsights")
from MedImageInsight.medimageinsightmodel import MedImageInsight
sys.path.append("/mnt/data2/datasets_lfay/MedImageInsights/predictions")
from utils import read_image, zero_shot_prediction, extract_findings_and_impressions, create_wandb_run_name


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
classes = ["pneumonia", "no finding"]
views = ["frontal", "lateral"]

# Prompt templates
prompt_templates = [
    "",
    "chest ",
    "x-ray ",
    "chest x-ray ",
    "chest x-ray anteroposterior ",
    "chest x-ray posteroanterior "
]


In [3]:
# generate prompts for each class

template_disease = [p+ "Pneumonia" for p in prompt_templates]
template_no_disease = [p + "No Pneumonia" for p in prompt_templates]

In [4]:
PATH_TO_DATA = "/mnt/data2/datasets_lfay/MedImageInsights/data"
read_path = PATH_TO_DATA+"/MIMIC-v1.0-512/"

df_train = pd.read_csv(read_path + "train.csv")

df_disease = df_train[df_train["Pneumonia"] == 1]
df_no_disease = df_train[df_train["Pneumonia"] == 0]

In [5]:
# Extract full report
report_disease = df_disease["report"].copy()
report_disease = report_disease.dropna()
report_disease = report_disease.sample(10, random_state=42).tolist()

report_no_disease = df_no_disease["report"].copy()
report_no_disease = report_no_disease.dropna()
report_no_disease = report_no_disease.sample(10, random_state=42).tolist()

In [6]:
# Extract Findings 
findings_disease = df_disease.section_findings.copy()
findings_disease = findings_disease.dropna().reset_index(drop=True)
findings_disease = findings_disease.sample(10, random_state=42).tolist()

findings_no_disease = df_no_disease.section_findings.copy()
findings_no_disease = findings_no_disease.dropna().reset_index(drop=True)
findings_no_disease = findings_no_disease.sample(10, random_state=42).tolist()

In [7]:
# Extract Impression
impression_disease = df_disease.section_impression.copy()
impression_disease = impression_disease.dropna().reset_index(drop=True)
impression_disease = impression_disease.sample(10, random_state=42).tolist()

impression_no_disease = df_no_disease.section_impression.copy()
impression_no_disease = impression_no_disease.dropna().reset_index(drop=True)
impression_no_disease = impression_no_disease.sample(10, random_state=42).tolist()

In [8]:
df_findings_impression_disease = df_disease[["section_findings", "section_impression"]].copy()
df_findings_impression_disease.dropna(inplace=True)

findings_impression_disease = (
    df_findings_impression_disease.section_findings.str.cat(
        df_findings_impression_disease.section_impression, sep=" ", na_rep=""
    )
    .sample(10, random_state=42).tolist()
)

df_findings_impression_no_disease = df_no_disease[["section_findings", "section_impression"]].copy()
df_findings_impression_no_disease.dropna(inplace=True)

findings_impression_no_disease = (
    df_findings_impression_no_disease.section_findings.str.cat(
        df_findings_impression_no_disease.section_impression, sep=" ", na_rep=""
    )
    .sample(10, random_state=42).tolist()
)

In [9]:
prompts_disease = template_disease + report_disease + findings_disease + impression_disease + findings_impression_disease
prompts_no_disease = template_no_disease + report_no_disease + findings_no_disease + impression_no_disease + findings_impression_no_disease

print (len(prompts_disease))
print (len(prompts_no_disease))

46
46


In [10]:
prompts_disease

['Pneumonia',
 'chest Pneumonia',
 'x-ray Pneumonia',
 'chest x-ray Pneumonia',
 'chest x-ray anteroposterior Pneumonia',
 'chest x-ray posteroanterior Pneumonia',
 '                                 FINAL REPORT\n CHEST RADIOGRAPH.\n \n INDICATION:  Evaluation for interval change in pulmonary edema.\n \n COMPARISON:  ___.\n \n FINDINGS:  As compared to the previous radiograph, there is no relevant\n change.  The pleural effusions have slightly increased.  The signs indicative\n of pulmonary edema are stable.  Stable is the moderate cardiomegaly.  Areas of\n atelectasis are unchanged.  No evidence of new parenchymal opacity suggesting\n pneumonia.\n',
 '                                 FINAL REPORT\n AP CHEST, 10:18 A.M., ___ \n \n HISTORY:  Check Dobbhoff tube placement.\n \n IMPRESSION:  AP chest compared to ___, 7:32 a.m.:\n \n Dobbhoff feeding tube with a wire stylet in place ends in the upper stomach. \n Tracheostomy tube is still canted and should be evaluated clinically to see if

In [11]:
## 3. Initialize model
classifier = MedImageInsight(
    model_dir="/mnt/data2/datasets_lfay/MedImageInsights/MedImageInsight/2024.09.27",
    vision_model_name="medimageinsigt-v1.0.0.pt",
    language_model_name="language_model.pth"
)
classifier.load_model()


Model loaded successfully on device: cuda


In [12]:
embeddings_disease = classifier.encode(texts=prompts_disease)["text_embeddings"]
embeddings_no_disease = classifier.encode(texts=prompts_no_disease)["text_embeddings"]
embeddings = np.concatenate([embeddings_no_disease, embeddings_disease], axis=0)
print(embeddings_disease.shape)
print(embeddings_no_disease.shape)
print(embeddings.shape)

(46, 1024)
(46, 1024)
(92, 1024)


In [13]:
# save embeddings
np.save("/mnt/data2/datasets_lfay/MedImageInsights/data/text_embeddings/embeddings_92.npy", embeddings)

In [14]:
# Step 2: Compute average embeddings for each class
average_embedding_disease = np.mean(embeddings_disease, axis=0)
average_embedding_no_disease = np.mean(embeddings_no_disease, axis=0)

# concat embeddings
average_embedding = np.stack([average_embedding_disease, average_embedding_no_disease], axis=0)
print(average_embedding_disease.shape)
print(average_embedding_no_disease.shape)
print(average_embedding.shape)

(1024,)
(1024,)
(2, 1024)


In [15]:
# save average embeddings
np.save("/mnt/data2/datasets_lfay/MedImageInsights/data/text_embeddings/average_embeddings_2.npy", average_embedding)

In [16]:
path_to_img =PATH_TO_DATA+df_train["Path"].values[0]
image_base64 = base64.encodebytes(read_image(path_to_img)).decode("utf-8")
image_embedding = classifier.encode(images=[image_base64])["image_embeddings"]
image_embedding = image_embedding.reshape(1, -1)


In [17]:
# Compute cosine similarity
similarities = cosine_similarity(image_embedding, average_embedding)
print(similarities)
# Predict class based on maximum similarity
predicted_class_idx = np.argmax(similarities)
predicted_class = "Disease" if predicted_class_idx == 0 else "No Disease"
print(f"Predicted Class: {predicted_class}")

# Apply softmax to convert similarities to probabilities
def softmax(x):
    e_x = np.exp(x - np.max(x))  # Stability trick to prevent overflow
    return e_x / e_x.sum(axis=1, keepdims=True)

# Compute probabilities
probabilities = softmax(similarities)
print(f"Probability of Disease: {probabilities[0][0]:.4f}")
print(f"Probability of No Disease: {probabilities[0][1]:.4f}")
print(df_train['Pneumonia'].values[0])
gt = "Disease" if df_train['Pneumonia'].values[0]==1 else "NoDisease"
print(f"Ground Truth: {gt}")



[[0.30274254 0.26232612]]
Predicted Class: Disease
Probability of Disease: 0.5101
Probability of No Disease: 0.4899
0.0
Ground Truth: NoDisease


In [18]:
# Compute cosine similarity
similarities = cosine_similarity(image_embedding,embeddings)
# Predict class based on maximum similarity
predicted_class_idx = np.argmax(similarities)
predicted_class = "Disease" if predicted_class_idx == 0 else "No Disease"
print(f"Predicted Class: {predicted_class}")

# Compute probabilities
probabilities = softmax(similarities)

# get k=5 max similarities and their indices
max_indices = np.argsort(similarities[0])[::-1][:5]
max_similarities = similarities[0][max_indices]

# if most of k=5 indices are < len(embeddings_disease)/0 then it is disease else no disease
if len([i for i in max_indices if i < len(embeddings_disease)]) >= 3:
    predicted_class = "Disease"
else:
    predicted_class = "No Disease"

print(f"Predicted Class: {predicted_class}")

Predicted Class: No Disease
Predicted Class: No Disease


# Prediction with augmentation

In [19]:
def chest_xray_augmentations(image, num_views=3):
    """
    Generate a specified number of augmented views for a chest X-ray image 
    while preserving the original shape.

    Args:
        image (PIL.Image): Input chest X-ray image (grayscale).
        num_views (int): Number of augmented views to generate.

    Returns:
        List[PIL.Image]: A list of augmented views with the same shape as the original.
    """
    # Define shape-preserving augmentations
    augmentation_transforms = [
        transforms.RandomRotation(degrees=10, fill=0),  # Small rotation with padding
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), fill=0),  # Translation
        transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Brightness & contrast
        transforms.GaussianBlur(kernel_size=(5, 5)),  # Simulate blurring
        transforms.Lambda(lambda img: ImageOps.autocontrast(img)),  # Enhance contrast
    ]

    # Generate augmented views
    augmented_views = []
    for _ in range(num_views):
        # Randomly select and apply augmentations
        augment_pipeline = transforms.Compose(random.sample(augmentation_transforms, k=2))
        augmented_view = augment_pipeline(image)
        augmented_views.append(augmented_view)

    return augmented_views

In [20]:
def images_are_different(image_list):
    """
    Check if all images in the list are different based on pixel-wise difference.

    Args:
        image_list (List[PIL.Image]): List of images to compare.

    Returns:
        bool: True if all images are unique, False otherwise.
    """
    hashes = [np.array(img).sum() for img in image_list]  # Simplified uniqueness check
    return len(hashes) == len(set(hashes))

In [21]:
from PIL import Image, ImageOps
from torchvision import transforms
from io import BytesIO
import base64
import random

def chest_xray_augmentations(image, num_views=3):
    """
    Generate a specified number of augmented views for a chest X-ray image.

    Args:
        image (PIL.Image): Input chest X-ray image (grayscale).
        num_views (int): Number of augmented views to generate.

    Returns:
        List[PIL.Image]: A list of augmented views.
    """
    # Define shape-preserving augmentations
    augmentation_transforms = [
        transforms.RandomRotation(degrees=10, fill=0),  # Small rotation
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), fill=0),  # Translation
        transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Brightness & contrast
        transforms.GaussianBlur(kernel_size=(5, 5)),  # Simulate blurring
        transforms.Lambda(lambda img: ImageOps.autocontrast(img)),  # Enhance contrast
    ]

    augmented_views = []
    for _ in range(num_views):
        # Randomly select and apply augmentations
        augment_pipeline = transforms.Compose(random.sample(augmentation_transforms, k=2))
        augmented_view = augment_pipeline(image)
        augmented_views.append(augmented_view)

    return augmented_views

def augment_image_to_base64(image, num_views=3):
    """
    Generate base64-encoded strings for augmented views of the input image.

    Args:
        image (PIL.Image): Input image.
        num_views (int): Number of augmented views to generate.

    Returns:
        List[str]: Base64-encoded strings of augmented images.
    """
    augmented_views = chest_xray_augmentations(image, num_views=num_views)
    all_views = [image] + augmented_views
    print(images_are_different(augmented_views))
    base64_encoded_images = []

    for augmented_image in all_views:
        # Save augmented image to an in-memory buffer as .jpg
        buffer = BytesIO()
        augmented_image.save(buffer, format="JPEG")
        buffer.seek(0)  # Move to the start of the buffer

        # Encode to base64
        base64_image = base64.encodebytes(buffer.read()).decode("utf-8")
        base64_encoded_images.append(base64_image)

        buffer.close()

    return base64_encoded_images

# Example Usage
original_image = Image.open(path_to_img).convert("L")  # Convert to grayscale

# Generate base64-encoded augmented images
aug_image_base64_list = augment_image_to_base64(original_image, num_views=63)

len(aug_image_base64_list)


True


64

In [23]:
aug_image_base64_list

['/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0a\nHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAInAgABAREA/8QAHwAAAQUBAQEB\nAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1Fh\nByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZ\nWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXG\nx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+iiiijFamg6Be+I9R\nFjYeV5xUt+9kCjH9fwrr5vhZf6RqWltqbfaNOuJliuJLTOYixwM5HTpzW38Rvh/pWheFUu9KtTG9\nvKPNdnLMynjnPvivIqKWilxTgKeF9qnjUeX0HWrCRccJn8KtIpEeGjHtxUjqjQI2xeuDxTVVFzuV\ncEelU32GEYXlW649f/1Ves441gcsoJxxx3qCcxZAKr9cVEUjdd20D6CgRIoyUX8qhdFyflGPpVcq\nvYUq7WwpUflUvlpjoPyo8pGXoOKikWOM/dzUZ2NwF5pPLFJ5YpwjWlEIPenCBaXyBjgDNRmMA4xT\nSgHakKj0pu3mrENjLMeFwPU1ZW0iiBBAdvemG3B6IM/Sj7OqYBRSTTmtkAyUGT0GKrtBhj8vH0pr\nxoo6c1Ht/wBml8vvijyT2FPFq5GduPrQbVgO1NMDD+Gm+UfSkK+1Nx7UbfagL7U7y2IziozRRRRW\n3a+FdSnt0uZ/s9jbyDKS3sywhx6qDyR9BUzeD76TIsbzTdQkH/LK0u1dz9FOCf

In [22]:
image_embeddings = classifier.encode(images=aug_image_base64_list)["image_embeddings"]
image_embeddings.shape

OutOfMemoryError: CUDA out of memory. Tried to allocate 3.96 GiB. GPU 0 has a total capacity of 23.59 GiB of which 3.38 GiB is free. Process 687943 has 8.28 GiB memory in use. Including non-PyTorch memory, this process has 11.72 GiB memory in use. Of the allocated memory 11.26 GiB is allocated by PyTorch, and 162.78 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Confidence Score


In [None]:
# Generate augmented embeddings
augmented_embeddings = []
for aug_image_base64 in aug_image_base64_list:
    embedding = classifier.encode(images=[aug_image_base64])["image_embeddings"]
    augmented_embeddings.append(embedding[0])  # Extract single embedding

# Convert to NumPy array
augmented_embeddings = np.array(augmented_embeddings)
augmented_embeddings.shape

(64, 1024)

: 

: 

In [None]:
aug_image_base64_list

NameError: name 'aug_image_base64_list' is not defined

: 

In [None]:
similarities = cosine_similarity(augmented_embeddings, average_embedding)  # Shape: (B, K)


(64, 2)

: 

In [None]:
def select_confident_samples(logits, top):
    batch_entropy = -(logits.softmax(1) * logits.log_softmax(1)).sum(1)
    idx = torch.argsort(batch_entropy, descending=False)[:int(batch_entropy.size()[0] * top)]
    return idx

: 

In [None]:
# Step 2: Compute entropy for confidence-based filtering
logits = torch.tensor(similarities)  # Convert similarities to PyTorch tensor

# Select the top 10% most confident samples
top_confidence_ratio = 0.1  # 10%
confident_indices = select_confident_samples(logits, top_confidence_ratio)
print(confident_indices) # Indices of the most confident samples

# Extract the embeddings of the most confident samples
filtered_embeddings = augmented_embeddings_torch[confident_indices]

filtered_embeddings.shape


tensor([11, 17, 20, 50, 42,  5])


torch.Size([6, 1024])

: 

In [None]:
# Step 3: Compute the averaged embedding
average_filtered_embedding = filtered_embeddings.mean(dim=0).numpy()

# Print results
print(f"Shape of Averaged Filtered Embedding: {average_filtered_embedding.shape}")


Shape of Averaged Filtered Embedding: (1024,)


: 

In [None]:
# Compute similarity scores for classification
final_scores = cosine_similarity([average_filtered_embedding], average_embedding)  # Shape: (1, K)

# Apply softmax for probabilities as torch tensor
final_probabilities = torch.nn.functional.softmax(torch.tensor(final_scores), dim=1)

# Predicted class
predicted_class = torch.argmax(final_probabilities).item()
print(f"Predicted Class: {predicted_class}")
print(f"Probability of Disease: {final_probabilities} ")

Predicted Class: 0
Probability of Disease: tensor([[0.5129, 0.4871]]) 


: 

: 