In [None]:
import json
import pandas as pd
import os
import time
import google.generativeai as genai
import seaborn as sns
from PIL import Image
from IPython.display import display, Image as IPImage
from dotenv import load_dotenv
from tqdm import tqdm

# 0) Import Data Sets

In [None]:
# Import JSON datasets
train_images_dir = os.path.join("2024_dataset", "images", "train")
val_images_dir = os.path.join("2024_dataset", "images", "valid")
test_images_dir = os.path.join("2024_dataset", "images", "test")

train_json_file = os.path.join("2024_dataset", "train_downloaded.json")
with open(train_json_file, 'r', encoding='utf-8') as f:
    train_data = json.load(f)

val_json_file = os.path.join("2024_dataset", "valid_downloaded.json")
with open(val_json_file, 'r', encoding='utf-8') as f:
    val_data = json.load(f)

test_json_file = os.path.join("2024_dataset", "test_downloaded.json")
with open(test_json_file, 'r', encoding='utf-8') as f:
    test_data = json.load(f)

In [None]:
# Filter the dataset for valid image paths
def filter_valid_entries(data, images_dir, verbose=False):
    valid_entries = []
    for entry in data:
        image_path = os.path.normpath(os.path.join(images_dir, f"{entry['encounter_id']}.jpg"))
        if os.path.exists(image_path):  # Check if the image file exists
            valid_entries.append(entry)
        elif verbose:
            print(f"Skipping entry with missing image: {entry['encounter_id']}")
    return valid_entries

In [None]:
# Create datasets that have been filtered for valid image paths
train_valid_data = filter_valid_entries(train_data, train_images_dir)
val_valid_data = filter_valid_entries(val_data, val_images_dir)
test_valid_data = filter_valid_entries(test_data, test_images_dir)

# 1) Retrieve Captions + Images for Single Training Data Example

In [None]:
# Access image and caption data for first entry in validation set
if val_valid_data:
    first_entry = val_valid_data[0]
    image_id = first_entry["encounter_id"]
    caption = first_entry["responses"][0]["content_en"]
    print(f"Image ID: {image_id}")
    print(f"Caption: {caption}")
else:
    print("The dataset is empty.")

In [None]:
# Combine image paths with captions
data_combined = [
    {
        "image_path": os.path.join(val_images_dir, f"{entry['encounter_id']}.jpg"),
        "caption": entry["responses"][0]["content_en"],
        "query": entry["query_title_en"]
    }
    for entry in val_valid_data if os.path.exists(os.path.join(val_images_dir, f"{entry['encounter_id']}.jpg"))
]

# Convert to df
df = pd.DataFrame(data_combined)
print(df.head())

In [None]:
# Display the first image and its caption
first_row = df.iloc[0]
image_path = first_row["image_path"]
caption = first_row["caption"]
query = first_row["query"]

# Open and display the image
img = Image.open(image_path)
img.thumbnail((300, 300))  # Resize to a maximum width and height of 300 pixels
display(img)

# Print the caption
print(f"Query: {query}")
print(f"Caption: {caption}")

# 2) Pass it through an initial baseline architecture

<img src="image.png" alt="image.png" width="300"/>

In [None]:
# Step 1: Load environment variables and configure Gemini
load_dotenv()
api_key = os.getenv('API_KEY')
genai.configure(api_key=api_key)
model_name = 'gemini-1.5-flash'
model = genai.GenerativeModel(model_name)

QUESTION: is there a med-gemini we can use? https://research.google/blog/advancing-medical-ai-with-med-gemini/

In [None]:
# Step 2: Define a function to process the image with Gemini
def process_image_with_gemini(image_path, model):
    image_prompt = f"""
    Describe the following image in detail for a medical context.
    Provide a comprehensive description including any visible abnormalities, patterns, or other notable observations.
    Output the description in plain text without any additional formatting.
    """

    response = model.generate_content([image_prompt, Image.open(image_path)])
    return response.text.strip()

In [None]:
gemini_description = process_image_with_gemini(image_path, model)

In [None]:
# Print the query, caption, and Gemini description
print(f"Query: {query}")
print(f"Caption: {caption}")
print(f"Gemini Description: {gemini_description}")

In [None]:
# Step 3: Combine Gemini's output with the query and generate a response
def generate_prompt(query, image_description):
    return f"""
    Based on the following query and image description, provide a detailed and helpful medical response:

    Query: {query}
    Image Description: {image_description}

    Output the response in plain text without any additional formatting.
    """

In [None]:
# Step 4: Generate a fine-tuned dataset and fine-tune the Gemini model
# Source: https://ai.google.dev/gemini-api/docs/model-tuning/tutorial?lang=python
def generate_fine_tuning_dataset():
    train_finetuning_data = []
    for entry in tqdm(train_valid_data):
        image_path = os.path.normpath(os.path.join(train_images_dir, f"{entry['encounter_id']}.jpg"))
    
        if not os.path.exists(image_path):
            continue
    
        # Extract query and caption
        query = entry.get("query_title_en", "No query provided.")
        response = entry["responses"][0]["content_en"]
    
        try:
            # Process the image with Gemini
            image_description = process_image_with_gemini(image_path, model)
            prompt = generate_prompt(query, image_description)
    
            train_finetuning_data.append({"text_input": prompt, "output": response})
        except Exception as e:
            print("Skipping entry due to following error: ", e)

        # This is required to keep the requests less than 15 RPM
        time.sleep(5)
    return train_finetuning_data

print("Generating fine-tuning dataset...")
train_finetuning_data = generate_fine_tuning_dataset()

print("Fine-tuning model...")
# The hyperaparameters are 
operation = genai.create_tuned_model(
    display_name="mediqa",
    source_model="models/gemini-1.5-flash-001-tuning",
    epoch_count=20,
    batch_size=4,
    learning_rate=0.001,
    training_data=train_finetuning_data,
)

for status in operation.wait_bar():
    time.sleep(10)

In [None]:
# Plot Learning Curve from Fine-Tuning
result = operation.result()
snapshots = pd.DataFrame(result.tuning_task.snapshots)
sns.lineplot(data=snapshots, x='epoch', y='mean_loss')

In [None]:
# Step 5: Get the mediqa fine-tuned model
fine_tuned_model = None
for m in genai.list_tuned_models():
    if "tunedModels" in m.name and "mediqa" in m.name:
        fine_tuned_model = genai.GenerativeModel(model_name=m.name)
        break

if fine_tuned_model is None:
    print("ERROR: Unable to find fine-tuned model")
        

In [None]:
# Step 6: Generate the response from the fine-tuned model
def generate_response(query, image_description, model):
    response_prompt = generate_prompt(query, image_description)
    response = model.generate_content(response_prompt)
    return response.text.strip()

generated_response = generate_response(query, gemini_description, fine_tuned_model)

In [None]:
# Step 7: Print the query, caption, Gemini description, and generated response
print(f"Query: {query}")
print(f"Caption: {caption}")
print(f"Gemini Description: {gemini_description}")
print(f"Generated Response: {generated_response}")

# 3) Process Entire Dataset

QUESTION: How does the above process compare to simply just having a single prompt where both the image and query are provided to retrieve image description? It seems less confident in its understanding of what it could be from these outputs. What else can we do to improve the quality of the repsonse in addition to changing the process?

In [None]:
# Step 4: Workflow to process an image, query, and caption
def process_entry(entry, images_dir, model):
    # Construct the full image path
    image_path = os.path.normpath(os.path.join(images_dir, f"{entry['encounter_id']}.jpg"))
    
    # Debugging: Print the constructed image path
    print(f"Constructed image path: {image_path}")
    
    # Check if the image exists
    if not os.path.exists(image_path):
        print(f"Image does not exist: {image_path}")
        return None  # Skip entries without an image

    # Extract query and caption
    query = entry.get("query_title_en", "No query provided.")
    original_caption = entry["responses"][0]["content_en"]

    # Process the image with Gemini
    image_description = process_image_with_gemini(image_path, model)

    # Generate a response
    response = generate_response(query, image_description, model)

    # Return the combined data
    return {
        "image_path": image_path,
        "query": query,
        "original_caption": original_caption,
        "image_description": image_description,
        "response": response,
    }

In [None]:
# Load the JSON data
json_file = os.path.join("2024_dataset", "train_downloaded.json")
with open(json_file, 'r', encoding='utf-8') as f:
    data = json.load(f)

# Filter the dataset to only include entries with valid image paths
filtered_data = filter_valid_entries(data, images_dir)

# Process the filtered dataset
if filtered_data:
    # Example: Process only the first valid entry
    first_entry = filtered_data[0]
    processed_entry = process_entry(first_entry, images_dir, model)

    # Debugging: Check the processed entry
    if processed_entry:
        print("Processed Entry:")
        print(processed_entry)
    else:
        print("Processing failed for the first valid entry.")
else:
    print("No valid entries with images were found.")

QUESTION: there seem to be guardrails in place - how to prompt engineer this so the output is as meaningful for the patient + similiar to the conference results needed (aka without the guardrails)

In [None]:
# Step 5: Example Data Processing for all entires in the dataset
def process_dataset(data, images_dir, model):
    results = []
    for entry in data:
        result = process_entry(entry, images_dir, model)
        if result:
            results.append(result)
    return pd.DataFrame(results)

In [None]:
# Use this code when we are ready to process the entire dataset

# # Load the dataset
# with open(os.path.join('2024_dataset', 'train_downloaded.json'), 'r', encoding='utf-8') as f:
#     data = json.load(f)

# # Define the images directory
# images_dir = os.path.join("2024_dataset", "images", "train")

# # Configure the Gemini model
# load_dotenv()
# api_key = os.getenv('API_KEY')
# genai.configure(api_key=api_key)
# model = genai.GenerativeModel('gemini-1.5-flash')

# # Filter valid entries (optional but recommended to avoid missing images)
# data_filtered = [entry for entry in data if os.path.exists(os.path.join(images_dir, f"{entry['encounter_id']}.jpg"))]

# # Process the dataset
# processed_df = process_dataset(data_filtered, images_dir, model)

# # Inspect the output
# print(processed_df.head())

In [None]:
# Step 6: Display an example in the dataset with side-by-side comparison
def display_example(row):
    # Load and display the image
    img = Image.open(row["image_path"])
    img.thumbnail((300, 300))
    display(img)

    # Print the details
    print(f"Query: {row['query']}\n")
    print(f"Original Caption: {row['original_caption']}\n")
    print(f"Gemini Image Description: {row['image_description']}\n")
    print(f"Generated Response: {row['response']}\n")

In [None]:
# Use this code when we are ready to process the entire dataset

# # Ensure the DataFrame has been processed
# processed_df = process_dataset(filtered_data, images_dir, model)

# # Check if the DataFrame is not empty
# if not processed_df.empty:
#     # Display the first example in the dataset
#     display_example(processed_df.iloc[0])
# else:
#     print("No valid entries found in the dataset to display.")

QUESTIONS: How can we add in the following: 
- Medical chain-of-thought - see https://arxiv.org/abs/2412.13736v1
- Figure out to map generative text to multiple choice 
- Test out augmenting the queries? or image descriptions? with ShareCaptioner
- Is there a way to extract high- and low-level image features? Taking inspo from Flickr30k dataset..
- What would LLM finetuning look like here? Can we finetune multiple LLMs (region specific) and employ weight-merging or multitask learning?
- Could we leverage multimodal explainability here to provide transparency in model's reasoning? https://jayneelparekh.github.io/LMM_Concept_Explainability