In [1]:
import os
import json
import torch
import pandas as pd
from datasets import Dataset
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig

In [31]:
# Paths
train_json_file = os.path.join("2024_dataset", "train_downloaded.json")
train_images_dir = os.path.join("2024_dataset", "images", "train")

# Load JSON
with open(train_json_file, 'r', encoding='utf-8') as f:
    train_data = json.load(f)

# Convert dataset to Hugging Face format
def format_data(sample):
    image_path = os.path.join(train_images_dir, f"{sample['encounter_id']}.jpg")
    if not os.path.exists(image_path):
        return None  # Skip missing images

    # Open image before storing in dataset
    image = Image.open(image_path).convert("RGB")

    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are an AI assisting in medical image analysis."}],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": f"Analyze the following image: {sample.get('query_title_en', '')} {sample.get('query_content_en', '')}".strip()},
                    {"type": "image", "image": image},  # Use actual Image object instead of path
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["responses"][0]["content_en"]}],
            },
        ]
    }

# Process dataset
formatted_data = [format_data(entry) for entry in train_data if format_data(entry) is not None]

# Convert to Hugging Face Dataset
dataset = Dataset.from_list(formatted_data)

In [32]:
def process_vision_info(messages):
    """
    Extracts images from a structured messages list.
    Returns a list of PIL Image objects in RGB format.
    """
    image_inputs = []

    for msg in messages:
        content = msg.get("content", [])
        if not isinstance(content, list):
            content = [content]

        for element in content:
            if isinstance(element, dict) and "image" in element:
                image = element["image"]

                # Convert string paths to Image objects
                if isinstance(image, str):
                    image = Image.open(image).convert("RGB")

                image_inputs.append(image)

    return image_inputs

In [None]:
# THIS IS IN PACE VERSION THAT WORKS

# taken directly from documentation

# Model ID
# model_id = "google/gemma-3-4b-pt"

# Get HF token from environment or manually set it
hf_token = os.getenv("HF_TOKEN", "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")

# Load model with token
model = AutoModelForImageTextToText.from_pretrained(
    "google/gemma-3-4b-pt",
    use_auth_token=hf_token
)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it", use_auth_token=hf_token)

# Ensure GPU supports bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16. Use a different GPU.")

# Load model with QLoRA (4-bit quantization)
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load model and processor
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

In [None]:
# taken directly from documentation

# Model ID
model_id = "google/gemma-3-4b-pt"

# Ensure GPU supports bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16. Use a different GPU.")

# Load model with QLoRA (4-bit quantization)
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load model and processor
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

ValueError: Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules in 32-bit, you need to set `llm_int8_enable_fp32_cpu_offload=True` and pass a custom `device_map` to `from_pretrained`. Check https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu for more details. 

In [9]:
processor.chat_template = "<s>{role}: {content}</s>"  # Set a basic chat temp

In [24]:
peft_config = LoraConfig(
    lora_alpha=16,
    # lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],
)

In [25]:
args = SFTConfig(
    output_dir="gemma-product-description",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=5,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    push_to_hub=True,
    report_to="tensorboard",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_text_field="messages",
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,  # Critical for custom datasets
    label_names=["labels"],  # Explicitly setting label_names
)

args.remove_unused_columns = False

In [27]:
def collate_fn(examples):
    texts = []
    images = []
    
    for example in examples:
        if not isinstance(example["messages"], list):
            print("Warning: 'messages' should be a list but isn't.")
            continue

        # Extract image and text inputs
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )

        texts.append(text.strip())
        images.append(image_inputs)

    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # Process labels correctly
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100

    batch["labels"] = labels
    return batch

In [28]:
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

# Start training
trainer.train()

ValueError: Target module Identity() is not supported. Currently, only the following modules are supported: `torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, `transformers.pytorch_utils.Conv1D`.

In [None]:
from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")