In [1]:
import os
import sys
import torch
import numpy as np
import pandas as pd
from datasets import Dataset
from PIL import Image
import gc
import pickle
from tqdm.auto import tqdm
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, PeftModel
from dotenv import load_dotenv
from torch.utils.data import DataLoader
import shutil
import json
import matplotlib.pyplot as plt
from pprint import pprint



In [2]:
# Display environment information
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

# Clean memory
gc.collect()
torch.cuda.empty_cache()

# Set up environment variables and cache directories
os.environ["HF_HOME"] = os.path.join(os.getcwd(), ".hf_cache")
print(f"HF_HOME: {os.getenv('HF_HOME')}")

Python version: 3.10.10 (main, Apr 15 2024, 11:52:16) [GCC 11.4.1 20230605 (Red Hat 11.4.1-2)]
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
CUDA device count: 1
Current CUDA device: 0
CUDA device name: NVIDIA A100 80GB PCIe
HF_HOME: /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.hf_cache


In [8]:
# Set up paths
BASE_DIR = os.path.dirname(os.path.abspath('__file__'))
DATASET_DIR = os.path.join(BASE_DIR, "2025_dataset")
TRAIN_DIR = os.path.join(DATASET_DIR, "train")
VAL_DIR = os.path.join(DATASET_DIR, "valid")
VAL_IMAGES_DIR = os.path.join(VAL_DIR, "images_valid")
OUTPUT_DIR = os.path.join(BASE_DIR, "outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(BASE_DIR)

/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2


In [29]:
train_csv_file = os.path.join(OUTPUT_DIR, "multi_label_dataset.csv")
print(train_csv_file)
train_images_dir = os.path.join(TRAIN_DIR, "images_train")
print(train_images_dir)

train_df = pd.read_csv(train_csv_file)

/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/multi_label_dataset.csv
/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/2025_dataset/train/images_train


Unnamed: 0,encounter_id,base_qid,image_id,image_path,valid_answers,valid_indices,question_text,query_title_en,query_content_en,author_id,options_en,question_type_en,question_category_en,is_multi_label
0,ENC00001,CQID010,IMG_ENC00001_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,['limited area'],[1],How much of the body is affected?,Pleural effusion accompanied by rash,A patient with pleural effusion is accompanied...,U04473,"['single spot', 'limited area', 'widespread', ...",Site,General,False
1,ENC00001,CQID010,IMG_ENC00001_00002.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,['limited area'],[1],How much of the body is affected?,Pleural effusion accompanied by rash,A patient with pleural effusion is accompanied...,U04473,"['single spot', 'limited area', 'widespread', ...",Site,General,False
2,ENC00001,CQID011,IMG_ENC00001_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,['back'],[5],1 Where is the affected area?,Pleural effusion accompanied by rash,A patient with pleural effusion is accompanied...,U04473,"['head', 'neck', 'upper extremities', 'lower e...",Site Location,General,False
3,ENC00001,CQID011,IMG_ENC00001_00002.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,['back'],[5],1 Where is the affected area?,Pleural effusion accompanied by rash,A patient with pleural effusion is accompanied...,U04473,"['head', 'neck', 'upper extremities', 'lower e...",Site Location,General,False
4,ENC00001,CQID012,IMG_ENC00001_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,['size of palm'],[1],1 How large are the affected areas? Please spe...,Pleural effusion accompanied by rash,A patient with pleural effusion is accompanied...,U04473,"['size of thumb nail', 'size of palm', 'larger...",Size,General,False


In [33]:
# Confirms the model accepts multiple labels for training

# import ast

# # Ensure 'valid_indices' is a list
# train_df['valid_indices_list'] = train_df['valid_indices'].apply(ast.literal_eval)

# # Filter rows where length of valid indices > 1
# multi_label_rows = train_df[train_df['valid_indices_list'].apply(lambda x: len(x) > 1)]

# # Display a few examples
# multi_label_rows.head()

Unnamed: 0,encounter_id,base_qid,image_id,image_path,valid_answers,valid_indices,question_text,query_title_en,query_content_en,author_id,options_en,question_type_en,question_category_en,is_multi_label,valid_indices_list
59,ENC00003,CQID011,IMG_ENC00003_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,"['back', 'chest/abdomen']","[5, 4]",1 Where is the affected area?,Interpreting Images - Is it magical skin?,"Male, 65 years old, skin lesions as shown in t...",U00780,"['head', 'neck', 'upper extremities', 'lower e...",Site Location,General,True,"[5, 4]"
60,ENC00003,CQID011,IMG_ENC00003_00002.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,"['back', 'chest/abdomen']","[5, 4]",1 Where is the affected area?,Interpreting Images - Is it magical skin?,"Male, 65 years old, skin lesions as shown in t...",U00780,"['head', 'neck', 'upper extremities', 'lower e...",Site Location,General,True,"[5, 4]"
61,ENC00003,CQID011,IMG_ENC00003_00003.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,"['back', 'chest/abdomen']","[5, 4]",1 Where is the affected area?,Interpreting Images - Is it magical skin?,"Male, 65 years old, skin lesions as shown in t...",U00780,"['head', 'neck', 'upper extremities', 'lower e...",Site Location,General,True,"[5, 4]"
62,ENC00003,CQID011,IMG_ENC00003_00004.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,"['back', 'chest/abdomen']","[5, 4]",1 Where is the affected area?,Interpreting Images - Is it magical skin?,"Male, 65 years old, skin lesions as shown in t...",U00780,"['head', 'neck', 'upper extremities', 'lower e...",Site Location,General,True,"[5, 4]"
63,ENC00003,CQID011,IMG_ENC00003_00005.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,"['back', 'chest/abdomen']","[5, 4]",1 Where is the affected area?,Interpreting Images - Is it magical skin?,"Male, 65 years old, skin lesions as shown in t...",U00780,"['head', 'neck', 'upper extremities', 'lower e...",Site Location,General,True,"[5, 4]"


In [None]:
train_df.columns

In [None]:
# train_df = train_df.head(10)  # Start with 10 samples for quick debugging
print(f"Using {len(train_df)} samples for training")

In [10]:
def safe_convert_options(options_str):
    """
    Safely convert a string representation of a list to an actual list.
    """
    if not isinstance(options_str, str):
        return options_str
        
    try:
        # Use ast.literal_eval which is safer than eval()
        import ast
        return ast.literal_eval(options_str)
    except (SyntaxError, ValueError):
        # Try common formats
        if options_str.startswith('[') and options_str.endswith(']'):
            # Strip brackets and split by commas
            return [opt.strip().strip("'\"") for opt in options_str[1:-1].split(',')]
        elif ',' in options_str:
            # Just split by commas
            return [opt.strip() for opt in options_str.split(',')]
        else:
            # Single option
            return [options_str]

In [None]:
def process_batch(batch_df, batch_idx, save_dir, images_dir):
    """
    Process a batch of data samples and save them as a pickle file.
    Includes query title and content as clinical context.
    """
    os.makedirs(save_dir, exist_ok=True)
    batch_data = []
    
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"Batch {batch_idx}"):
        try:
            # Get image path - using image_id instead of image_ids
            image_id = row.get('image_id')
            if not image_id:
                continue
                
            # Use the full image path if it's already in the dataframe
            if 'image_path' in row and os.path.exists(row['image_path']):
                image_path = row['image_path']
            else:
                # Otherwise construct from images_dir and image_id
                image_path = os.path.join(images_dir, image_id)
            
            if not os.path.exists(image_path):
                print(f"Image not found: {image_path}")
                continue

            # Verify the image is valid
            try:
                with Image.open(image_path) as img:
                    img.load()
            except Exception as e:
                print(f"Corrupt or unreadable image at {image_path} — {e}")
                continue
            
            # Get options from options_en
            if 'options_en' in row:
                options = safe_convert_options(row['options_en'])
#                 options = row['options_en']
#                 if isinstance(options, str):
#                     try:
#                         options = eval(options)
#                     except:
#                         options = options.split(',')
            else:
                options = ["Yes", "No", "Not mentioned"]
                
            options_text = ", ".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
            
            # Create metadata string
            metadata = ""
            if 'question_type_en' in row:
                metadata += f"Type: {row['question_type_en']}"
                
            if 'question_category_en' in row:
                metadata += f", Category: {row['question_category_en']}"
            
            # Get question text
            question = row.get('question_text', 'What do you see in this image?')
            
            # Get clinical context from query title and content
            query_title = row.get('query_title_en', '')
            query_content = row.get('query_content_en', '')
            
            # Create the clinical context section
            clinical_context = ""
            if query_title:
                clinical_context += f"Clinical Context: {query_title}\n"
            if query_content:
                clinical_context += f"{query_content}\n"
            
            # Create the full query text with clinical context
            query_text = (f"Question: Based on the image, {question}\n"
                         f"Question Metadata: {metadata}\n"
                         f"{clinical_context}"
                         f"Options: {options_text}")
            
            # Get answer text - from valid_answers
            if 'valid_answers' in row and row['valid_answers']:
                # For multi-label, join all valid answers
                answers = row['valid_answers']
                if isinstance(answers, list):
                    if len(answers) > 1:
                        # Join multiple answers with commas
                        answer_text = ", ".join(answers)
                    elif len(answers) == 1:
                        answer_text = answers[0]
                    else:
                        answer_text = "Not mentioned"
                else:
                    answer_text = str(answers)
            elif 'multi_label' in row:
                answer_text = row['multi_label']
            else:
                answer_text = "Not mentioned"
            
            batch_data.append({
                "id": row.get('encounter_id', str(idx)),
                "qid": row.get('base_qid', ''),
                "query_text": query_text,
                "image_path": image_path,
                "answer_text": answer_text,
                "question_type": row.get('question_type_en', ''),
                "question_category": row.get('question_category_en', '')
            })
        
        except Exception as e:
            print(f"Error processing row {idx}: {e}")
            import traceback
            traceback.print_exc()
    
    batch_file = os.path.join(save_dir, f"batch_{batch_idx}.pkl")
    with open(batch_file, 'wb') as f:
        pickle.dump(batch_data, f)
    
    return len(batch_data)


def preprocess_dataset(df, batch_size=50, save_dir="outputsprocessed_data", images_dir=None):
    """
    Process the entire dataset in batches
    """
    total_processed = 0
    
    # Use train_images_dir global variable if images_dir is not provided
    if images_dir is None:
        images_dir = train_images_dir
    
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i+batch_size]
        batch_idx = i // batch_size
        
        print(f"Processing batch {batch_idx+1}/{(len(df)-1)//batch_size + 1}")
        processed = process_batch(batch_df, batch_idx, save_dir, images_dir)
        total_processed += processed
        
        gc.collect()
        
        print(f"Processed {total_processed} examples so far")
    
    return total_processed

In [None]:
# Process and save the dataset
processed_data_dir = "outputs/processed_data"

# Clear any existing processed data (optional)
if os.path.exists(processed_data_dir):
    shutil.rmtree(processed_data_dir)
    
# Process the dataset
total_examples = preprocess_dataset(train_df, batch_size=100, save_dir=processed_data_dir)
print(f"Total processed examples: {total_examples}")

In [None]:
# Check the processed data
batch_file = os.path.join(processed_data_dir, "batch_0.pkl")
with open(batch_file, 'rb') as f:
    batch_data = pickle.load(f)

print("\nSample of processed data (first example):")
sample_data = batch_data[0]
for key, value in sample_data.items():
    print(f"{key}: {value}")
    
print("\nSample of processed data (second example):")
sample_data = batch_data[1]
for key, value in sample_data.items():
    print(f"{key}: {value}")

In [None]:
def inspect_llm_input(processed_data_dir, num_samples=3):
    """
    Load and display what the LLM receives during training, including
    the images, query text with clinical context, and expected answers.
    """
    # Load the first batch file
    batch_file = os.path.join(processed_data_dir, "batch_0.pkl")
    with open(batch_file, 'rb') as f:
        batch_data = pickle.load(f)
    
    # Print info about the number of samples
    print(f"Total samples in batch: {len(batch_data)}")
    
    # Display the requested number of samples
    for i, sample in enumerate(batch_data[:num_samples]):
        print(f"\n{'='*80}")
        print(f"SAMPLE {i+1} of {num_samples}")
        print(f"{'='*80}")
        
        # Print metadata
        print(f"ID: {sample['id']}")
        print(f"Question ID: {sample['qid']}")
        print(f"Type: {sample['question_type']}")
        print(f"Category: {sample['question_category']}")
        
        # Load and display the image
        img = Image.open(sample['image_path'])
        width, height = img.size
        print(f"Image dimensions: {width}x{height}, Format: {img.format}")
        
        # Display the image
        plt.figure(figsize=(10, 8))
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"Sample {i+1}")
        plt.show()
        
        # Display the inputs the LLM receives
        print("\nINPUT TO LLM:")
        print("-" * 80)
        print(sample['query_text'])
        print("-" * 80)
        
        # Display expected answer
        print("\nEXPECTED OUTPUT FROM LLM:")
        print("-" * 80)
        print(sample['answer_text'])
        print("-" * 80)
        
        # Display the system message that sets the context for the LLM
        print("\nSYSTEM MESSAGE (contexts the interaction):")
        print("-" * 80)
        print("You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note this is not the full context so if you are unsure or speculate other regions being affected, respond with 'Not mentioned'. You must respond with the full text of one of the provided options, exactly as written. Do not include any additional words or reasoning. Given the medical context, err on the side of caution when uncertain.")
        print("-" * 80)

# Run the function to show what the LLM gets
inspect_llm_input(processed_data_dir, num_samples=2)

In [None]:
class MedicalImageDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, processor):
        self.processor = processor
        self.examples = []
        
        for batch_file in sorted(os.listdir(data_dir)):
            if batch_file.startswith("batch_") and batch_file.endswith(".pkl"):
                with open(os.path.join(data_dir, batch_file), 'rb') as f:
                    batch_data = pickle.load(f)
                    self.examples.extend(batch_data)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        # Open image and convert to RGB
        image = Image.open(example['image_path']).convert("RGB")
        
        # Define system message for medical image analysis
        system_message = """You are a medical image analysis assistant. Your task is to examine the provided clinical images along with clinical context, and select the option(s) that best describe what you see. 

        IMPORTANT: You must respond ONLY with the exact text of the option(s) that apply. 
        - Do not provide any explanations
        - Do not include option numbers
        - Do not write "Options:" or similar prefixes
        - Do not write "Answer:" or similar prefixes
        - Multiple answers should be separated by commas
        - If unsure, respond with "Not mentioned
        
        """
        
        # Format as a conversation with system, user, and assistant messages
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": example['query_text']},
                    {"type": "image", "image": image},
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": example['answer_text']}],
            },
        ]
        
        return {"messages": messages}

In [4]:
# Load Hugging Face token if needed
load_dotenv()
hf_token = os.getenv("HF_TOKEN")  # Make sure to set this in a .env file or environment

# Set model ID
model_id = "google/gemma-3-4b-it"  # We'll use Gemma 3 4B with image understanding capabilities

# Load processor first to use in the dataset class
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
# Create helper functions
def create_dummy_image():
    """Create a small black image as a placeholder."""
    return Image.new('RGB', (224, 224), color='black')

def collate_fn(examples):
    """Custom collate function for batching examples."""
    texts = []
    images = []
    
    for example in examples:
        # Extract image from messages
        image_input = None
        for msg in example["messages"]:
            if msg["role"] == "user":
                for content in msg["content"]:
                    if isinstance(content, dict) and content.get("type") == "image" and "image" in content:
                        image_input = content["image"]
                        break
        
        if image_input is None:
            image_input = create_dummy_image()
            
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        
        texts.append(text.strip())
        images.append([image_input])
    
    batch = processor(
        text=texts, 
        images=images,
        return_tensors="pt", 
        padding=True
    )
    
    labels = batch["input_ids"].clone()
    
    # Get token IDs for special tokens to mask in loss computation
    boi_token_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["boi_token"]
    )
    eoi_token_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["eoi_token"]
    )
    
    # Mask tokens that shouldn't contribute to the loss
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == boi_token_id] = -100
    labels[labels == eoi_token_id] = -100
    
    batch["labels"] = labels
    return batch

In [None]:
# Create and test the dataset
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"Dataset size: {len(dataset)}")

if len(dataset) == 0:
    print("ERROR: Dataset is empty! Check data loading process.")
else:
    sample_size = min(3, len(dataset))
    print(f"Sampling {sample_size} examples from dataset")
    
    sample_examples = [dataset[i] for i in range(sample_size)]
    
    print(f"Sample size: {len(sample_examples)}")
    print("First example keys:", list(sample_examples[0].keys()))
    
    # Display the message structure for each sample
    for i in range(sample_size):
        example = dataset[i]
        print(f"\nExample {i+1} message structure:")
        print(f"System role: {example['messages'][0]}")
        print(f"User role content:")
        print(f"  Text: {example['messages'][1]['content'][0]['text'][:200]}...")  # Show first 200 chars
        print(f"  Image: {type(example['messages'][1]['content'][1]['image'])}")
        print(f"Assistant role: {example['messages'][2]}")
    
    # Test the collate function
    batch = collate_fn(sample_examples)
    print("\nCollated batch contains:", list(batch.keys()))
    print(f"Input_ids shape: {batch['input_ids'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")

In [None]:
# Test with a simple dataloader
dataloader = DataLoader(
    dataset,
    batch_size=8,  # Adjust based on GPU memory
    shuffle=True,
    collate_fn=collate_fn
)

total_examples = 0
for batch in dataloader:
    # Process each batch. Note: in training, pass this to model.forward()
    batch_size = len(batch["input_ids"])
    total_examples += batch_size
    print(f"Processed batch with {batch_size} examples")

print(f"Processed all {total_examples} examples")

In [None]:
# Check if GPU can support bfloat16
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8:
    print("WARNING: GPU may not fully support bfloat16. Consider using float16 instead.")

# Configure the model parameters
model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# Configure quantization for memory efficiency
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load the model
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs, token=hf_token)

# Print info about the chat template
print(f"Default chat template: {processor.tokenizer.chat_template}")
print(f"Special tokens map: {processor.tokenizer.special_tokens_map}")

In [None]:
# Configure LoRA for efficient fine-tuning
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",  # Apply LoRA to all linear layers
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],  # Save these modules fully
)

In [None]:
# Set up training configuration
output_dir = "outputs/finetuned-model"

training_args = SFTConfig(
    output_dir=output_dir,
    num_train_epochs=3,  # Adjust based on dataset size
    per_device_train_batch_size=1,  # Adjust based on GPU memory
    gradient_accumulation_steps=8,  # Accumulate gradients to simulate larger batch
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,  # Use bfloat16 precision
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=False,  # Set to True if you want to push to Hub
    report_to="tensorboard",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,  # Critical for custom datasets
    label_names=["labels"],  # Explicitly setting label_names
)

In [None]:
# Initialize the trainer with all components
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

# Start training
trainer.train()

In [None]:
# # Commenting this out as my training didn't complete

# # Save the trained model
# trainer.save_model()
# print("Training complete and model saved!")

In [None]:
# # Commenting this out as my training didn't complete

# # Clean up memory first
# del model
# del trainer
# torch.cuda.empty_cache()

In [None]:
checkpoint_path = "outputs/finetuned-model/checkpoint-1974"

# Load base model
# model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True, token=hf_token)
model = AutoModelForImageTextToText.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto",
                                                    low_cpu_mem_usage=True, token=hf_token)

# Merge LoRA weights into base model
# peft_model = PeftModel.from_pretrained(model, output_dir)
peft_model = PeftModel.from_pretrained(model, checkpoint_path)
merged_model = peft_model.merge_and_unload()

# Save the merged model
merged_dir = "outputs/merged_model"
merged_model.save_pretrained(merged_dir, safe_serialization=True, max_shard_size="2GB")

# Save the processor alongside the model
# processor = AutoProcessor.from_pretrained(output_dir)
processor = AutoProcessor.from_pretrained(checkpoint_path)
processor.save_pretrained(merged_dir)

print(f"Merged model saved to {merged_dir}")

In [5]:
# Function to load and prepare validation data
def prepare_validation_data():
    """
    Create a validation dataframe similar to the training dataframe.
    """
    print("Preparing validation data...")
    
    # Load question definitions
    questions_path = os.path.join(TRAIN_DIR, "closedquestions_definitions_imageclef2025.json")
    with open(questions_path, 'r') as f:
        questions = json.load(f)
        
    # Convert to DataFrame for easier manipulation
    questions_df = pd.json_normalize(questions)[["qid", "question_en", "options_en", "question_type_en", "question_category_en"]]
    
    # Load validation data with query information
    val_json_path = os.path.join(VAL_DIR, "valid.json")
    val_df = pd.read_json(val_json_path)
    
    # Extract relevant columns including query content and title
    query_info_df = val_df[["encounter_id", "image_ids", "query_title_en", "query_content_en", "author_id"]]
    
    # Load CVQA data (ground truth answers)
    cvqa_path = os.path.join(VAL_DIR, "valid_cvqa.json")
    with open(cvqa_path, 'r') as f:
        cvqa_data = json.load(f)
    cvqa_df = pd.json_normalize(cvqa_data)
    
    # Melt to get one row per question
    cvqa_long = cvqa_df.melt(id_vars=["encounter_id"], 
                             var_name="qid", 
                             value_name="answer_index")
    
    # Filter out encounter_id rows
    cvqa_long = cvqa_long[cvqa_long["qid"] != "encounter_id"]
    
    # Merge CVQA with questions
    cvqa_merged = cvqa_long.merge(questions_df, on="qid", how="left")
    
    # Get answer text
    def get_answer_text(row):
        try:
            return row["options_en"][row["answer_index"]]
        except (IndexError, TypeError):
            return None
    
    cvqa_merged["answer_text"] = cvqa_merged.apply(get_answer_text, axis=1)
    
    # Merge with validation data
    final_df = cvqa_merged.merge(query_info_df, on="encounter_id", how="left")
    
    # Extract the base CQID code
    final_df['base_qid'] = final_df['qid'].str.extract(r'(CQID\d+)')
    
    # Group by encounter_id and base_qid to see all answers for each question family
    grouped_by_family = final_df.groupby(['encounter_id', 'base_qid']).agg({
        'qid': list,
        'question_en': list,
        'answer_text': list,
        'answer_index': list,
        'image_ids': 'first',
        'options_en': 'first',
        'question_type_en': 'first',
        'question_category_en': 'first',
        'query_title_en': 'first',
        'query_content_en': 'first',
        'author_id': 'first'
    })
    
    # Reset index for easier manipulation
    grouped_by_family = grouped_by_family.reset_index()
    
    # Modified function to extract all valid answers (treating "Not mentioned" appropriately)
    def get_valid_answers(row):
        """
        Extract all valid answers, with special handling for "Not mentioned".
        If "Not mentioned" is the only answer for all slots, we keep it.
        Otherwise, we collect all non-"Not mentioned" answers.
        """
        answers = row['answer_text']
        answer_indices = row['answer_index']
        
        if all(ans == "Not mentioned" for ans in answers):
            return ["Not mentioned"], [answer_indices[0]]  # If all are "Not mentioned", return it as valid
        
        valid_answers = []
        valid_indices = []
        
        for i, ans in enumerate(answers):
            if ans != "Not mentioned" and ans not in valid_answers:
                valid_answers.append(ans)
                valid_indices.append(answer_indices[i])
        
        return valid_answers, valid_indices
    
    # Apply to all question families
    grouped_by_family[['valid_answers', 'valid_indices']] = grouped_by_family.apply(
        lambda row: pd.Series(get_valid_answers(row)), axis=1)
    
    # Create the multi-label validation dataset
    multi_label_data = []
    
    # Process all validation encounters
    for _, row in tqdm(grouped_by_family.iterrows(), desc="Creating validation dataset"):
        encounter_id = row['encounter_id']
        base_qid = row['base_qid']
        valid_answers = row['valid_answers']
        valid_indices = row['valid_indices']
        image_ids = row['image_ids']
        question_text = row['question_en'][0]  # Taking the first question as reference
        query_title = row['query_title_en']
        query_content = row['query_content_en']
        author_id = row['author_id']
        options_en = row['options_en']
        question_type_en = row['question_type_en']
        question_category_en = row['question_category_en']
        
        # For each image in the encounter
        for img_id in image_ids:
            img_path = os.path.join(VAL_IMAGES_DIR, img_id)
            
            # Skip if image doesn't exist
            if not os.path.exists(img_path):
                print(f"Warning: Image {img_id} not found at {img_path}")
                continue
                
            multi_label_data.append({
                'encounter_id': encounter_id,
                'base_qid': base_qid,
                'image_id': img_id,
                'image_path': img_path,
                'valid_answers': valid_answers,
                'valid_indices': valid_indices,
                'question_text': question_text,
                'query_title_en': query_title,
                'query_content_en': query_content,
                'author_id': author_id,
                'options_en': options_en,
                'question_type_en': question_type_en, 
                'question_category_en': question_category_en,
                'is_multi_label': len(valid_answers) > 1
            })
    
    # Convert to DataFrame
    val_dataset = pd.DataFrame(multi_label_data)
    
    # Save the dataset
    val_dataset.to_csv(os.path.join(OUTPUT_DIR, "val_dataset.csv"), index=False)
    
    print(f"Validation dataset created with {len(val_dataset)} entries")
    
    return val_dataset

# Function to process a batch for inference
def process_inference_batch(batch_df, batch_idx, save_dir, images_dir):
    """
    Process a batch of data samples for inference and save them as a pickle file.
    """
    os.makedirs(save_dir, exist_ok=True)
    batch_data = []
    
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"Batch {batch_idx}"):
        try:
            # Get image path
            image_id = row.get('image_id')
            if not image_id:
                continue
                
            # Use the full image path if it's already in the dataframe
            if 'image_path' in row and os.path.exists(row['image_path']):
                image_path = row['image_path']
            else:
                # Otherwise construct from images_dir and image_id
                image_path = os.path.join(images_dir, image_id)
            
            if not os.path.exists(image_path):
                print(f"Image not found: {image_path}")
                continue

            # Verify the image is valid
            try:
                with Image.open(image_path) as img:
                    img.load()
            except Exception as e:
                print(f"Corrupt or unreadable image at {image_path} — {e}")
                continue
            
            # Get options from options_en
            if 'options_en' in row:
                options = safe_convert_options(row['options_en'])
#                 options = row['options_en']
#                 if isinstance(options, str):
#                     try:
#                         options = eval(options)
#                     except:
#                         options = options.split(',')
            else:
                options = ["Yes", "No", "Not mentioned"]
                
            options_text = ", ".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
            
            # Create metadata string
            metadata = ""
            if 'question_type_en' in row:
                metadata += f"Type: {row['question_type_en']}"
                
            if 'question_category_en' in row:
                metadata += f", Category: {row['question_category_en']}"
            
            # Get question text
            question = row.get('question_text', 'What do you see in this image?')
            
            # Get clinical context from query title and content
            query_title = row.get('query_title_en', '')
            query_content = row.get('query_content_en', '')
            
            # Create the clinical context section
            clinical_context = ""
            if query_title:
                clinical_context += f"Clinical Context: {query_title}\n"
            if query_content:
                clinical_context += f"{query_content}\n"
            
            # Create the full query text with clinical context
            query_text = (f"Question: Based on the image, {question}\n"
                         f"Question Metadata: {metadata}\n"
                         f"{clinical_context}"
                         f"Options: {options_text}")
            
            batch_data.append({
                "id": row.get('encounter_id', str(idx)),
                "qid": row.get('base_qid', ''),
                "query_text": query_text,
                "image_path": image_path,
                "question_type": row.get('question_type_en', ''),
                "question_category": row.get('question_category_en', '')
            })
        
        except Exception as e:
            print(f"Error processing row {idx}: {e}")
            import traceback
            traceback.print_exc()
    
    batch_file = os.path.join(save_dir, f"val_batch_{batch_idx}.pkl")
    with open(batch_file, 'wb') as f:
        pickle.dump(batch_data, f)
    
    return len(batch_data)

def preprocess_validation_dataset(df, batch_size=50, save_dir="outputs/processed_val_data", images_dir=None):
    """
    Process the entire validation dataset in batches
    """
    total_processed = 0
    
    # Use VAL_IMAGES_DIR global variable if images_dir is not provided
    if images_dir is None:
        images_dir = VAL_IMAGES_DIR
    
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i+batch_size]
        batch_idx = i // batch_size
        
        print(f"Processing batch {batch_idx+1}/{(len(df)-1)//batch_size + 1}")
        processed = process_inference_batch(batch_df, batch_idx, save_dir, images_dir)
        total_processed += processed
        
        gc.collect()
        
        print(f"Processed {total_processed} examples so far")
    
    return total_processed

In [62]:
# Class for inference
class MedicalImageInference:
    def __init__(self, model_path, device="cuda" if torch.cuda.is_available() else "cpu"):
        self.device = device
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else None
        )
        self.model.eval()
        
    def predict(self, query_text, image_path, max_new_tokens=100):
        try:
            # Load the image
            image = Image.open(image_path).convert("RGB")

            # Create the system message
#             system_message = "You are a medical image analysis assistant. Your task is to examine the provided clinical images along with any provided clinical context, and select the option(s) that best describe what you see. Multiple answers may be correct for some questions. Respond with the exact text of the option(s) that apply. Given the medical context, err on the side of caution when uncertain."

            system_message = """You are a medical image analysis assistant. Your task is to examine the provided clinical images along with clinical context, and select the option(s) that best describe what you see. 

            IMPORTANT: You must respond ONLY with the exact text of the option(s) that apply. 
            - Do not provide any explanations
            - Do not include option numbers
            - Do not write "Options:" or similar prefixes
            - Do not write "Answer:" or similar prefixes
            - Multiple answers should be separated by commas
            - If unsure, respond with "Not mentioned

            """
            
            # Format as a conversation with system and user messages
            messages = [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": system_message}],
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": query_text},
                        {"type": "image", "image": image},
                    ],
                }
            ]

            # Create model inputs
            inputs = self.processor(
                text=self.processor.apply_chat_template(messages, tokenize=False),
                images=image,
                return_tensors="pt"
            ).to(self.device)

            # Generate prediction
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    do_sample=False
                )

            # Get only the new tokens (the model's answer)
            input_length = inputs.input_ids.shape[1]
            new_tokens = generated_ids[0][input_length:]

            # Decode only the new tokens
            prediction = self.processor.decode(new_tokens, skip_special_tokens=True)

            # Clean the prediction - remove any remaining template artifacts
            prediction = prediction.strip()
            if prediction.startswith("model\n"):
                prediction = prediction[len("model\n"):]
            
            # Extract just the answer text
            if "Answer:" in prediction:
                parts = prediction.split("Answer:")
                if len(parts) > 1:
                    prediction = parts[1].strip()
                
            if prediction.startswith("<start_of_turn>model") or prediction.startswith("<start_of_turn>assistant"):
                prediction = prediction.split("\n", 1)[1] if "\n" in prediction else ""
            if prediction.endswith("<end_of_turn>"):
                prediction = prediction[:-len("<end_of_turn>")]

            return prediction.strip()
        except Exception as e:
            print(f"Error during prediction for {image_path}: {e}")
            import traceback
            traceback.print_exc()
            return "Not mentioned"  # Default to not mentioned in case of errors
    
    def batch_predict(self, processed_data_dir, output_file, max_samples=None):
        """
        Run inference on a batch of preprocessed data
        """
        results = []
        sample_count = 0
        
        # Load all batch files
        batch_files = sorted([f for f in os.listdir(processed_data_dir) if f.startswith("val_batch_") and f.endswith(".pkl")])
        
        # Process each batch file
        for batch_file in tqdm(batch_files, desc="Processing batches"):
            with open(os.path.join(processed_data_dir, batch_file), 'rb') as f:
                batch_data = pickle.load(f)
            
            # Process each sample in the batch
            for sample in tqdm(batch_data, desc=f"Predicting {batch_file}", leave=False):
                # Get prediction
                prediction = self.predict(sample["query_text"], sample["image_path"])
                
                # Save results
                results.append({
                    "encounter_id": sample["id"],
                    "base_qid": sample["qid"],
                    "image_id": os.path.basename(sample["image_path"]),
                    "prediction": prediction
                })
                
                sample_count += 1
                if max_samples and sample_count >= max_samples:
                    break
                
            if max_samples and sample_count >= max_samples:
                break
        
        # Convert to DataFrame and save
        results_df = pd.DataFrame(results)
        results_df.to_csv(output_file, index=False)
        
        return results_df
    
#     def aggregate_predictions(self, predictions_df):
#         """
#         Aggregate predictions for each encounter and question ID
#         For each encounter-question pair, collect unique predictions across all images
#         """
#         # Group by encounter_id and base_qid
#         grouped = predictions_df.groupby(['encounter_id', 'base_qid'])
        
#         aggregated_results = []
        
#         for (encounter_id, base_qid), group in tqdm(grouped, desc="Aggregating predictions"):
#             # Extract all predictions for this group
#             predictions = group['prediction'].tolist()
#             image_ids = group['image_id'].tolist()
            
#             # Process predictions to standardize format
#             cleaned_predictions = []
#             for pred in predictions:
#                 # Handle predictions that might be in a list format
#                 if pred.startswith('[') and pred.endswith(']'):
#                     try:
#                         # Try to evaluate as a Python list
#                         pred_list = eval(pred)
#                         if isinstance(pred_list, list):
#                             cleaned_predictions.extend(pred_list)
#                             continue
#                     except:
#                         pass
                
#                 # Handle comma-separated values
#                 if ',' in pred:
#                     cleaned_predictions.extend([p.strip() for p in pred.split(',')])
#                 else:
#                     cleaned_predictions.append(pred.strip())
            
#             # Get unique predictions
#             unique_predictions = list(set(cleaned_predictions))
            
#             # If "Not mentioned" is in predictions but there are other predictions,
#             # remove "Not mentioned"
#             if len(unique_predictions) > 1 and "Not mentioned" in unique_predictions:
#                 unique_predictions.remove("Not mentioned")
            
#             # Create a single, combined prediction
#             combined_prediction = ", ".join(unique_predictions)
            
#             aggregated_results.append({
#                 "encounter_id": encounter_id,
#                 "base_qid": base_qid,
#                 "image_ids": image_ids,
#                 "unique_predictions": unique_predictions,
#                 "combined_prediction": combined_prediction
#             })
        
#         # Convert to DataFrame
#         aggregated_df = pd.DataFrame(aggregated_results)
        
#         return aggregated_df

    # Update the aggregate_predictions method in the MedicalImageInference class
    def aggregate_predictions(self, predictions_df, validation_df=None):
        """
        Aggregate predictions for each encounter and question ID
        For each encounter-question pair, collect unique predictions across all images,
        respecting the maximum allowed answers for each question type.

        Parameters:
        - predictions_df: DataFrame with prediction results
        - validation_df: Optional DataFrame containing validation data with options_en
        """
        # Define maximum allowed answers for each question type
        max_answers = {
            'CQID010': 1,  # Single answer
            'CQID011': 6,  # Up to 6 answers
            'CQID012': 6,  # Up to 6 answers
            'CQID015': 1,  # Single answer
            'CQID020': 9,  # Up to 9 answers
            'CQID025': 1,  # Single answer
            'CQID034': 1,  # Single answer
            'CQID035': 1,  # Single answer
            'CQID036': 1   # Single answer
        }

        # Set default max_answers for any question type not explicitly listed
        default_max_answers = 1

        # Group by encounter_id and base_qid
        grouped = predictions_df.groupby(['encounter_id', 'base_qid'])

        aggregated_results = []

        for (encounter_id, base_qid), group in tqdm(grouped, desc="Aggregating predictions"):
            # Extract all predictions for this group
            predictions = group['prediction'].tolist()
            image_ids = group['image_id'].tolist()

            # Process predictions to standardize format
            cleaned_predictions = []
            for pred in predictions:
                # Handle predictions that might be in a list format
                if isinstance(pred, str):
                    if pred.startswith('[') and pred.endswith(']'):
                        try:
                            # Try to evaluate as a Python list
                            pred_list = safe_convert_options(pred)
    #                         pred_list = eval(pred)
                            if isinstance(pred_list, list):
                                cleaned_predictions.extend(pred_list)
                                continue
                        except:
                            pass

                # Handle comma-separated values
                if isinstance(pred, str) and ',' in pred:
#                 if ',' in pred:
                    cleaned_predictions.extend([p.strip() for p in pred.split(',')])
                else:
                    cleaned_predictions.append(str(pred).strip())
#                     cleaned_predictions.append(pred.strip())

            all_cleaned_predictions = cleaned_predictions.copy()

            # Count frequencies of each prediction
            from collections import Counter
            prediction_counts = Counter(cleaned_predictions)

            # Get question type for determining max allowed answers
            question_type = base_qid.split('-')[0] if '-' in base_qid else base_qid

            # Determine max allowed answers for this question type
            allowed_max = max_answers.get(question_type, default_max_answers)

            # Sort predictions by frequency (most common first)
            sorted_predictions = sorted(prediction_counts.items(), 
                                       key=lambda x: x[1], 
                                       reverse=True)

            all_sorted_predictions = sorted_predictions.copy()
            
            # Get top N predictions where N is the max allowed
            top_predictions = [p[0] for p in sorted_predictions[:allowed_max]]

            # If there are ties at the cutoff point, randomly select to meet the max limit
            if len(sorted_predictions) > allowed_max:
                # Check if there's a tie at the cutoff
                cutoff_count = sorted_predictions[allowed_max-1][1]
                tied_predictions = [p[0] for p in sorted_predictions if p[1] == cutoff_count]

                # If we have more tied predictions than slots available
                if len(tied_predictions) > 1 and len(top_predictions) > allowed_max - len(tied_predictions):
                    # Remove all tied predictions from top_predictions
                    top_predictions = [p for p in top_predictions if p not in tied_predictions]

                    # Randomly select from tied predictions to fill remaining slots
                    import random
                    random.seed(42)  # For reproducibility
                    slots_remaining = allowed_max - len(top_predictions)
                    selected_tied = random.sample(tied_predictions, slots_remaining)

                    # Add the randomly selected tied predictions
                    top_predictions.extend(selected_tied)

            # If "Not mentioned" is in predictions but there are other predictions,
            # remove "Not mentioned" (unless it's the only prediction)
            if len(top_predictions) > 1 and "Not mentioned" in top_predictions:
                top_predictions.remove("Not mentioned")

            # Create a single, combined prediction
            combined_prediction = ", ".join(top_predictions)

            # Initialize options_en as None
            options_en = None

            # If validation_df is provided, try to get options_en from it
            if validation_df is not None:
                # Find matching rows in validation_df
                matching_rows = validation_df[(validation_df['encounter_id'] == encounter_id) & 
                                             (validation_df['base_qid'] == base_qid)]
                if not matching_rows.empty:
                    # Get options_en from the first matching row
                    options_en = matching_rows.iloc[0].get('options_en')

            result_dict = {
                "encounter_id": encounter_id,
                "base_qid": base_qid,
                "image_ids": image_ids,
                "unique_predictions": top_predictions,  # Now limited to max allowed
                "combined_prediction": combined_prediction,
                "all_raw_predictions": all_cleaned_predictions,
                "all_sorted_predictions": all_sorted_predictions
            }

            # Add options_en only if it's available
            if options_en is not None:
                result_dict["options_en"] = options_en

            aggregated_results.append(result_dict)

        # Convert to DataFrame
        aggregated_df = pd.DataFrame(aggregated_results)

        return aggregated_df

In [103]:
# Create a simple class to simulate command-line arguments
class Args:
    def __init__(self):
#         self.test = True  # Set to True to run in test mode
#         self.skip_data_prep = False  # Set to True to skip data preparation
#         self.batch_size = 100
#         self.max_samples = 50  # Limit number of samples for inference
#         self.model_path = "outputs/merged_model"
        self.test = False  # Set to False to run on full dataset
        self.skip_data_prep = False  # Set to False to process all data
        self.batch_size = 100
        self.max_samples = None  # No limit on samples
        self.model_path = "outputs/merged_model"
        
args = Args()
processed_val_dir = "outputs/processed_val_data"
os.makedirs(processed_val_dir, exist_ok=True)

if not args.skip_data_prep:
    # Prepare data for validation
    print("Preparing validation dataset...")
    val_df = prepare_validation_data()
    
    # Subset for testing if requested
    if args.test:
        print("Running in test mode with a small subset of data...")
        test_size = min(500, len(val_df))
        val_df = val_df.head(test_size)
    
    # Process validation data
    # Clear any existing processed data
    import shutil
    if os.path.exists(processed_val_dir):
        shutil.rmtree(processed_val_dir)
        os.makedirs(processed_val_dir)
    
    total_examples = preprocess_validation_dataset(val_df, batch_size=args.batch_size, save_dir=processed_val_dir)
    print(f"Total processed validation examples: {total_examples}")
else:
    print("Skipping data preparation...")

Preparing validation dataset...
Preparing validation data...


Creating validation dataset: 0it [00:00, ?it/s]

Validation dataset created with 1413 entries
Processing batch 1/15


Batch 0:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 100 examples so far
Processing batch 2/15


Batch 1:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 200 examples so far
Processing batch 3/15


Batch 2:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 300 examples so far
Processing batch 4/15


Batch 3:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 400 examples so far
Processing batch 5/15


Batch 4:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 500 examples so far
Processing batch 6/15


Batch 5:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 600 examples so far
Processing batch 7/15


Batch 6:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 700 examples so far
Processing batch 8/15


Batch 7:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 800 examples so far
Processing batch 9/15


Batch 8:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 900 examples so far
Processing batch 10/15


Batch 9:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 1000 examples so far
Processing batch 11/15


Batch 10:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 1100 examples so far
Processing batch 12/15


Batch 11:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 1200 examples so far
Processing batch 13/15


Batch 12:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 1300 examples so far
Processing batch 14/15


Batch 13:   0%|          | 0/100 [00:00<?, ?it/s]

Processed 1400 examples so far
Processing batch 15/15


Batch 14:   0%|          | 0/13 [00:00<?, ?it/s]

Processed 1413 examples so far
Total processed validation examples: 1413


In [104]:
val_df.head()

Unnamed: 0,encounter_id,base_qid,image_id,image_path,valid_answers,valid_indices,question_text,query_title_en,query_content_en,author_id,options_en,question_type_en,question_category_en,is_multi_label
0,ENC00852,CQID010,IMG_ENC00852_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,[limited area],[1],How much of the body is affected?,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[single spot, limited area, widespread, Not me...",Site,General,False
1,ENC00852,CQID010,IMG_ENC00852_00002.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,[limited area],[1],How much of the body is affected?,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[single spot, limited area, widespread, Not me...",Site,General,False
2,ENC00852,CQID011,IMG_ENC00852_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,"[upper extremities, head]","[2, 0]",1 Where is the affected area?,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[head, neck, upper extremities, lower extremit...",Site Location,General,True
3,ENC00852,CQID011,IMG_ENC00852_00002.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,"[upper extremities, head]","[2, 0]",1 Where is the affected area?,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[head, neck, upper extremities, lower extremit...",Site Location,General,True
4,ENC00852,CQID012,IMG_ENC00852_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,[size of palm],[1],1 How large are the affected areas? Please spe...,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[size of thumb nail, size of palm, larger area...",Size,General,False


In [106]:
not_mentioned_rows = val_df[val_df['valid_answers'].apply(lambda x: 'Not mentioned' in x)]
not_mentioned_rows.head(5) # works appropriately

Unnamed: 0,encounter_id,base_qid,image_id,image_path,valid_answers,valid_indices,question_text,query_title_en,query_content_en,author_id,options_en,question_type_en,question_category_en,is_multi_label
6,ENC00852,CQID015,IMG_ENC00852_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,[Not mentioned],[6],When did the patient first notice the issue?,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[within hours, within days, within weeks, with...",Onset,General,False
7,ENC00852,CQID015,IMG_ENC00852_00002.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,[Not mentioned],[6],When did the patient first notice the issue?,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[within hours, within days, within weeks, with...",Onset,General,False
16,ENC00852,CQID036,IMG_ENC00852_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,[Not mentioned],[2],What is the skin lesion texture?,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[smooth, rough, Not mentioned]",Texture,Skin Specific,False
17,ENC00852,CQID036,IMG_ENC00852_00002.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,[Not mentioned],[2],What is the skin lesion texture?,Is this Vitiligo? Please see picture.,"The patient is a middle age female, about 50 y...",U00295,"[smooth, rough, Not mentioned]",Texture,Skin Specific,False
73,ENC00855,CQID025,IMG_ENC00855_00001.jpg,/storage/coda1/p-dsgt_clef2025/0/kthakrar3/med...,[Not mentioned],[2],Is there any associated itching with the skin ...,Sharing a disease commonly seen,"Patient: male, 32 years old. Got the disease ...",U00904,"[yes, no, Not mentioned]",Itch,Skin Specific,False


In [107]:
# Load model and run inference
print(f"Loading model from {args.model_path}...")
inference = MedicalImageInference(args.model_path)

# Run inference on processed data
predictions_file = os.path.join(OUTPUT_DIR, 
                               f"val_predictions{'_test' if args.test else ''}.csv")
print(f"Running inference (max_samples={args.max_samples if args.max_samples else 'all'})...")
predictions_df = inference.batch_predict(processed_val_dir, predictions_file, max_samples=args.max_samples)

Loading model from outputs/merged_model...


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Running inference (max_samples=all)...


Processing batches:   0%|          | 0/15 [00:00<?, ?it/s]

Predicting val_batch_0.pkl:   0%|          | 0/100 [00:00<?, ?it/s]



Predicting val_batch_1.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_10.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_11.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_12.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_13.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_14.pkl:   0%|          | 0/13 [00:00<?, ?it/s]

Predicting val_batch_2.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_3.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_4.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_5.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_6.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_7.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_8.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

Predicting val_batch_9.pkl:   0%|          | 0/100 [00:00<?, ?it/s]

In [108]:
# Aggregate predictions
print("Aggregating predictions...")

# Load the validation dataset
val_dataset = pd.read_csv(os.path.join(OUTPUT_DIR, "val_dataset.csv"))

# Run the inference

# Aggregate predictions with options_en
aggregated_df = inference.aggregate_predictions(predictions_df, validation_df=val_dataset)

# Save aggregated results
aggregated_file = os.path.join(OUTPUT_DIR, 
                              f"aggregated_predictions{'_test' if args.test else ''}.csv")
aggregated_df.to_csv(aggregated_file, index=False)

print(f"Inference complete. Results saved to {predictions_file} and {aggregated_file}")

Aggregating predictions...


Aggregating predictions:   0%|          | 0/504 [00:00<?, ?it/s]

Inference complete. Results saved to /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/val_predictions.csv and /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/aggregated_predictions.csv


In [109]:
# Print sample of predictions for inspection
print("\nSample of raw predictions:")
predictions_df.head(3)


Sample of raw predictions:


Unnamed: 0,encounter_id,base_qid,image_id,prediction
0,ENC00852,CQID010,IMG_ENC00852_00001.jpg,widespread
1,ENC00852,CQID010,IMG_ENC00852_00002.jpg,widespread
2,ENC00852,CQID011,IMG_ENC00852_00001.jpg,upper extremities


In [110]:
print("\nSample of aggregated predictions:")
aggregated_df.head(3)


Sample of aggregated predictions:


Unnamed: 0,encounter_id,base_qid,image_ids,unique_predictions,combined_prediction,all_raw_predictions,all_sorted_predictions,options_en
0,ENC00852,CQID010,"[IMG_ENC00852_00001.jpg, IMG_ENC00852_00002.jpg]",[widespread],widespread,"[widespread, widespread]","[(widespread, 2)]","['single spot', 'limited area', 'widespread', ..."
1,ENC00852,CQID011,"[IMG_ENC00852_00001.jpg, IMG_ENC00852_00002.jpg]",[upper extremities],upper extremities,"[upper extremities, upper extremities]","[(upper extremities, 2)]","['head', 'neck', 'upper extremities', 'lower e..."
2,ENC00852,CQID012,"[IMG_ENC00852_00001.jpg, IMG_ENC00852_00002.jpg]",[size of palm],size of palm,"[size of palm, size of palm]","[(size of palm, 2)]","['size of thumb nail', 'size of palm', 'larger..."


In [127]:
# # Load raw predictions
# # predictions = pd.read_csv("/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/val_predictions_test.csv")
# predictions = pd.read_csv("/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/val_predictions.csv")

# # Load aggregated predictions
# # aggregated = pd.read_csv("/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/aggregated_predictions_test.csv")
# aggregated = pd.read_csv("/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/aggregated_predictions.csv")

# Load raw predictions
predictions = pd.read_csv(f"/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/val_predictions{'_test' if args.test else ''}.csv")

# Load aggregated predictions
aggregated = pd.read_csv(f"/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/aggregated_predictions{'_test' if args.test else ''}.csv")

In [128]:
# Display sample predictions
print("Sample of individual predictions:")
predictions.head(5)[["encounter_id", "base_qid", "image_id", "prediction"]]

Sample of individual predictions:


Unnamed: 0,encounter_id,base_qid,image_id,prediction
0,ENC00852,CQID010,IMG_ENC00852_00001.jpg,widespread
1,ENC00852,CQID010,IMG_ENC00852_00002.jpg,widespread
2,ENC00852,CQID011,IMG_ENC00852_00001.jpg,upper extremities
3,ENC00852,CQID011,IMG_ENC00852_00002.jpg,upper extremities
4,ENC00852,CQID012,IMG_ENC00852_00001.jpg,size of palm


In [129]:
# Filter predictions that have commas (indicating multiple answers)
multi_answer_preds = predictions[predictions['prediction'].str.contains(',', na=False)]

# Display sample of multi-answer predictions
print("Sample of predictions with multiple answers:")
multi_answer_preds[["encounter_id", "base_qid", "image_id", "prediction"]].head(20)

Sample of predictions with multiple answers:


Unnamed: 0,encounter_id,base_qid,image_id,prediction
12,ENC00852,CQID034,IMG_ENC00852_00001.jpg,"red, white"
13,ENC00852,CQID034,IMG_ENC00852_00002.jpg,"red, white"
14,ENC00852,CQID035,IMG_ENC00852_00001.jpg,"Leukoplakia, erythma"
15,ENC00852,CQID035,IMG_ENC00852_00002.jpg,"Leukoplakia, erythma"
16,ENC00852,CQID036,IMG_ENC00852_00001.jpg,"Leukoplakia, rough"
17,ENC00852,CQID036,IMG_ENC00852_00002.jpg,"Leukoplakia, rough"
20,ENC00853,CQID011,IMG_ENC00853_00001.jpg,"* recurrent,\n* upper extremities,\n* lo..."
21,ENC00853,CQID011,IMG_ENC00853_00002.jpg,"çeşitli yerlerde olduğu gibi, nereye olursa ol..."
22,ENC00853,CQID012,IMG_ENC00853_00001.jpg,"çeşitli yerlerde olduğu gibi, daha çok ten ren..."
26,ENC00853,CQID020,IMG_ENC00853_00001.jpg,"Based on the image, 1 What label best describe..."


In [130]:
# Count how many multi-answer predictions we have
print(f"\nTotal multi-answer predictions: {len(multi_answer_preds)}")

# See which questions tend to have multiple answers
print("\nMulti-answer predictions by question type:")
multi_answer_preds['base_qid'].value_counts()


Total multi-answer predictions: 335

Multi-answer predictions by question type:


base_qid
CQID034    77
CQID020    68
CQID012    50
CQID011    45
CQID036    41
CQID025    23
CQID035    13
CQID015    11
CQID010     7
Name: count, dtype: int64

In [131]:
print("\nSample of aggregated predictions:")
aggregated.head(5)


Sample of aggregated predictions:


Unnamed: 0,encounter_id,base_qid,image_ids,unique_predictions,combined_prediction,all_raw_predictions,all_sorted_predictions,options_en
0,ENC00852,CQID010,"['IMG_ENC00852_00001.jpg', 'IMG_ENC00852_00002...",['widespread'],widespread,"['widespread', 'widespread']","[('widespread', 2)]","['single spot', 'limited area', 'widespread', ..."
1,ENC00852,CQID011,"['IMG_ENC00852_00001.jpg', 'IMG_ENC00852_00002...",['upper extremities'],upper extremities,"['upper extremities', 'upper extremities']","[('upper extremities', 2)]","['head', 'neck', 'upper extremities', 'lower e..."
2,ENC00852,CQID012,"['IMG_ENC00852_00001.jpg', 'IMG_ENC00852_00002...",['size of palm'],size of palm,"['size of palm', 'size of palm']","[('size of palm', 2)]","['size of thumb nail', 'size of palm', 'larger..."
3,ENC00852,CQID015,"['IMG_ENC00852_00001.jpg', 'IMG_ENC00852_00002...",['Not mentioned'],Not mentioned,"['Not mentioned', 'Not mentioned']","[('Not mentioned', 2)]","['within hours', 'within days', 'within weeks'..."
4,ENC00852,CQID020,"['IMG_ENC00852_00001.jpg', 'IMG_ENC00852_00002...",['Leukoplakia'],Leukoplakia,"['Leukoplakia', 'Leukoplakia']","[('Leukoplakia', 2)]","['raised or bumpy', 'flat', 'skin loss or sunk..."


In [132]:
# Look up the aggregated result for this question
agg_result = aggregated[
    (aggregated['encounter_id'] == 'ENC00852') & 
    (aggregated['base_qid'] == 'CQID034')
]
agg_result

Unnamed: 0,encounter_id,base_qid,image_ids,unique_predictions,combined_prediction,all_raw_predictions,all_sorted_predictions,options_en
6,ENC00852,CQID034,"['IMG_ENC00852_00001.jpg', 'IMG_ENC00852_00002...",['red'],red,"['red', 'white', 'red', 'white']","[('red', 2), ('white', 2)]","['normal skin color', 'pink', 'red', 'brown', ..."


In [133]:
# Look up the aggregated result for this question
agg_result = aggregated[
    (aggregated['encounter_id'] == 'ENC00852') & 
    (aggregated['base_qid'] == 'CQID035')
]
agg_result

Unnamed: 0,encounter_id,base_qid,image_ids,unique_predictions,combined_prediction,all_raw_predictions,all_sorted_predictions,options_en
7,ENC00852,CQID035,"['IMG_ENC00852_00001.jpg', 'IMG_ENC00852_00002...",['Leukoplakia'],Leukoplakia,"['Leukoplakia', 'erythma', 'Leukoplakia', 'ery...","[('Leukoplakia', 2), ('erythma', 2)]","['single', 'multiple (please specify)', 'Not m..."


In [134]:
# Look up the aggregated result for this question
agg_result = aggregated[
    (aggregated['encounter_id'] == 'ENC00853') & 
    (aggregated['base_qid'] == 'CQID011')
]
agg_result

Unnamed: 0,encounter_id,base_qid,image_ids,unique_predictions,combined_prediction,all_raw_predictions,all_sorted_predictions,options_en
10,ENC00853,CQID011,"['IMG_ENC00853_00001.jpg', 'IMG_ENC00853_00002...","['çeşitli yerlerde olduğu gibi', '* recurren...","çeşitli yerlerde olduğu gibi, * recurrent, n...","['* recurrent', '* upper extremities', '* ...","[('* recurrent', 1), ('* upper extremities...","['head', 'neck', 'upper extremities', 'lower e..."


In [135]:
# Count frequency of different answers
answer_counts = predictions["prediction"].value_counts().head(10)
print("\nMost common predictions:")
print(answer_counts)


Most common predictions:
prediction
Not mentioned                140
widespread                    82
multiple (please specify)     76
multiple years                37
yes                           36
size of palm                  24
size of palm, larger area     22
rough                         21
upper extremities             19
limited area                  17
Name: count, dtype: int64


In [136]:
def format_predictions_for_official_eval_with_display(aggregated_df, output_file):
    """
    Format predictions as expected by the official evaluation script,
    mapping text answers to indices and distributing multiple answers
    across question variants when appropriate.
    Also displays the text values alongside their indices for verification.
    """
    # Define the question IDs and their allowed variants
    QIDS = [
        "CQID010-001",  # how much of body is affected (single answer)
        "CQID011-001", "CQID011-002", "CQID011-003", "CQID011-004", "CQID011-005", "CQID011-006",  # multiple answers allowed
        "CQID012-001", "CQID012-002", "CQID012-003", "CQID012-004", "CQID012-005", "CQID012-006",  # multiple answers allowed
        "CQID015-001",  # single answer
        "CQID020-001", "CQID020-002", "CQID020-003", "CQID020-004", "CQID020-005", 
        "CQID020-006", "CQID020-007", "CQID020-008", "CQID020-009",  # multiple answers allowed
        "CQID025-001",  # single answer
        "CQID034-001",  # single answer
        "CQID035-001",  # single answer
        "CQID036-001",  # single answer
    ]
    
    # Create a mapping of question base IDs to their allowed variants
    qid_variants = {}
    for qid in QIDS:
        base_qid, variant = qid.split('-')
        if base_qid not in qid_variants:
            qid_variants[base_qid] = []
        qid_variants[base_qid].append(qid)
    
    # Get all required base QIDs for a complete encounter
    required_base_qids = set(qid.split('-')[0] for qid in QIDS)
    
    formatted_predictions = []
    display_info = []
    
    # Group by encounter_id
    for encounter_id, group in aggregated_df.groupby('encounter_id'):
        # Get all base_qids for this encounter
        encounter_base_qids = set(group['base_qid'].unique())
        
        # Skip encounters that don't have all required questions
        if not required_base_qids.issubset(encounter_base_qids):
            print(f"Skipping encounter {encounter_id} - missing required questions")
            continue
        
        # Create a prediction entry for this encounter
        pred_entry = {'encounter_id': encounter_id}
        encounter_display = {'encounter_id': encounter_id, 'questions': []}
        
        # Process each question for this encounter
        for _, row in group.iterrows():
            base_qid = row['base_qid']
            
            # Skip if we don't have variants defined for this question
            if base_qid not in qid_variants:
                continue
            
            # Get the options list for this question
            options = safe_convert_options(row['options_en'])

#             options = row['options_en']
#             if isinstance(options, str):
#                 try:
#                     options = eval(options)
#                 except:
#                     options = options.split(',')
            
            # Find the index of "Not mentioned" in the options
            not_mentioned_index = None
            for i, opt in enumerate(options):
                if opt == "Not mentioned":
                    not_mentioned_index = i
                    break
            
            # If "Not mentioned" is not in the options, default to the last option
            if not_mentioned_index is None:
                not_mentioned_index = len(options) - 1
            
            # Get predictions
            if isinstance(row['unique_predictions'], list):
                predictions = row['unique_predictions']
            else:
                try:
                    predictions = eval(row['unique_predictions'])
                except:
                    predictions = [row['unique_predictions']]
            
            # Map text predictions to indices
            prediction_indices = []
            prediction_texts = []
            
            for pred in predictions:
                pred_text = str(pred).strip()
                prediction_texts.append(pred_text)
                
                # Find index of the prediction in options
                found = False
                for i, option in enumerate(options):
                    if pred_text.lower() == option.lower():
                        prediction_indices.append(i)
                        found = True
                        break
                
                # If prediction not found in options, use index 100
                if not found:
                    prediction_indices.append(100)
            
            # Remove duplicates while preserving order
            unique_indices = []
            unique_texts = []
            for idx, text in zip(prediction_indices, prediction_texts):
                if idx not in unique_indices:
                    unique_indices.append(idx)
                    unique_texts.append(text)
            
            # If 100 is in the list along with valid indices, remove 100
            if len(unique_indices) > 1 and 100 in unique_indices:
                idx_to_remove = unique_indices.index(100)
                unique_indices.remove(100)
                unique_texts.pop(idx_to_remove)
            
            # Get the available variants for this question
            available_variants = qid_variants[base_qid]
            
            # Store info for display
            question_display = {
                'base_qid': base_qid,
                'predicted_texts': unique_texts,
                'predicted_indices': unique_indices,
                'options': options,
                'not_mentioned_index': not_mentioned_index,
                'variant_assignments': {}
            }
            
            # For single-answer questions (with only one variant)
            if len(available_variants) == 1:
                if unique_indices:
                    # Store as a single integer, not a list
                    pred_entry[available_variants[0]] = unique_indices[0]
                    question_display['variant_assignments'][available_variants[0]] = {
                        'index': unique_indices[0],
                        'text': unique_texts[0] if unique_texts else "None"
                    }
                else:
                    # Default to "Not mentioned" if no prediction
                    pred_entry[available_variants[0]] = not_mentioned_index
                    question_display['variant_assignments'][available_variants[0]] = {
                        'index': not_mentioned_index,
                        'text': "Not mentioned"
                    }
            
            # For multi-answer questions
            else:
                # Distribute answers across available variants
                for i, idx in enumerate(unique_indices):
                    if i < len(available_variants):
                        # Store each answer as a single integer, not a list
                        pred_entry[available_variants[i]] = idx
                        question_display['variant_assignments'][available_variants[i]] = {
                            'index': idx,
                            'text': unique_texts[i] if i < len(unique_texts) else "None"
                        }
                
                # Fill remaining variants with a default value (usually "Not mentioned")
                for i in range(len(unique_indices), len(available_variants)):
                    # Use correct "Not mentioned" index for this question
                    pred_entry[available_variants[i]] = not_mentioned_index
                    question_display['variant_assignments'][available_variants[i]] = {
                        'index': not_mentioned_index,
                        'text': "Not mentioned"
                    }
            
            encounter_display['questions'].append(question_display)
        
        formatted_predictions.append(pred_entry)
        display_info.append(encounter_display)
    
    if not formatted_predictions:
        print("Warning: No complete encounters found in the data!")
    
    # Save to JSON file
    with open(output_file, 'w') as f:
        json.dump(formatted_predictions, f, indent=2)
    
    # Display information about the predictions
    for encounter in display_info:
        print(f"\nEncounter: {encounter['encounter_id']}")
        for question in encounter['questions']:
            print(f"  Question: {question['base_qid']}")
            print(f"  Predicted texts: {question['predicted_texts']}")
            print(f"  Predicted indices: {question['predicted_indices']}")
            print(f"  'Not mentioned' index: {question['not_mentioned_index']}")
            print("  Variant assignments:")
            for variant, assignment in question['variant_assignments'].items():
                print(f"    {variant}: index={assignment['index']} ({assignment['text']})")
            print(f"  Available options: {question['options']}")
            print()
    
    print(f"Formatted predictions saved to {output_file} ({len(formatted_predictions)} complete encounters)")
    return formatted_predictions

In [137]:
# Format and save predictions for official evaluation
predictions_json = os.path.join(OUTPUT_DIR, f"data_cvqa_sys{'_test' if args.test else ''}.json")
format_predictions_for_official_eval_with_display(aggregated_df, predictions_json)

print(f"Formatted predictions saved to {predictions_json}")


Encounter: ENC00852
  Question: CQID010
  Predicted texts: ['widespread']
  Predicted indices: [2]
  'Not mentioned' index: 3
  Variant assignments:
    CQID010-001: index=2 (widespread)
  Available options: ['single spot', 'limited area', 'widespread', 'Not mentioned']

  Question: CQID011
  Predicted texts: ['upper extremities']
  Predicted indices: [2]
  'Not mentioned' index: 7
  Variant assignments:
    CQID011-001: index=2 (upper extremities)
    CQID011-002: index=7 (Not mentioned)
    CQID011-003: index=7 (Not mentioned)
    CQID011-004: index=7 (Not mentioned)
    CQID011-005: index=7 (Not mentioned)
    CQID011-006: index=7 (Not mentioned)
  Available options: ['head', 'neck', 'upper extremities', 'lower extremities', 'chest/abdomen', 'back', 'other (please specify)', 'Not mentioned']

  Question: CQID012
  Predicted texts: ['size of palm']
  Predicted indices: [1]
  'Not mentioned' index: 3
  Variant assignments:
    CQID012-001: index=1 (size of palm)
    CQID012-002: inde

In [138]:
# Filter validation DataFrame to get the specific question and encounter
specific_question = val_dataset[(val_dataset['encounter_id'] == 'ENC00853') & 
                               (val_dataset['base_qid'] == 'CQID012')]

# Display all relevant columns
print("Question Information:")
print(f"Question text: {specific_question['question_text'].values[0]}")
print(f"Question type: {specific_question['question_type_en'].values[0]}")
print(f"Question category: {specific_question['question_category_en'].values[0]}")
print(f"Options: {specific_question['options_en'].values[0]}")
print(f"Multi-label: {specific_question['is_multi_label'].values[0]}")
print("\nClinical Context:")
print(f"Query title: {specific_question['query_title_en'].values[0]}")
print(f"Query content: {specific_question['query_content_en'].values[0]}")
print("\nImage information:")
print(f"Image ID: {specific_question['image_id'].values[0]}")
print(f"Image path: {specific_question['image_path'].values[0]}")

Question Information:
Question text: 1 How large are the affected areas? Please specify which affected area for each selection.
Question type: Size
Question category: General
Options: ['size of thumb nail', 'size of palm', 'larger area', 'Not mentioned']
Multi-label: False

Clinical Context:
Query title: Please help take a look, what kind of skin disease is this?
Query content: Suffering from the disease for more than 10 years.  It is recurrent and is vey itchy!  It happens wherever I scratch in some places.

Image information:
Image ID: IMG_ENC00853_00001.jpg
Image path: /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/2025_dataset/valid/images_valid/IMG_ENC00853_00001.jpg


In [139]:
# allows us to double check the saved file

# Load the formatted predictions JSON
# with open('/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/data_cvqa_sys_test.json', 'r') as f:
#     formatted_preds = json.load(f)
    
# with open('/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/data_cvqa_sys.json', 'r') as f:
#     formatted_preds = json.load(f)

# Load the formatted predictions JSON
with open(f'/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/data_cvqa_sys{"_test" if args.test else ""}.json', 'r') as f:
    formatted_preds = json.load(f)
    
# Display the first 3 entries
print("First 3 prediction entries:")
for i in range(min(3, len(formatted_preds))):
    print(f"\nPrediction {i+1}:")
    pprint(formatted_preds[i])

# Show an example of answers not in options (if any)
print("\nLooking for predictions with index 100 (not in options):")
found = False
for entry in formatted_preds:
    for key, value in entry.items():
        if key != 'encounter_id':  # Skip the encounter_id
            if (isinstance(value, list) and 100 in value) or value == 100:
                print(f"\nFound prediction not in options:")
                print(f"Encounter: {entry['encounter_id']}")
                print(f"Question: {key}")
                print(f"Prediction indices: {value}")
                
                # Load original predictions for this encounter
                agg_df = pd.read_csv('/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/aggregated_predictions_test.csv')
                base_qid = key.split('-')[0]
                encounter = entry['encounter_id']
                match = agg_df[(agg_df['encounter_id'] == encounter) & (agg_df['base_qid'] == base_qid)]
                if not match.empty:
                    print(f"Original prediction text: {match['combined_prediction'].values[0]}")
                    print(f"Available options: {match['options_en'].values[0]}")
                found = True
                break
    if found:
        break

if not found:
    print("No predictions with index 100 found in the first few entries.")

# Show statistics
question_counts = {}
for entry in formatted_preds:
    qid_count = len(entry) - 1  # Subtract 1 for encounter_id
    if qid_count in question_counts:
        question_counts[qid_count] += 1
    else:
        question_counts[qid_count] = 1

print("\nNumber of questions per encounter:")
for count, num_entries in sorted(question_counts.items()):
    print(f"{count} questions: {num_entries} encounters")

First 3 prediction entries:

Prediction 1:
{'CQID010-001': 2,
 'CQID011-001': 2,
 'CQID011-002': 7,
 'CQID011-003': 7,
 'CQID011-004': 7,
 'CQID011-005': 7,
 'CQID011-006': 7,
 'CQID012-001': 1,
 'CQID012-002': 3,
 'CQID012-003': 3,
 'CQID012-004': 3,
 'CQID012-005': 3,
 'CQID012-006': 3,
 'CQID015-001': 6,
 'CQID020-001': 100,
 'CQID020-002': 9,
 'CQID020-003': 9,
 'CQID020-004': 9,
 'CQID020-005': 9,
 'CQID020-006': 9,
 'CQID020-007': 9,
 'CQID020-008': 9,
 'CQID020-009': 9,
 'CQID025-001': 2,
 'CQID034-001': 2,
 'CQID035-001': 100,
 'CQID036-001': 100,
 'encounter_id': 'ENC00852'}

Prediction 2:
{'CQID010-001': 2,
 'CQID011-001': 100,
 'CQID011-002': 7,
 'CQID011-003': 7,
 'CQID011-004': 7,
 'CQID011-005': 7,
 'CQID011-006': 7,
 'CQID012-001': 100,
 'CQID012-002': 3,
 'CQID012-003': 3,
 'CQID012-004': 3,
 'CQID012-005': 3,
 'CQID012-006': 3,
 'CQID015-001': 5,
 'CQID020-001': 4,
 'CQID020-002': 9,
 'CQID020-003': 9,
 'CQID020-004': 9,
 'CQID020-005': 9,
 'CQID020-006': 9,
 'CQID020-

Run evaluation in terminal with the following command: 

```
python evaluate_new.py outputs/data_cvqa_sys_test.json
OR python evaluate_new.py outputs/data_cvqa_sys.json
```

Steps for submitting on platform: 
```
# Create empty masks_preds directory in the output folder
mkdir -p /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs/masks_preds

# Create the zip file with both components
cd /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/outputs
zip -r mysubmission.zip data_cvqa_sys.json masks_preds/
```