# Installations

In [1]:
import os
import sys
import torch
import numpy as np
import pandas as pd
from datasets import Dataset
from PIL import Image
import gc
import pickle
from tqdm.auto import tqdm
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
from dotenv import load_dotenv
from torch.utils.data import DataLoader
import shutil



In [2]:
gc.collect()
torch.cuda.empty_cache()

In [3]:
print("TRANSFORMERS_CACHE:", os.getenv("TRANSFORMERS_CACHE"))

TRANSFORMERS_CACHE: /storage/coda1/p-dsgt_clef2025/0/kthakrar3/hf_cache


In [4]:
os.environ.pop("TRANSFORMERS_CACHE", None)

'/storage/coda1/p-dsgt_clef2025/0/kthakrar3/hf_cache'

In [5]:
os.environ["HF_HOME"] = "/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.hf_cache"

# where the downloaded model goes

In [6]:
print("TRANSFORMERS_CACHE:", os.getenv("TRANSFORMERS_CACHE"))
print("HF_HOME:", os.getenv("HF_HOME"))

TRANSFORMERS_CACHE: None
HF_HOME: /storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.hf_cache


In [7]:
print(sys.executable)

print(f"Python version: {sys.version}")
print(f"Python version info: {sys.version_info}")

/storage/coda1/p-dsgt_clef2025/0/kthakrar3/mediqa-magic-v2/.venv/bin/python
Python version: 3.10.10 (main, Apr 15 2024, 11:52:16) [GCC 11.4.1 20230605 (Red Hat 11.4.1-2)]
Python version info: sys.version_info(major=3, minor=10, micro=10, releaselevel='final', serial=0)


In [8]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")
    print(f"CUDA device name: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
CUDA device count: 1
Current CUDA device: 0
CUDA device name: NVIDIA A100 80GB PCIe


# Data loading + preprocessing

In [9]:
train_csv_file = os.path.join("2025_dataset", "train", "final_df.csv")
train_images_dir = os.path.join("2025_dataset", "train", "images_train")

train_df = pd.read_csv(train_csv_file)

train_df['image_ids'] = train_df['image_ids'].apply(eval)
train_df['responses_en'] = train_df['responses_en'].apply(eval)

train_df.head()

Unnamed: 0,encounter_id,qid,answer_index,question_en,options_en,answer_text,author_id,image_ids,responses,query_title_en,query_content_en,image_paths,responses_en
0,ENC00001,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,U04473,"[IMG_ENC00001_00001.jpg, IMG_ENC00001_00002.jpg]","[{'author_id': 'U00217', 'content_zh': '银屑病，似与...",Pleural effusion accompanied by rash,A patient with pleural effusion is accompanied...,['C:\\Users\\karishma\\OneDrive\\Projects\\med...,[Psoriasis seems to have no relation to pleura...
1,ENC00002,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,U06063,"[IMG_ENC00002_00001.jpg, IMG_ENC00002_00002.jp...","[{'author_id': 'U11305', 'content_zh': '脚气', '...",What is on the bottom of the right foot?,"The patient is a 50-year-old male, who has bee...",['C:\\Users\\karishma\\OneDrive\\Projects\\med...,[Beriberi]
2,ENC00003,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,U00780,"[IMG_ENC00003_00001.jpg, IMG_ENC00003_00002.jp...","[{'author_id': 'U01131', 'content_zh': '瘙痒症，有无...",Interpreting Images - Is it magical skin?,"Male, 65 years old, skin lesions as shown in t...",['C:\\Users\\karishma\\OneDrive\\Projects\\med...,"[Pruritus, is there any other special medical ..."
3,ENC00004,CQID010-001,2,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",widespread,U00209,"[IMG_ENC00004_00001.jpg, IMG_ENC00004_00002.jpg]","[{'author_id': 'U06715', 'content_zh': '肢端角化病？...",Skin Disease,"Male, 15 years old, keratosis on both palms, s...",['C:\\Users\\karishma\\OneDrive\\Projects\\med...,"[Acrokeratosis?, Progressive Symmetrical Eryth..."
4,ENC00005,CQID010-001,1,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,U09050,[IMG_ENC00005_00001.jpg],"[{'author_id': 'U09402', 'content_zh': '是否神经性皮...",Perifollicular atrophy?,"Young female, silver-gray dot-like atrophy spo...",['C:\\Users\\karishma\\OneDrive\\Projects\\med...,"[Is it neurodermatitis?, Impotence?, Lichen Sc..."


In [10]:
train_df = train_df.head(10)  # Start with 10 samples for quick debugging
print(f"Using {len(train_df)} samples for training")

Using 10 samples for training


In [11]:
train_df = train_df[['encounter_id', 'qid', 'question_en', 'options_en', 'answer_text', 'image_ids']]

print(f"Filtered dataframe shape: {train_df.shape}")
print("Columns:", train_df.columns.tolist())

display("Sample row:", train_df.head(3))

Filtered dataframe shape: (10, 6)
Columns: ['encounter_id', 'qid', 'question_en', 'options_en', 'answer_text', 'image_ids']


'Sample row:'

Unnamed: 0,encounter_id,qid,question_en,options_en,answer_text,image_ids
0,ENC00001,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00001_00001.jpg, IMG_ENC00001_00002.jpg]"
1,ENC00002,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00002_00001.jpg, IMG_ENC00002_00002.jp..."
2,ENC00003,CQID010-001,How much of the body is affected?,"['single spot', 'limited area', 'widespread', ...",limited area,"[IMG_ENC00003_00001.jpg, IMG_ENC00003_00002.jp..."


In [12]:
def process_batch(batch_df, batch_idx, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    batch_data = []
    
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), desc=f"Batch {batch_idx}"):
        try:
            image_paths = [os.path.join(train_images_dir, img_id) for img_id in row['image_ids']]
            
            if not all(os.path.exists(img_path) for img_path in image_paths):
                continue

            valid_images = []
            for img_path in image_paths:
                try:
                    with Image.open(img_path) as img:
                        img.load()
                    valid_images.append(img_path)
                except Exception as e:
                    print(f"Corrupt or unreadable image at {img_path} — {e}")

            if len(valid_images) != len(image_paths):
                continue
            
            options_text = ", ".join([f"{i+1}. {opt}" for i, opt in enumerate(eval(row['options_en']))])
            query_text = f"Question: {row['question_en']} Options: {options_text}"
            
            batch_data.append({
                "encounter_id": row['encounter_id'],
                "qid": row['qid'],
                "query_text": query_text,
                "image_paths": valid_images,
                "answer_text": row['answer_text']
            })
        
        except Exception as e:
            print(f"Error processing row {idx}: {e}")
    
    batch_file = os.path.join(save_dir, f"batch_{batch_idx}.pkl")
    with open(batch_file, 'wb') as f:
        pickle.dump(batch_data, f)
    
    return len(batch_data)

In [13]:
class MedicalImageDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, processor):
        self.processor = processor
        self.examples = []
        
        for batch_file in sorted(os.listdir(data_dir)):
            if batch_file.startswith("batch_") and batch_file.endswith(".pkl"):
                with open(os.path.join(data_dir, batch_file), 'rb') as f:
                    batch_data = pickle.load(f)
                    self.examples.extend(batch_data)
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        example = self.examples[idx]
        
        images = [Image.open(path).convert("RGB") for path in example['image_paths']]
        
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are an AI assistant answering medical questions based on images."}],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": example['query_text']},
                    *[{"type": "image", "image": img} for img in images],
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": example['answer_text']}],
            },
        ]
        
        return {"messages": messages}

def preprocess_dataset(df, batch_size=50, save_dir="processed_data"):
    total_processed = 0
    
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i+batch_size]
        batch_idx = i // batch_size
        
        print(f"Processing batch {batch_idx+1}/{(len(df)-1)//batch_size + 1}")
        processed = process_batch(batch_df, batch_idx, save_dir)
        total_processed += processed
        
        gc.collect()
        
        print(f"Processed {total_processed} examples so far")
    
    return total_processed

In [15]:
processed_data_dir = "processed_data_debug"

if os.path.exists(processed_data_dir):
    shutil.rmtree(processed_data_dir)

In [16]:
total_examples = preprocess_dataset(train_df, batch_size=500, save_dir=processed_data_dir)
print(f"Total processed examples: {total_examples}")

Processing batch 1/1


Batch 0:   0%|          | 0/10 [00:00<?, ?it/s]

Processed 10 examples so far
Total processed examples: 10


In [17]:
batch_file = os.path.join(processed_data_dir, "batch_0.pkl")
with open(batch_file, 'rb') as f:
    batch_data = pickle.load(f)

print("\nSample of processed data (first example):")
sample_data = batch_data[0]
print(f"Encounter ID: {sample_data['encounter_id']}")
print(f"Question ID: {sample_data['qid']}")
print(f"Query text: {sample_data['query_text']}")
print(f"Number of images: {len(sample_data['image_paths'])}")
print(f"Image paths: {sample_data['image_paths']}")
print(f"Answer text: {sample_data['answer_text']}")


Sample of processed data (first example):
Encounter ID: ENC00001
Question ID: CQID010-001
Query text: Question: How much of the body is affected? Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned
Number of images: 2
Image paths: ['2025_dataset/train/images_train/IMG_ENC00001_00001.jpg', '2025_dataset/train/images_train/IMG_ENC00001_00002.jpg']
Answer text: limited area


In [18]:
batch_file = os.path.join(processed_data_dir, "batch_0.pkl")
with open(batch_file, 'rb') as f:
    batch_data = pickle.load(f)

print("\nSample of processed data (first example):")
sample_data = batch_data[1]
print(f"Encounter ID: {sample_data['encounter_id']}")
print(f"Question ID: {sample_data['qid']}")
print(f"Query text: {sample_data['query_text']}")
print(f"Number of images: {len(sample_data['image_paths'])}")
print(f"Image paths: {sample_data['image_paths']}")
print(f"Answer text: {sample_data['answer_text']}")


Sample of processed data (first example):
Encounter ID: ENC00002
Question ID: CQID010-001
Query text: Question: How much of the body is affected? Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned
Number of images: 4
Image paths: ['2025_dataset/train/images_train/IMG_ENC00002_00001.jpg', '2025_dataset/train/images_train/IMG_ENC00002_00002.jpg', '2025_dataset/train/images_train/IMG_ENC00002_00003.jpg', '2025_dataset/train/images_train/IMG_ENC00002_00004.jpg']
Answer text: limited area


In [19]:
batch_file = os.path.join(processed_data_dir, "batch_0.pkl")
with open(batch_file, 'rb') as f:
    batch_data = pickle.load(f)

print("\nSample of processed data (first example):")
sample_data = batch_data[2]
print(f"Encounter ID: {sample_data['encounter_id']}")
print(f"Question ID: {sample_data['qid']}")
print(f"Query text: {sample_data['query_text']}")
print(f"Number of images: {len(sample_data['image_paths'])}")
print(f"Image paths: {sample_data['image_paths']}")
print(f"Answer text: {sample_data['answer_text']}")


Sample of processed data (first example):
Encounter ID: ENC00003
Question ID: CQID010-001
Query text: Question: How much of the body is affected? Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned
Number of images: 5
Image paths: ['2025_dataset/train/images_train/IMG_ENC00003_00001.jpg', '2025_dataset/train/images_train/IMG_ENC00003_00002.jpg', '2025_dataset/train/images_train/IMG_ENC00003_00003.jpg', '2025_dataset/train/images_train/IMG_ENC00003_00004.jpg', '2025_dataset/train/images_train/IMG_ENC00003_00005.jpg']
Answer text: limited area


In [20]:
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
model_id = "google/gemma-3-4b-pt"
processor = AutoProcessor.from_pretrained(model_id, token=hf_token, use_fast=True)

In [21]:
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"\nDataset size: {len(dataset)}")

example = dataset[0]
print(f"\nFirst example roles:")
print(f"System role: {example['messages'][0]}")
print(f"User role: {example['messages'][1]}")
print(f"Assistant role: {example['messages'][2]}")


Dataset size: 10

First example roles:
System role: {'role': 'system', 'content': [{'type': 'text', 'text': 'You are an AI assistant answering medical questions based on images.'}]}
User role: {'role': 'user', 'content': [{'type': 'text', 'text': 'Question: How much of the body is affected? Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=1944x2541 at 0x15541D9B75E0>}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=2592x1944 at 0x15541D9CC190>}]}
Assistant role: {'role': 'assistant', 'content': [{'type': 'text', 'text': 'limited area'}]}


In [22]:
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"\nDataset size: {len(dataset)}")

example = dataset[1]
print(f"\nFirst example roles:")
print(f"System role: {example['messages'][0]}")
print(f"User role: {example['messages'][1]}")
print(f"Assistant role: {example['messages'][2]}")


Dataset size: 10

First example roles:
System role: {'role': 'system', 'content': [{'type': 'text', 'text': 'You are an AI assistant answering medical questions based on images.'}]}
User role: {'role': 'user', 'content': [{'type': 'text', 'text': 'Question: How much of the body is affected? Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=2560x1920 at 0x15541D782B30>}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=2560x1920 at 0x15541DA19270>}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=2560x1920 at 0x15541D903880>}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=2560x1920 at 0x15541D902E00>}]}
Assistant role: {'role': 'assistant', 'content': [{'type': 'text', 'text': 'limited area'}]}


In [23]:
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"\nDataset size: {len(dataset)}")

example = dataset[2]
print(f"\nFirst example roles:")
print(f"System role: {example['messages'][0]}")
print(f"User role: {example['messages'][1]}")
print(f"Assistant role: {example['messages'][2]}")


Dataset size: 10

First example roles:
System role: {'role': 'system', 'content': [{'type': 'text', 'text': 'You are an AI assistant answering medical questions based on images.'}]}
User role: {'role': 'user', 'content': [{'type': 'text', 'text': 'Question: How much of the body is affected? Options: 1. single spot, 2. limited area, 3. widespread, 4. Not mentioned'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=600x400 at 0x15544D0510F0>}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=600x401 at 0x155550D22380>}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=600x402 at 0x155550D20940>}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=600x400 at 0x155550D21CC0>}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=600x399 at 0x155550D23700>}]}
Assistant role: {'role': 'assistant', 'content': [{'type': 'text', 'text': 'limited area'}]}


# Load model

In [24]:
model_id = "google/gemma-3-4b-pt"

if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16. Use a different GPU.")

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs, token=hf_token)

processor = AutoProcessor.from_pretrained(model_id, token=hf_token)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


# Collate data

In [25]:
print(f"Current chat template: {processor.tokenizer.chat_template}")

Current chat template: None


In [27]:
gemma_chat_template = """{% for message in messages %}
{% if message['role'] == 'user' %}
<start_of_turn>user
{% for content in message['content'] %}
{% if content['type'] == 'text' %}{{ content['text'] }}{% elif content['type'] == 'image' %}<image>{% endif %}
{% endfor %}
<end_of_turn>
{% elif message['role'] == 'assistant' %}
<start_of_turn>model
{% for content in message['content'] %}
{% if content['type'] == 'text' %}{{ content['text'] }}{% endif %}
{% endfor %}
<end_of_turn>
{% elif message['role'] == 'system' %}
<start_of_turn>system
{% for content in message['content'] %}
{% if content['type'] == 'text' %}{{ content['text'] }}{% endif %}
{% endfor %}
<end_of_turn>
{% endif %}
{% endfor %}
{% if add_generation_prompt %}
<start_of_turn>model
{% endif %}
"""

# Set template
processor.tokenizer.chat_template = gemma_chat_template
processor.chat_template = gemma_chat_template

# Ensure <image> is recognized as a token
if "<image>" not in processor.tokenizer.get_vocab():
    processor.tokenizer.add_special_tokens({'additional_special_tokens': ['<image>']})
    model.resize_token_embeddings(len(processor.tokenizer))

# Set boi_token for template rendering
processor.boi_token = "<image>"
processor.tokenizer.special_tokens_map['boi_token'] = "<image>"

In [28]:
print(f"Current chat template: {processor.tokenizer.chat_template}")

Current chat template: {% for message in messages %}
{% if message['role'] == 'user' %}
<start_of_turn>user
{% for content in message['content'] %}
{% if content['type'] == 'text' %}{{ content['text'] }}{% elif content['type'] == 'image' %}<image>{% endif %}
{% endfor %}
<end_of_turn>
{% elif message['role'] == 'assistant' %}
<start_of_turn>model
{% for content in message['content'] %}
{% if content['type'] == 'text' %}{{ content['text'] }}{% endif %}
{% endfor %}
<end_of_turn>
{% elif message['role'] == 'system' %}
<start_of_turn>system
{% for content in message['content'] %}
{% if content['type'] == 'text' %}{{ content['text'] }}{% endif %}
{% endfor %}
<end_of_turn>
{% endif %}
{% endfor %}
{% if add_generation_prompt %}
<start_of_turn>model
{% endif %}



In [29]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
#     target_modules="all-linear",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],
)

In [30]:
def process_vision_info(messages):
    """
    Extracts images from a structured messages list.
    Returns a list of PIL Image objects in RGB format.
    """
    image_inputs = []
    # print(f"Processing messages: {messages}")
    
    for msg in messages:
        content = msg.get("content", [])
        # print(f"Message content: {content}")
        
        if not isinstance(content, list):
            content = [content]
        
        for element in content:
            # print(f"Checking element: {type(element)}")
            if isinstance(element, dict) and (
                "image" in element or element.get("type") == "image"
            ):
                # print("Found image element!")
                if "image" in element:
                    image = element["image"]
                else:
                    image = element
                
                if hasattr(image, 'convert'):
                    image = image.convert("RGB")
                    image_inputs.append(image)
                    # print(f"Added image: {image.size}")
                else:
                    # print(f"Element is not a PIL image: {type(image)}")
                    pass
    
    # print(f"Total images found: {len(image_inputs)}")
    return image_inputs

In [31]:
def create_dummy_image():
    """Create a small black image as a placeholder."""
    return Image.new('RGB', (224, 224), color='black')

In [32]:
# Ensure <image> is properly registered as a special token
if "<image>" not in processor.tokenizer.get_vocab():
    # Add the token to the vocabulary
    processor.tokenizer.add_special_tokens({'additional_special_tokens': ['<image>']})
    # VERY IMPORTANT: Resize the model's embeddings to match the new vocabulary size
    model.resize_token_embeddings(len(processor.tokenizer))

# Check what token ID is assigned to <image>
image_token_id = processor.tokenizer.convert_tokens_to_ids("<image>")
print(f"<image> token ID: {image_token_id}")

<image> token ID: 262145


In [33]:
def collate_fn(examples):
    texts = []
    images_per_example = []
    
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        if not image_inputs:
            print(f"Using dummy image — Example roles: {[m['role'] for m in example['messages']]}")
            image_inputs = [create_dummy_image()]
        
        # Apply chat template with hardcoded <image> token
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        
        # Count image tokens explicitly
        num_image_tokens = text.count("<image>")
        
        # Ensure number of images matches tokens
        if len(image_inputs) < num_image_tokens:
            needed_dummies = num_image_tokens - len(image_inputs)
            image_inputs += [create_dummy_image()] * needed_dummies
        elif len(image_inputs) > num_image_tokens:
            # Never truncate to zero
            image_inputs = image_inputs[:max(1, num_image_tokens)]
        
        texts.append(text.strip())
        images_per_example.append(image_inputs)
    
    # Use the processor directly as in the documentation
    batch = processor(
        text=texts, 
        images=images_per_example,  # Pass as list of lists
        return_tensors="pt", 
        padding=True
    )
    
    # Create labels for training
    labels = batch["input_ids"].clone()
    
    # Get image token ID the same way the model does
    image_token_id = processor.tokenizer.convert_tokens_to_ids("<image>")
    
    # Mask tokens that shouldn't contribute to the loss
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100
    
    batch["labels"] = labels
    return batch

In [34]:
dataset = MedicalImageDataset(processed_data_dir, processor)
print(f"Dataset size: {len(dataset)}")

if len(dataset) == 0:
    print("ERROR: Dataset is empty! Check data loading process.")
    pass
else:
    sample_size = min(3, len(dataset))
    print(f"Sampling {sample_size} examples from dataset")
    
    sample_examples = [dataset[i] for i in range(sample_size)]
    
    print(f"Sample size: {len(sample_examples)}")
    print("First example keys:", list(sample_examples[0].keys()))
    
    batch = collate_fn(sample_examples)
    print("Collated batch contains:", list(batch.keys()))
    print(f"Input_ids shape: {batch['input_ids'].shape}")
    print(f"Labels shape: {batch['labels'].shape}")

Dataset size: 10
Sampling 3 examples from dataset
Sample size: 3
First example keys: ['messages']
Collated batch contains: ['input_ids', 'attention_mask', 'token_type_ids', 'pixel_values', 'labels']
Input_ids shape: torch.Size([3, 1359])
Labels shape: torch.Size([3, 1359])


In [35]:
dataloader = DataLoader(
    dataset,
    batch_size=8,  # Adjust based on GPU memory
    shuffle=True,
    collate_fn=collate_fn
)

total_examples = 0
for batch in dataloader:
    # Process each batch. Note: in training, pass this to model.forward()
    batch_size = len(batch["input_ids"])
    total_examples += batch_size
    print(f"Processed batch with {batch_size} examples")

print(f"Processed all {total_examples} examples")

Processed batch with 8 examples
Processed batch with 2 examples
Processed all 10 examples


In [36]:
def debug_template_application(examples):
    print("\n=== DEBUGGING TEMPLATE APPLICATION ===")
    for i, example in enumerate(examples):
        print(f"\nExample {i}:")
        
        # First, check message structure
        print("Message structure:")
        for j, msg in enumerate(example["messages"]):
            role = msg.get("role", "unknown")
            print(f"  Message {j} role: {role}")
            
            content = msg.get("content", [])
            if not isinstance(content, list):
                content = [content]
                
            print(f"  Content types: {[c.get('type') if isinstance(c, dict) else type(c).__name__ for c in content]}")
            
            # Check specifically for image content
            image_count = sum(1 for c in content if isinstance(c, dict) and c.get('type') == 'image')
            print(f"  Image content count: {image_count}")
        
        # Next, check template application
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        
        # Count image tokens in generated text
        img_token_count = text.count("<image>")
        print(f"  Image tokens in text: {img_token_count}")
        
        # Print text snippet to see if tokens appear
        print(f"  Text preview: {text[:200]}...")
        
        # Check if there's a mismatch
        if img_token_count == 0:
            print("  WARNING: No image tokens found in text!")
            
            # Examine raw content in more detail
            print("  Detailed content examination:")
            for j, msg in enumerate(example["messages"]):
                if msg.get("role") == "user":
                    content = msg.get("content", [])
                    if not isinstance(content, list):
                        content = [content]
                    
                    for k, item in enumerate(content):
                        if isinstance(item, dict):
                            print(f"    Item {k}: type={item.get('type')}, keys={list(item.keys())}")
                            if item.get('type') == 'image':
                                print(f"    Found image content item")
                            else:
                                print(f"    Item has type={item.get('type')}, not 'image'")
                                
debug_template_application(sample_examples)


=== DEBUGGING TEMPLATE APPLICATION ===

Example 0:
Message structure:
  Message 0 role: system
  Content types: ['text']
  Image content count: 0
  Message 1 role: user
  Content types: ['text', 'image', 'image']
  Image content count: 2
  Message 2 role: assistant
  Content types: ['text']
  Image content count: 0
  Image tokens in text: 2
  Text preview: <start_of_turn>system
You are an AI assistant answering medical questions based on images.<end_of_turn>
<start_of_turn>user
Question: How much of the body is affected? Options: 1. single spot, 2. limi...

Example 1:
Message structure:
  Message 0 role: system
  Content types: ['text']
  Image content count: 0
  Message 1 role: user
  Content types: ['text', 'image', 'image', 'image', 'image']
  Image content count: 4
  Message 2 role: assistant
  Content types: ['text']
  Image content count: 0
  Image tokens in text: 4
  Text preview: <start_of_turn>system
You are an AI assistant answering medical questions based on images.<end_of_

In [37]:
args = SFTConfig(
    output_dir="gemma-product-description",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=5,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
    report_to="tensorboard",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,  # Critical for custom datasets
    label_names=["labels"],  # Explicitly setting label_names
)

# args.remove_unused_columns = False # Not needed in our case but leaving it in in case the loaded data changes

In [38]:
# # Test with a small batch
# test_indices = list(range(min(3, len(dataset))))
# test_batch = [dataset[i] for i in test_indices]

# Now initialize your actual trainer with the regular collate_fn
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

In [39]:
# Start training
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss


TrainOutput(global_step=2, training_loss=19.036113739013672, metrics={'train_runtime': 67.0704, 'train_samples_per_second': 0.149, 'train_steps_per_second': 0.03, 'total_flos': 174718733364480.0, 'train_loss': 19.036113739013672})

In [40]:
# Save the model
trainer.save_model()
print("Training complete and model saved!")

Training complete and model saved!


In [None]:
# Before you can test your model, make sure to free the memory.

# free the memory again
del model
del trainer
torch.cuda.empty_cache()

In [None]:
from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")