In [35]:
# Medical Image Inference with Baseline Model

## Installations and Setup
import os
import torch
import pandas as pd
from PIL import Image
import gc
from tqdm.auto import tqdm
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from dotenv import load_dotenv

In [36]:
# Free memory
gc.collect()
torch.cuda.empty_cache()

In [37]:
# Set cache directory
print("TRANSFORMERS_CACHE:", os.getenv("TRANSFORMERS_CACHE"))
os.environ.pop("TRANSFORMERS_CACHE", None)
os.environ["HF_HOME"] = "/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.hf_cache"

# Print environment info
print("TRANSFORMERS_CACHE:", os.getenv("TRANSFORMERS_CACHE"))
print("HF_HOME:", os.getenv("HF_HOME"))
print(sys.executable)

print(f"Python version: {sys.version}")
print(f"Python version info: {sys.version_info}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

TRANSFORMERS_CACHE: None
TRANSFORMERS_CACHE: None
HF_HOME: /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.hf_cache
/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.venv/bin/python
Python version: 3.10.10 (main, Apr 15 2024, 11:52:16) [GCC 11.4.1 20230605 (Red Hat 11.4.1-2)]
Python version info: sys.version_info(major=3, minor=10, micro=10, releaselevel='final', serial=0)
PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
CUDA device count: 1
Current CUDA device: 0
CUDA device name: NVIDIA A100 80GB PCIe


In [11]:
train_csv_file = os.path.join("2025_dataset", "train", "final_df_2.csv")
train_images_dir = os.path.join("2025_dataset", "train", "images_train")

train_df = pd.read_csv(train_csv_file)

train_df['image_ids'] = train_df['image_ids'].apply(eval)
train_df['responses_en'] = train_df['responses_en'].apply(eval)

train_df.head()

Unnamed: 0,encounter_id,qid,answer_index,question_en,options_en,question_type_en,question_category_en,answer_text,author_id,image_ids,responses,query_title_en,query_content_en,image_paths,responses_en
0,ENC00001,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,limited area,U04473,"[IMG_ENC00001_00001.jpg, IMG_ENC00001_00002.jpg]","[{'author_id': 'U00217', 'content_zh': '银屑病，似与...",Pleural effusion accompanied by rash,A patient with pleural effusion is accompanied...,['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,[Psoriasis seems to have no relation to pleura...
1,ENC00002,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,limited area,U06063,"[IMG_ENC00002_00001.jpg, IMG_ENC00002_00002.jp...","[{'author_id': 'U11305', 'content_zh': '脚气', '...",What is on the bottom of the right foot?,"The patient is a 50-year-old male, who has bee...",['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,[Beriberi]
2,ENC00003,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,limited area,U00780,"[IMG_ENC00003_00001.jpg, IMG_ENC00003_00002.jp...","[{'author_id': 'U01131', 'content_zh': '瘙痒症，有无...",Interpreting Images - Is it magical skin?,"Male, 65 years old, skin lesions as shown in t...",['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,"[Pruritus, is there any other special medical ..."
3,ENC00004,CQID010-001,2,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,widespread,U00209,"[IMG_ENC00004_00001.jpg, IMG_ENC00004_00002.jpg]","[{'author_id': 'U06715', 'content_zh': '肢端角化病？...",Skin Disease,"Male, 15 years old, keratosis on both palms, s...",['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,"[Acrokeratosis?, Progressive Symmetrical Eryth..."
4,ENC00005,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,limited area,U09050,[IMG_ENC00005_00001.jpg],"[{'author_id': 'U09402', 'content_zh': '是否神经性皮...",Perifollicular atrophy?,"Young female, silver-gray dot-like atrophy spo...",['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,"[Is it neurodermatitis?, Impotence?, Lichen Sc..."


In [12]:
train_df = train_df.head(10)  # Start with 10 samples for quick debugging
print(f"Using {len(train_df)} samples for training")

Using 10 samples for training


In [13]:
train_df = train_df[['encounter_id', 'qid', 'question_en', 'options_en', 'answer_text', 'image_ids', 'question_type_en', 'question_category_en']]

print(f"Filtered dataframe shape: {train_df.shape}")
print("Columns:", train_df.columns.tolist())

display("Sample row:", train_df.head(3))

Filtered dataframe shape: (10, 8)
Columns: ['encounter_id', 'qid', 'question_en', 'options_en', 'answer_text', 'image_ids', 'question_type_en', 'question_category_en']


'Sample row:'

Unnamed: 0,encounter_id,qid,question_en,options_en,answer_text,image_ids,question_type_en,question_category_en
0,ENC00001,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00001_00001.jpg, IMG_ENC00001_00002.jpg]",Site,General
1,ENC00002,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00002_00001.jpg, IMG_ENC00002_00002.jp...",Site,General
2,ENC00003,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00003_00001.jpg, IMG_ENC00003_00002.jp...",Site,General


In [14]:
def process_batch(batch_df, batch_idx, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    batch_data = []
    
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"Batch {batch_idx}"):
        try:
            # Only take the first image from the list
            if not row['image_ids'] or len(row['image_ids']) == 0:
                continue
                
            # Get just the first image path
            image_path = os.path.join(train_images_dir, row['image_ids'][0])
            
            if not os.path.exists(image_path):
                continue

            # Verify the image is valid
            try:
                with Image.open(image_path) as img:
                    img.load()
            except Exception as e:
                print(f"Corrupt or unreadable image at {image_path} — {e}")
                continue
            
            # Format options text
            options_text = ", ".join([f"{i+1}. {opt}" for i, opt in enumerate(eval(row['options_en']))])
            
            # Create metadata string
            metadata = f"Type: {row.get('question_type_en', '')}, Category: {row.get('question_category_en', '')}"
            
            # Create the full query text with instructions
            query_text = f"Question: Based on the image, {row['question_en']}\nQuestion Metadata: {metadata}\nOptions: {options_text}"
#             query_text += "\n\nCRITICAL INSTRUCTION: Only respond with an option if it is **clearly and unambiguously** supported by the image. If the image is unclear, incomplete, or could fit multiple answers, respond with: 'Not mentioned'. You must respond with the **exact text** of one option below. No numbers, no explanation. Given the medical context, err on the side of caution."
            
            batch_data.append({
                "encounter_id": row['encounter_id'],
                "qid": row['qid'],
                "query_text": query_text,
                "image_path": image_path,
                "answer_text": row['answer_text'],
                "question_type": row.get('question_type_en', ''),
                "question_category": row.get('question_category_en', '')
            })
        
        except Exception as e:
            print(f"Error processing row {idx}: {e}")
    
    batch_file = os.path.join(save_dir, f"batch_{batch_idx}.pkl")
    with open(batch_file, 'wb') as f:
        pickle.dump(batch_data, f)
    
    return len(batch_data)

In [15]:
def process_batch(batch_df, batch_idx, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    batch_data = []
    
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"Batch {batch_idx}"):
        try:
            # Only take the first image from the list
            if not row['image_ids'] or len(row['image_ids']) == 0:
                continue
                
            # Get just the first image path
            image_path = os.path.join(train_images_dir, row['image_ids'][0])
            
            if not os.path.exists(image_path):
                continue

            # Verify the image is valid
            try:
                with Image.open(image_path) as img:
                    img.load()
            except Exception as e:
                print(f"Corrupt or unreadable image at {image_path} — {e}")
                continue
            
            # Format options text
            options_text = ", ".join([f"{i+1}. {opt}" for i, opt in enumerate(eval(row['options_en']))])
            
            # Create metadata string
            metadata = f"Type: {row.get('question_type_en', '')}, Category: {row.get('question_category_en', '')}"
            
            # Create the full query text with instructions
            query_text = f"Question: Based on the image, {row['question_en']}\nQuestion Metadata: {metadata}\nOptions: {options_text}"
#             query_text += "\n\nCRITICAL INSTRUCTION: Only respond with an option if it is **clearly and unambiguously** supported by the image. If the image is unclear, incomplete, or could fit multiple answers, respond with: 'Not mentioned'. You must respond with the **exact text** of one option below. No numbers, no explanation. Given the medical context, err on the side of caution."
            
            batch_data.append({
                "encounter_id": row['encounter_id'],
                "qid": row['qid'],
                "query_text": query_text,
                "image_path": image_path,
                "answer_text": row['answer_text'],
                "question_type": row.get('question_type_en', ''),
                "question_category": row.get('question_category_en', '')
            })
        
        except Exception as e:
            print(f"Error processing row {idx}: {e}")
    
    batch_file = os.path.join(save_dir, f"batch_{batch_idx}.pkl")
    with open(batch_file, 'wb') as f:
        pickle.dump(batch_data, f)
    
    return len(batch_data)

In [16]:
class MedicalImageDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, processor):
        self.processor = processor
        self.examples = []
        
        for batch_file in sorted(os.listdir(data_dir)):
            if batch_file.startswith("batch_") and batch_file.endswith(".pkl"):
                with open(os.path.join(data_dir, batch_file), 'rb') as f:
                    batch_data = pickle.load(f)
                    self.examples.extend(batch_data)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        # Open just one image, convert to RGB
        image = Image.open(example['image_path']).convert("RGB")
        
        # Use consistent system message
        system_message = "You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note this is not the full context so if you are unsure or speculate other regions being affected, respond with 'Not mentioned'. You must respond with the full text of one of the provided options, exactly as written. Do not include any additional words or reasoning. Given the medical context, err on the side of caution when uncertain."
        
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": example['query_text']},
                    {"type": "image", "image": image},
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": example['answer_text']}],
            },
        ]
        
        return {"messages": messages}

def preprocess_dataset(df, batch_size=50, save_dir="processed_data"):
    total_processed = 0
    
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i+batch_size]
        batch_idx = i // batch_size
        
        print(f"Processing batch {batch_idx+1}/{(len(df)-1)//batch_size + 1}")
        processed = process_batch(batch_df, batch_idx, save_dir)
        total_processed += processed
        
        gc.collect()
        
        print(f"Processed {total_processed} examples so far")
    
    return total_processed

In [17]:
processed_data_dir = "processed_data_debug"

if os.path.exists(processed_data_dir):
    shutil.rmtree(processed_data_dir)

In [18]:
total_examples = preprocess_dataset(train_df, batch_size=500, save_dir=processed_data_dir)
print(f"Total processed examples: {total_examples}")

Processing batch 1/1


Batch 0:   0%|          | 0/10 [00:00<?, ?it/s]

Processed 10 examples so far
Total processed examples: 10


In [20]:
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
model_id = "google/gemma-3-4b-it"
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [21]:
model_id = "google/gemma-3-4b-it"

if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16. Use a different GPU.")

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs, token=hf_token)
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)

print(f"Default chat template: {processor.tokenizer.chat_template}")
print(f"Special tokens map: {processor.tokenizer.special_tokens_map}")

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Default chat template: {{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
    {%- if messages[0]['content'] is string -%}
        {%- set first_user_prefix = messages[0]['content'] + '

' -%}
    {%- else -%}
        {%- set first_user_prefix = messages[0]['content'][0]['text'] + '

' -%}
    {%- endif -%}
    {%- set loop_messages = messages[1:] -%}
{%- else -%}
    {%- set first_user_prefix = "" -%}
    {%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
        {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
    {%- endif -%}
    {%- if (message['role'] == 'assistant') -%}
        {%- set role = "model" -%}
    {%- else -%}
        {%- set role = message['role'] -%}
    {%- endif -%}
    {{ '<start_of_turn>' + role + '
' + (first_user_prefix if loop.first else "") }}
    {%- if message['content'] is string -%}
        {{ messag

In [38]:
def generate_medical_analysis(image_path, question, options, question_type="", question_category="", model=model, processor=processor):
    """
    Generate medical image analysis based on a question and image using the base model.
    
    Args:
        image_path: Path to the medical image
        question: The medical question to analyze
        options: List of possible options to choose from
        question_type: Optional metadata about question type
        question_category: Optional metadata about question category
        model: The model to use for inference
        processor: The processor to use for tokenization
        
    Returns:
        The model's response (selected option)
    """
    # Load the image
    if isinstance(image_path, str):
        image = Image.open(image_path).convert("RGB")
    else:
        # Already a PIL Image
        image = image_path.convert("RGB") if hasattr(image_path, 'convert') else image_path
    
    # Format options text
    options_text = ", ".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
    
    # Create metadata string
    metadata = f"Type: {question_type}, Category: {question_category}"
    
    # Create the query text
    query_text = f"Question: Based on the image, {question}\nQuestion Metadata: {metadata}\nOptions: {options_text}"
    
    # System message
    system_message = "You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note this is not the full context so if you are unsure or speculate other regions being affected, respond with 'Not mentioned'. You must respond with the full text of one of the provided options, exactly as written. Do not include any additional words or reasoning. Given the medical context, err on the side of caution when uncertain."
    
    # Convert to messages format
    messages = [
        {"role": "system", "content": [{"type": "text", "text": system_message}]},
        {"role": "user", "content": [
            {"type": "text", "text": query_text},
            {"type": "image", "image": image},
        ]},
    ]
    
    # Apply chat template
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    # Extract image from messages - needs to be in correct format for processor
    image_inputs = []
    for msg in messages:
        if msg["role"] == "user":
            for content in msg["content"]:
                if isinstance(content, dict) and content.get("type") == "image" and "image" in content:
                    image_inputs.append(content["image"])
                    break
    
    # Tokenize the text and process the images
    inputs = processor(
        text=[text],
        images=[image_inputs],  # Nested list as processor expects
        padding=True,
        return_tensors="pt",
    )
    
    # Move the inputs to the device - IMPORTANT: preserve the object structure
    for key in inputs:
        inputs[key] = inputs[key].to(model.device)
    
    # Generate the output
    stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
    generated_ids = model.generate(
        **inputs, 
        max_new_tokens=64,  # Shorter is fine for option selection
        top_p=0.9,
        do_sample=True, 
        temperature=0.5,  # Lower temp for more precise answers
        eos_token_id=stop_token_ids,
        disable_compile=True
    )
    
    # Trim the generation and decode the output to text
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    return output_text[0].strip()

In [39]:
def load_test_cases(csv_file, images_dir, num_samples=10):
    """
    Load test cases from the CSV file
    """
    df = pd.read_csv(csv_file)
    
    # Convert string representations of lists to actual lists
    df['image_ids'] = df['image_ids'].apply(eval)
    
    # Only take a subset for testing
    test_df = df.head(num_samples)
    
    test_cases = []
    for _, row in test_df.iterrows():
        if not row['image_ids'] or len(row['image_ids']) == 0:
            continue
            
        # Get the first image path
        image_path = os.path.join(images_dir, row['image_ids'][0])
        
        # Check if image exists
        if not os.path.exists(image_path):
            print(f"Warning: Image not found at {image_path}")
            continue
            
        # Format options
        options = eval(row['options_en'])
        
        test_cases.append({
            "image_path": image_path,
            "question": row['question_en'],
            "options": options,
            "question_type": row.get('question_type_en', ''),
            "question_category": row.get('question_category_en', ''),
            "expected_answer": row['answer_text']
        })
    
    return test_cases

In [40]:
# Main function
if __name__ == "__main__":
    # Test with a sample image
    print("\n=== Sample Test ===")
    image_path = "2025_dataset/train/images_train/IMG_ENC00001_00001.jpg"
    question = "How much of the body is affected?"
    options = ["single spot", "limited area", "widespread", "Not mentioned"]
    question_type = "Site"
    question_category = "General"
    
    result = generate_medical_analysis(
        image_path=image_path,
        question=question,
        options=options,
        question_type=question_type,
        question_category=question_category
    )
    
    print(f"Question: {question}")
    print(f"Options: {options}")
    print(f"Model's answer: {result}")
    
    # Test with multiple examples from the dataset
    print("\n=== Dataset Tests ===")
    csv_file = os.path.join("2025_dataset", "train", "final_df_2.csv")
    images_dir = os.path.join("2025_dataset", "train", "images_train")
    
    try:
        test_cases = load_test_cases(csv_file, images_dir, num_samples=5)
        
        correct = 0
        for i, test_case in enumerate(test_cases):
            print(f"\nTest Case {i+1}:")
            result = generate_medical_analysis(
                image_path=test_case["image_path"],
                question=test_case["question"],
                options=test_case["options"],
                question_type=test_case["question_type"],
                question_category=test_case["question_category"]
            )
            
            print(f"Question: {test_case['question']}")
            print(f"Options: {test_case['options']}")
            print(f"Model's answer: {result}")
            print(f"Expected answer: {test_case['expected_answer']}")
            
            if result.lower() == test_case['expected_answer'].lower():
                correct += 1
                print("✓ Correct")
            else:
                print("✗ Incorrect")
        
        if test_cases:
            print(f"\nAccuracy: {correct}/{len(test_cases)} = {correct/len(test_cases):.2%}")
    
    except Exception as e:
        print(f"Error loading test cases: {e}")
        
        # Fall back to hardcoded test cases if CSV loading fails
        print("\n=== Fallback Test Cases ===")
        test_cases = [
            {
                "image_path": "2025_dataset/train/images_train/IMG_ENC00002_00001.jpg",
                "question": "How much of the body is affected?",
                "options": ["single spot", "limited area", "widespread", "Not mentioned"],
                "question_type": "Site",
                "question_category": "General"
            },
            {
                "image_path": "2025_dataset/train/images_train/IMG_ENC00003_00001.jpg",
                "question": "How much of the body is affected?",
                "options": ["single spot", "limited area", "widespread", "Not mentioned"],
                "question_type": "Site", 
                "question_category": "General"
            }
        ]
        
        for i, test_case in enumerate(test_cases):
            if os.path.exists(test_case["image_path"]):
                print(f"\nTest Case {i+1}:")
                result = generate_medical_analysis(**test_case)
                print(f"Question: {test_case['question']}")
                print(f"Options: {test_case['options']}")
                print(f"Model's answer: {result}")
            else:
                print(f"Image not found: {test_case['image_path']}")
        
    # Free up GPU memory
    del model
    torch.cuda.empty_cache()
    print("\nInference completed and GPU memory released")


=== Sample Test ===
Question: How much of the body is affected?
Options: ['single spot', 'limited area', 'widespread', 'Not mentioned']
Model's answer: widespread

=== Dataset Tests ===

Test Case 1:
Question: How much of the body is affected?
Options: ['single spot', 'limited area', 'widespread', 'Not mentioned']
Model's answer: widespread
Expected answer: limited area
✗ Incorrect

Test Case 2:
Question: How much of the body is affected?
Options: ['single spot', 'limited area', 'widespread', 'Not mentioned']
Model's answer: limited area
Expected answer: limited area
✓ Correct

Test Case 3:
Question: How much of the body is affected?
Options: ['single spot', 'limited area', 'widespread', 'Not mentioned']
Model's answer: limited area
Expected answer: limited area
✓ Correct

Test Case 4:
Question: How much of the body is affected?
Options: ['single spot', 'limited area', 'widespread', 'Not mentioned']
Model's answer: limited area
Expected answer: widespread
✗ Incorrect

Test Case 5:
Que