# Installations

In [2]:
import os
import sys
import torch
import numpy as np
import pandas as pd
from datasets import Dataset
from PIL import Image
import gc
import pickle
from tqdm.auto import tqdm
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
from dotenv import load_dotenv
from torch.utils.data import DataLoader
import shutil



In [3]:
gc.collect()
torch.cuda.empty_cache()

In [4]:
print("TRANSFORMERS_CACHE:", os.getenv("TRANSFORMERS_CACHE"))

TRANSFORMERS_CACHE: /storage/coda1/p-dsgt_clef2025/0/kthakrar3/hf_cache


In [5]:
os.environ.pop("TRANSFORMERS_CACHE", None)

'/storage/coda1/p-dsgt_clef2025/0/kthakrar3/hf_cache'

In [6]:
os.environ["HF_HOME"] = "/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.hf_cache"

# where the downloaded model goes

In [7]:
print("TRANSFORMERS_CACHE:", os.getenv("TRANSFORMERS_CACHE"))
print("HF_HOME:", os.getenv("HF_HOME"))

TRANSFORMERS_CACHE: None
HF_HOME: /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.hf_cache


In [8]:
print(sys.executable)

print(f"Python version: {sys.version}")
print(f"Python version info: {sys.version_info}")

/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.venv/bin/python
Python version: 3.10.10 (main, Apr 15 2024, 11:52:16) [GCC 11.4.1 20230605 (Red Hat 11.4.1-2)]
Python version info: sys.version_info(major=3, minor=10, micro=10, releaselevel='final', serial=0)


In [9]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
CUDA device count: 1
Current CUDA device: 0
CUDA device name: NVIDIA A100 80GB PCIe


# Data loading + preprocessing

In [10]:
train_csv_file = os.path.join("2025_dataset", "train", "final_df_2.csv")
train_images_dir = os.path.join("2025_dataset", "train", "images_train")

train_df = pd.read_csv(train_csv_file)

train_df['image_ids'] = train_df['image_ids'].apply(eval)
train_df['responses_en'] = train_df['responses_en'].apply(eval)

train_df.head()

Unnamed: 0,encounter_id,qid,answer_index,question_en,options_en,question_type_en,question_category_en,answer_text,author_id,image_ids,responses,query_title_en,query_content_en,image_paths,responses_en
0,ENC00001,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,limited area,U04473,"[IMG_ENC00001_00001.jpg, IMG_ENC00001_00002.jpg]","[{'author_id': 'U00217', 'content_zh': '银屑病，似与...",Pleural effusion accompanied by rash,A patient with pleural effusion is accompanied...,['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,[Psoriasis seems to have no relation to pleura...
1,ENC00002,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,limited area,U06063,"[IMG_ENC00002_00001.jpg, IMG_ENC00002_00002.jp...","[{'author_id': 'U11305', 'content_zh': '脚气', '...",What is on the bottom of the right foot?,"The patient is a 50-year-old male, who has bee...",['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,[Beriberi]
2,ENC00003,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,limited area,U00780,"[IMG_ENC00003_00001.jpg, IMG_ENC00003_00002.jp...","[{'author_id': 'U01131', 'content_zh': '瘙痒症，有无...",Interpreting Images - Is it magical skin?,"Male, 65 years old, skin lesions as shown in t...",['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,"[Pruritus, is there any other special medical ..."
3,ENC00004,CQID010-001,2,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,widespread,U00209,"[IMG_ENC00004_00001.jpg, IMG_ENC00004_00002.jpg]","[{'author_id': 'U06715', 'content_zh': '肢端角化病？...",Skin Disease,"Male, 15 years old, keratosis on both palms, s...",['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,"[Acrokeratosis?, Progressive Symmetrical Eryth..."
4,ENC00005,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",Site,General,limited area,U09050,[IMG_ENC00005_00001.jpg],"[{'author_id': 'U09402', 'content_zh': '是否神经性皮...",Perifollicular atrophy?,"Young female, silver-gray dot-like atrophy spo...",['/storage/coda1/p-dsgt_clef2025/0/kthakrar3/m...,"[Is it neurodermatitis?, Impotence?, Lichen Sc..."


In [11]:
train_df = train_df.head(10)  # Start with 10 samples for quick debugging
print(f"Using {len(train_df)} samples for training")

Using 10 samples for training


In [12]:
train_df = train_df[['encounter_id', 'qid', 'question_en', 'options_en', 'answer_text', 'image_ids', 'question_type_en', 'question_category_en']]

print(f"Filtered dataframe shape: {train_df.shape}")
print("Columns:", train_df.columns.tolist())

display("Sample row:", train_df.head(3))

Filtered dataframe shape: (10, 8)
Columns: ['encounter_id', 'qid', 'question_en', 'options_en', 'answer_text', 'image_ids', 'question_type_en', 'question_category_en']


'Sample row:'

Unnamed: 0,encounter_id,qid,question_en,options_en,answer_text,image_ids,question_type_en,question_category_en
0,ENC00001,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00001_00001.jpg, IMG_ENC00001_00002.jpg]",Site,General
1,ENC00002,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00002_00001.jpg, IMG_ENC00002_00002.jp...",Site,General
2,ENC00003,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00003_00001.jpg, IMG_ENC00003_00002.jp...",Site,General


In [13]:
def process_batch(batch_df, batch_idx, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    batch_data = []
    
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"Batch {batch_idx}"):
        try:
            # Only take the first image from the list
            if not row['image_ids'] or len(row['image_ids']) == 0:
                continue
                
            # Get just the first image path
            image_path = os.path.join(train_images_dir, row['image_ids'][0])
            
            if not os.path.exists(image_path):
                continue

            # Verify the image is valid
            try:
                with Image.open(image_path) as img:
                    img.load()
            except Exception as e:
                print(f"Corrupt or unreadable image at {image_path} — {e}")
                continue
            
            # Format options text
            options_text = ", ".join([f"{i+1}. {opt}" for i, opt in enumerate(eval(row['options_en']))])
            
            # Create metadata string
            metadata = f"Type: {row.get('question_type_en', '')}, Category: {row.get('question_category_en', '')}"
            
            # Create the full query text with instructions
            query_text = f"Question: Based on the image, {row['question_en']}\nQuestion Metadata: {metadata}\nOptions: {options_text}"
#             query_text += "\n\nCRITICAL INSTRUCTION: Only respond with an option if it is **clearly and unambiguously** supported by the image. If the image is unclear, incomplete, or could fit multiple answers, respond with: 'Not mentioned'. You must respond with the **exact text** of one option below. No numbers, no explanation. Given the medical context, err on the side of caution."
            
            batch_data.append({
                "encounter_id": row['encounter_id'],
                "qid": row['qid'],
                "query_text": query_text,
                "image_path": image_path,
                "answer_text": row['answer_text'],
                "question_type": row.get('question_type_en', ''),
                "question_category": row.get('question_category_en', '')
            })
        
        except Exception as e:
            print(f"Error processing row {idx}: {e}")
    
    batch_file = os.path.join(save_dir, f"batch_{batch_idx}.pkl")
    with open(batch_file, 'wb') as f:
        pickle.dump(batch_data, f)
    
    return len(batch_data)

In [14]:
class MedicalImageDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, processor):
        self.processor = processor
        self.examples = []
        
        for batch_file in sorted(os.listdir(data_dir)):
            if batch_file.startswith("batch_") and batch_file.endswith(".pkl"):
                with open(os.path.join(data_dir, batch_file), 'rb') as f:
                    batch_data = pickle.load(f)
                    self.examples.extend(batch_data)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        # Open just one image, convert to RGB
        image = Image.open(example['image_path']).convert("RGB")
        
        # Use consistent system message
        system_message = "You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note this is not the full context so if you are unsure or speculate other regions being affected, respond with 'Not mentioned'. You must respond with the full text of one of the provided options, exactly as written. Do not include any additional words or reasoning. Given the medical context, err on the side of caution when uncertain."
        
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": example['query_text']},
                    {"type": "image", "image": image},
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": example['answer_text']}],
            },
        ]
        
        return {"messages": messages}

def preprocess_dataset(df, batch_size=50, save_dir="processed_data"):
    total_processed = 0
    
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i+batch_size]
        batch_idx = i // batch_size
        
        print(f"Processing batch {batch_idx+1}/{(len(df)-1)//batch_size + 1}")
        processed = process_batch(batch_df, batch_idx, save_dir)
        total_processed += processed
        
        gc.collect()
        
        print(f"Processed {total_processed} examples so far")
    
    return total_processed

In [15]:
processed_data_dir = "processed_data_debug"

if os.path.exists(processed_data_dir):
    shutil.rmtree(processed_data_dir)

In [16]:
total_examples = preprocess_dataset(train_df, batch_size=500, save_dir=processed_data_dir)
print(f"Total processed examples: {total_examples}")

Processing batch 1/1


Batch 0:   0%|          | 0/10 [00:00<?, ?it/s]

Processed 10 examples so far
Total processed examples: 10


In [17]:
batch_file = os.path.join(processed_data_dir, "batch_0.pkl")
with open(batch_file, 'rb') as f:
    batch_data = pickle.load(f)

print("\nSample of processed data (first example):")
sample_data = batch_data[0]
print(f"Encounter ID: {sample_data['encounter_id']}")
print(f"Question ID: {sample_data['qid']}")
print(f"Query text: {sample_data['query_text']}")
# print(f"Number of images: {len(sample_data['image_path'])}")
print(f"Image path: {sample_data['image_path']}")  # Changed from 'image_paths' to 'image_path'
print(f"Answer text: {sample_data['answer_text']}")
print(f"Question type: {sample_data['question_type']}")
print(f"Question category: {sample_data['question_category']}")


Sample of processed data (first example):
Encounter ID: ENC00001
Question ID: CQID010-001
Query text: Question: Based on the image, How much of the body is affected?
Question Metadata: Type: Site, Category: General
Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned
Image path: 2025_dataset/train/images_train/IMG_ENC00001_00001.jpg
Answer text: limited area
Question type: Site
Question category: General


In [18]:
batch_file = os.path.join(processed_data_dir, "batch_0.pkl")
with open(batch_file, 'rb') as f:
    batch_data = pickle.load(f)

print("\nSample of processed data (first example):")
sample_data = batch_data[1]
print(f"Encounter ID: {sample_data['encounter_id']}")
print(f"Question ID: {sample_data['qid']}")
print(f"Query text: {sample_data['query_text']}")
# print(f"Number of images: {len(sample_data['image_path'])}")
print(f"Image path: {sample_data['image_path']}")  # Changed from 'image_paths' to 'image_path'
print(f"Answer text: {sample_data['answer_text']}")
print(f"Question type: {sample_data['question_type']}")
print(f"Question category: {sample_data['question_category']}")


Sample of processed data (first example):
Encounter ID: ENC00002
Question ID: CQID010-001
Query text: Question: Based on the image, How much of the body is affected?
Question Metadata: Type: Site, Category: General
Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned
Image path: 2025_dataset/train/images_train/IMG_ENC00002_00001.jpg
Answer text: limited area
Question type: Site
Question category: General


In [19]:
batch_file = os.path.join(processed_data_dir, "batch_0.pkl")
with open(batch_file, 'rb') as f:
    batch_data = pickle.load(f)

print("\nSample of processed data (first example):")
sample_data = batch_data[2]
print(f"Encounter ID: {sample_data['encounter_id']}")
print(f"Question ID: {sample_data['qid']}")
print(f"Query text: {sample_data['query_text']}")
# print(f"Number of images: {len(sample_data['image_path'])}")
print(f"Image path: {sample_data['image_path']}")  # Changed from 'image_paths' to 'image_path'
print(f"Answer text: {sample_data['answer_text']}")
print(f"Question type: {sample_data['question_type']}")
print(f"Question category: {sample_data['question_category']}")


Sample of processed data (first example):
Encounter ID: ENC00003
Question ID: CQID010-001
Query text: Question: Based on the image, How much of the body is affected?
Question Metadata: Type: Site, Category: General
Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned
Image path: 2025_dataset/train/images_train/IMG_ENC00003_00001.jpg
Answer text: limited area
Question type: Site
Question category: General


In [20]:
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
model_id = "google/gemma-3-4b-it"
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [21]:
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"\nDataset size: {len(dataset)}")

example = dataset[0]
print(f"\nFirst example roles:")
print(f"System role: {example['messages'][0]}")
print(f"User role: {example['messages'][1]}")
print(f"Assistant role: {example['messages'][2]}")


Dataset size: 10

First example roles:
System role: {'role': 'system', 'content': [{'type': 'text', 'text': "You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note this is not the full context so if you are unsure or speculate other regions being affected, respond with 'Not mentioned'. You must respond with the full text of one of the provided options, exactly as written. Do not include any additional words or reasoning. Given the medical context, err on the side of caution when uncertain."}]}
User role: {'role': 'user', 'content': [{'type': 'text', 'text': 'Question: Based on the image, How much of the body is affected?\nQuestion Metadata: Type: Site, Category: General\nOptions: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=1944x2541 at 0x15541D84F850>}]}
Assistant role: {'role': 'a

In [22]:
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"\nDataset size: {len(dataset)}")

example = dataset[1]
print(f"\nFirst example roles:")
print(f"System role: {example['messages'][0]}")
print(f"User role: {example['messages'][1]}")
print(f"Assistant role: {example['messages'][2]}")


Dataset size: 10

First example roles:
System role: {'role': 'system', 'content': [{'type': 'text', 'text': "You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note this is not the full context so if you are unsure or speculate other regions being affected, respond with 'Not mentioned'. You must respond with the full text of one of the provided options, exactly as written. Do not include any additional words or reasoning. Given the medical context, err on the side of caution when uncertain."}]}
User role: {'role': 'user', 'content': [{'type': 'text', 'text': 'Question: Based on the image, How much of the body is affected?\nQuestion Metadata: Type: Site, Category: General\nOptions: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=2560x1920 at 0x15541D962920>}]}
Assistant role: {'role': 'a

In [23]:
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"\nDataset size: {len(dataset)}")

example = dataset[2]
print(f"\nFirst example roles:")
print(f"System role: {example['messages'][0]}")
print(f"User role: {example['messages'][1]}")
print(f"Assistant role: {example['messages'][2]}")


Dataset size: 10

First example roles:
System role: {'role': 'system', 'content': [{'type': 'text', 'text': "You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note this is not the full context so if you are unsure or speculate other regions being affected, respond with 'Not mentioned'. You must respond with the full text of one of the provided options, exactly as written. Do not include any additional words or reasoning. Given the medical context, err on the side of caution when uncertain."}]}
User role: {'role': 'user', 'content': [{'type': 'text', 'text': 'Question: Based on the image, How much of the body is affected?\nQuestion Metadata: Type: Site, Category: General\nOptions: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=600x400 at 0x15541D3CB340>}]}
Assistant role: {'role': 'ass

# Load model

In [24]:
model_id = "google/gemma-3-4b-it"

if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16. Use a different GPU.")

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs, token=hf_token)
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)

print(f"Default chat template: {processor.tokenizer.chat_template}")
print(f"Special tokens map: {processor.tokenizer.special_tokens_map}")

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Default chat template: {{ bos_token }}
{%- if messages[0]['role'] == 'system' -%}
    {%- if messages[0]['content'] is string -%}
        {%- set first_user_prefix = messages[0]['content'] + '

' -%}
    {%- else -%}
        {%- set first_user_prefix = messages[0]['content'][0]['text'] + '

' -%}
    {%- endif -%}
    {%- set loop_messages = messages[1:] -%}
{%- else -%}
    {%- set first_user_prefix = "" -%}
    {%- set loop_messages = messages -%}
{%- endif -%}
{%- for message in loop_messages -%}
    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
        {{ raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }}
    {%- endif -%}
    {%- if (message['role'] == 'assistant') -%}
        {%- set role = "model" -%}
    {%- else -%}
        {%- set role = message['role'] -%}
    {%- endif -%}
    {{ '<start_of_turn>' + role + '
' + (first_user_prefix if loop.first else "") }}
    {%- if message['content'] is string -%}
        {{ messag

# Collate data

In [103]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
#     target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],
)

In [105]:
def create_dummy_image():
    """Create a small black image as a placeholder."""
    return Image.new('RGB', (224, 224), color='black')

In [108]:
def collate_fn(examples):
    texts = []
    images = []
    
    for example in examples:
        # Extract image from messages
        image_input = None
        for msg in example["messages"]:
            if msg["role"] == "user":
                for content in msg["content"]:
                    if isinstance(content, dict) and content.get("type") == "image" and "image" in content:
                        image_input = content["image"]
                        break
        
        if image_input is None:
            image_input = create_dummy_image()
            
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        
        texts.append(text.strip())
        images.append([image_input])
    
    batch = processor(
        text=texts, 
        images=images,
        return_tensors="pt", 
        padding=True
    )
    
    labels = batch["input_ids"].clone()
    
    # Get token IDs for start_of_image and end_of_image
    boi_token_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["boi_token"]
    )
    eoi_token_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["eoi_token"]
    )
    
    # Mask tokens that shouldn't contribute to the loss
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == boi_token_id] = -100
    labels[labels == eoi_token_id] = -100
    
    batch["labels"] = labels
    return batch

In [109]:
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"Dataset size: {len(dataset)}")

if len(dataset) == 0:
    print("ERROR: Dataset is empty! Check data loading process.")
    pass
else:
    sample_size = min(3, len(dataset))
    print(f"Sampling {sample_size} examples from dataset")
    
    sample_examples = [dataset[i] for i in range(sample_size)]
    
    print(f"Sample size: {len(sample_examples)}")
    print("First example keys:", list(sample_examples[0].keys()))
    
    batch = collate_fn(sample_examples)
    print("Collated batch contains:", list(batch.keys()))
    print(f"Input_ids shape: {batch['input_ids'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")

Dataset size: 10
Sampling 3 examples from dataset
Sample size: 3
First example keys: ['messages']
Collated batch contains: ['input_ids', 'attention_mask', 'token_type_ids', 'pixel_values', 'labels']
Input_ids shape: torch.Size([3, 422])
Labels shape: torch.Size([3, 422])


In [110]:
dataloader = DataLoader(
    dataset,
    batch_size=8,  # Adjust based on GPU memory
    shuffle=True,
    collate_fn=collate_fn
)

total_examples = 0
for batch in dataloader:
    # Process each batch. Note: in training, pass this to model.forward()
    batch_size = len(batch["input_ids"])
    total_examples += batch_size
    print(f"Processed batch with {batch_size} examples")

print(f"Processed all {total_examples} examples")

Processed batch with 8 examples
Processed batch with 2 examples
Processed all 10 examples


In [111]:
def debug_template_application(examples):
    print("\n=== DEBUGGING TEMPLATE APPLICATION ===")
    for i, example in enumerate(examples):
        print(f"\nExample {i}:")
        
        # First, check message structure
        print("Message structure:")
        for j, msg in enumerate(example["messages"]):
            role = msg.get("role", "unknown")
            print(f"  Message {j} role: {role}")
            
            content = msg.get("content", [])
            if not isinstance(content, list):
                content = [content]
                
            print(f"  Content types: {[c.get('type') if isinstance(c, dict) else type(c).__name__ for c in content]}")
            
            # Check specifically for image content
            image_count = sum(1 for c in content if isinstance(c, dict) and c.get('type') == 'image')
            print(f"  Image content count: {image_count}")
        
        # Next, check template application
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        
        # Check for various possible image tokens
        image_tokens = ["<image>", "<start_of_image>", "<image_soft_token>"]
        for token in image_tokens:
            count = text.count(token)
            print(f"  {token} tokens found: {count}")
        
        # Print text snippet to see if tokens appear
        print(f"  Text preview: {text[:200]}...")
        
        # Get all available tokens from special tokens map
        special_tokens = processor.tokenizer.special_tokens_map
        print(f"  Special tokens in processor: {special_tokens}")
        
        # Tokenize the text to see what token IDs are actually being used
        token_ids = processor.tokenizer(text, return_tensors="pt").input_ids[0]
        # Get a few token IDs around where an image might be
        if len(token_ids) > 20:
            print(f"  Sample token IDs (first 20): {token_ids[:20].tolist()}")
            
debug_template_application(sample_examples)


=== DEBUGGING TEMPLATE APPLICATION ===

Example 0:
Message structure:
  Message 0 role: system
  Content types: ['text']
  Image content count: 0
  Message 1 role: user
  Content types: ['text', 'image']
  Image content count: 1
  Message 2 role: assistant
  Content types: ['text']
  Image content count: 0
  <image> tokens found: 0
  <start_of_image> tokens found: 1
  <image_soft_token> tokens found: 0
  Text preview: <bos><start_of_turn>user
You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note t...
  Special tokens in processor: {'bos_token': '<bos>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'boi_token': '<start_of_image>', 'eoi_token': '<end_of_image>', 'image_token': '<image_soft_token>'}
  Sample token IDs (first 20): [2, 2, 105, 2364, 107, 3048, 659, 496, 5526, 2471, 3671, 16326, 236761, 5180, 1186, 4209, 563, 531, 17318, 506]

Example 

# Train model

In [113]:
args = SFTConfig(
    output_dir="gemma-product-description",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=5,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
    report_to="tensorboard",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,  # Critical for custom datasets
    label_names=["labels"],  # Explicitly setting label_names
)

# args.remove_unused_columns = False # Not needed in our case but leaving it in in case the loaded data changes

In [114]:
# # Test with a small batch
# test_indices = list(range(min(3, len(dataset))))
# test_batch = [dataset[i] for i in test_indices]

# Now initialize your actual trainer with the regular collate_fn
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

In [115]:
# Start training
trainer.train()

Step,Training Loss




TrainOutput(global_step=2, training_loss=60.474578857421875, metrics={'train_runtime': 43.9671, 'train_samples_per_second': 0.227, 'train_steps_per_second': 0.045, 'total_flos': 87786641260032.0, 'train_loss': 60.474578857421875})

In [56]:
# Save the model
trainer.save_model()
print("Training complete and model saved!")



Training complete and model saved!


In [116]:
# Before you can test your model, make sure to free the memory.

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

In [117]:
from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

['merged_model/processor_config.json']

# Perform inference

In [118]:
# Load Model with PEFT adapter
model = AutoModelForImageTextToText.from_pretrained(
    "merged_model",  # Use merged model for inference
    device_map="auto",
    torch_dtype=torch.bfloat16,
    attn_implementation="eager",
)
processor = AutoProcessor.from_pretrained("merged_model")

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [120]:
def generate_medical_analysis(image_path, question, options, question_type="", question_category="", model=model, processor=processor):
    """
    Generate medical image analysis based on a question and image.
    
    Args:
        image_path: Path to the medical image
        question: The medical question to analyze
        options: List of possible options to choose from
        question_type: Optional metadata about question type
        question_category: Optional metadata about question category
        model: The model to use for inference
        processor: The processor to use for tokenization
        
    Returns:
        The model's response (selected option)
    """
    # Load the image
    if isinstance(image_path, str):
        image = Image.open(image_path).convert("RGB")
    else:
        # Already a PIL Image
        image = image_path.convert("RGB") if hasattr(image_path, 'convert') else image_path
    
    # Format options text
    options_text = ", ".join([f"{i+1}. {opt}" for i, opt in enumerate(options)])
    
    # Create metadata string
    metadata = f"Type: {question_type}, Category: {question_category}"
    
    # Create the query text
    query_text = f"Question: Based on the image, {question}\nQuestion Metadata: {metadata}\nOptions: {options_text}"
    
    # System message
    system_message = "You are a medical image analysis assistant. Your only task is to examine the provided clinical images and select the exact option text that best describes what you see. Note this is not the full context so if you are unsure or speculate other regions being affected, respond with 'Not mentioned'. You must respond with the full text of one of the provided options, exactly as written. Do not include any additional words or reasoning. Given the medical context, err on the side of caution when uncertain."
    
    # Convert to messages format
    messages = [
        {"role": "system", "content": [{"type": "text", "text": system_message}]},
        {"role": "user", "content": [
            {"type": "text", "text": query_text},
            {"type": "image", "image": image},
        ]},
    ]
    
    # Apply chat template
    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    
    # Extract image from messages - needs to be in correct format for processor
    image_inputs = []
    for msg in messages:
        if msg["role"] == "user":
            for content in msg["content"]:
                if isinstance(content, dict) and content.get("type") == "image" and "image" in content:
                    image_inputs.append(content["image"])
                    break
    
    # Tokenize the text and process the images
    inputs = processor(
        text=[text],
        images=[image_inputs],  # Nested list as processor expects
        padding=True,
        return_tensors="pt",
    )
    
    # Move the inputs to the device
    inputs = inputs.to(model.device)
    
    # Generate the output
    stop_token_ids = [processor.tokenizer.eos_token_id, processor.tokenizer.convert_tokens_to_ids("<end_of_turn>")]
    generated_ids = model.generate(
        **inputs, 
        max_new_tokens=64,  # Shorter is fine for option selection
        top_p=0.9,
        do_sample=True, 
        temperature=0.5,  # Lower temp for more precise answers
        eos_token_id=stop_token_ids, 
        disable_compile=True
    )
    
    # Trim the generation and decode the output to text
    generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    
    return output_text[0]

In [121]:
# Example usage:
if __name__ == "__main__":
    # Test with a sample image from your test set
    image_path = "2025_dataset/train/images_train/IMG_ENC00001_00001.jpg"
    question = "How much of the body is affected?"
    options = ["single spot", "limited area", "widespread", "Not mentioned"]
    question_type = "Site"
    question_category = "General"
    
    result = generate_medical_analysis(
        image_path=image_path,
        question=question,
        options=options,
        question_type=question_type,
        question_category=question_category
    )
    
    print(f"Question: {question}")
    print(f"Options: {options}")
    print(f"Model's answer: {result}")
    
    # You can test with multiple examples
    test_cases = [
        {
            "image_path": "2025_dataset/train/images_train/IMG_ENC00002_00001.jpg",
            "question": "How much of the body is affected?",
            "options": ["single spot", "limited area", "widespread", "Not mentioned"],
            "question_type": "Site",
            "question_category": "General"
        },
        {
            "image_path": "2025_dataset/train/images_train/IMG_ENC00003_00001.jpg",
            "question": "How much of the body is affected?",
            "options": ["single spot", "limited area", "widespread", "Not mentioned"],
            "question_type": "Site", 
            "question_category": "General"
        }
    ]
    
    print("\n=== More Test Cases ===")
    for i, test_case in enumerate(test_cases):
        print(f"\nTest Case {i+1}:")
        result = generate_medical_analysis(**test_case)
        print(f"Question: {test_case['question']}")
        print(f"Options: {test_case['options']}")
        print(f"Model's answer: {result}")

Question: How much of the body is affected?
Options: ['single spot', 'limited area', 'widespread', 'Not mentioned']
Model's answer: 3. widespread

=== More Test Cases ===

Test Case 1:
Question: How much of the body is affected?
Options: ['single spot', 'limited area', 'widespread', 'Not mentioned']
Model's answer: limited area

Test Case 2:
Question: How much of the body is affected?
Options: ['single spot', 'limited area', 'widespread', 'Not mentioned']
Model's answer: limited area
