In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("WARNING: No GPU detected! Please enable GPU in Runtime settings.")

In [None]:
!pip install -q accelerate peft transformers bitsandbytes datasets pillow tqdm evaluate trl scikit-learn

In [None]:
SYS_PROMPT = """You are an experienced emergency radiologist analyzing imaging studies.

OUTPUT FORMAT:
You must respond with ONLY detailed findings as a string.

ANALYSIS APPROACH:
- Systematically examine the entire image for all abnormalities
- Report all identified lesions and pathological findings
- Be factual - if uncertain, describe what you observe without assuming
- Use professional radiological terminology
- Review the image multiple times if findings are ambiguous

REPORT CONTENT:
The "report" field should contain a complete radiological description including:
- Primary findings related to the clinical question
- Additional incidental findings or lesions
- Relevant negative findings if clinically important

EXAMPLES:

Example 1 - Chest X-ray with pneumonia:
Input: Chest X-ray, patient with cough and fever
Output: Consolidation in the right lower lobe consistent with pneumonia. No pleural effusion or pneumothorax. Heart size normal.

Example 2 - Normal chest X-ray:
Input: Chest X-ray, routine screening
Output: Clear lung fields bilaterally. No consolidation, pleural effusion, or pneumothorax. Cardiac silhouette within normal limits. No acute bony abnormalities.
"""

USR_PROMPT = """Generate a radiology report for this {anatomy} X-ray of {subject}."""

### Loading the dataset
Load the dataset and display its structure

In [None]:
from datasets import load_dataset, Image

data = load_dataset("costinstroie/xray-chest-ped-test")
data = data.cast_column("image", Image())

# Debug
print(data)

### Processing the dataset
We create a custom prompt that will be used to guide the model during fine-tuning. The prompt includes the updated class labels. To prepare the dataset for fine-tuning, we will create a new column called "messages". This column will contain structured data representing a user query (the prompt) and assistant response (the report).

In [None]:
def format_data(example: dict[str, any]) -> dict[str, any]:
    prompt = USR_PROMPT.format(
        anatomy="chest",
        subject=f"{example['age_group']} {example['gender']}"
    )

    example["messages"] = [
        {
            "role": "system",
            "content": [{"type": "text", "text": SYS_PROMPT}]
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": prompt,
                },
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": example["report"],
                },
            ],
        },
    ]
    return example

# Apply the formatting to the dataset
formatted_data = data.map(format_data)

# Debug
formatted_data["train"][0]['messages']

### Loading the model and tokenizer
Since MedGemma is a gated model, you need to log in to the Hugging Face CLI using your API key. This also allows you to save your fine-tuned model to the Hugging Face Hub.

We use the Transformers library to load the MedGemma 4B Instruct model and its processor. The model is configured to use bfloat16 precision for efficient computation on GPUs.

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

model_id = "google/medgemma-4b-it"

## Check if GPU supports bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)

# Use right padding to avoid issues during training
processor.tokenizer.padding_side = "right"

### Setting up the model

To fine-tune the MedGemma 4B Instruct model efficiently, we will use Low-Rank Adaptation (LoRA), a parameter-efficient fine-tuning method.

LoRA allows us to adapt large models by training only a small number of additional parameters, significantly reducing computational costs while maintaining performance.

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=32,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

To handle both image and text inputs during training, we define a custom collation function. This function processes the dataset examples into a format suitable for the model, including tokenizing text and preparing image data.

In [None]:
def collate_fn(examples: list[dict[str, any]]):
    texts = []
    images = []
    for example in examples:
        images.append([example["image"]])
        texts.append(
            processor.apply_chat_template(
                example["messages"], add_generation_prompt=False, tokenize=False
            ).strip()
        )

    # Tokenize the texts and process the images
    # Pass a list of PIL.Image.Image objects directly
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, with the padding and image tokens masked in
    # the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens that are not used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100 # MedGemma specific image token ID

    batch["labels"] = labels
    return batch

We use the SFTConfig class from the trl library to define the training arguments. These arguments control the fine-tuning process, including batch size, learning rate, and gradient accumulation steps.

In [None]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="medgemma-4b-it-ped",
    num_train_epochs=3,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=0.1,
    save_strategy="epoch",
    eval_strategy="steps",
    eval_steps=0.1,
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",
    push_to_hub=True,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns = False,
    label_names=["labels"],
)

The SFTTrainer simplifies the fine-tuning process by combining the model, dataset, data collator, training arguments, and LoRA configuration into a single workflow. This makes the process streamlined and user-friendly.

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=formatted_data["train"],
    eval_dataset=formatted_data["validation"],
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

### Model training
Once the model, dataset, and training configurations are set up, we can begin the fine-tuning process. The SFTTrainer simplifies this step, allowing us to train the model with just a single command:

In [None]:
trainer.train()

After the training is complete, the fine-tuned model can be saved locally and pushed to the Hugging Face Hub using the save_model() method.

In [None]:
trainer.save_model()

### Model Evaluation
To evaluate the performance of the MedGemma 4B model, we will test both the base model and the fine-tuned model on the validation dataset. This process involves clearing the memory, preparing the test data, generating the response, and calculating key metrics such as accuracy and F1 score.

Before starting the evaluation, we remove the training setup to free up GPU memory and ensure a clean environment for testing

In [None]:
del model
del trainer
torch.cuda.empty_cache()

### Setting up for model testing
We format the validation dataset to match the input structure required by the model. This involves creating a "messages" column that contains the user prompt for each example.

In [None]:
def format_test_data(example: dict[str, any]) -> dict[str, any]:
    prompt = USR_PROMPT.format(
        anatomy="chest",
        subject=f"{example['age_group']} {example['gender']}"
    )

    example["messages"] = [
        {
            "role": "system",
            "content": [{"type": "text", "text": SYS_PROMPT}]
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": prompt,
                },
            ],
        },
    ]
    return example

test_data = data["test"]
test_data = test_data.map(format_test_data)

### Model performance on the base model

To evaluate the base model's performance, we load the pre-trained model and processor, configure the generation settings, and prepare the prompts and images for testing.

In [None]:
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor

model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model = AutoModelForImageTextToText.from_pretrained(
    model_id, **model_kwargs
)

from transformers import GenerationConfig
gen_cfg = GenerationConfig.from_pretrained(model_id)
gen_cfg.update(
    do_sample          = False,
    top_k              = None,
    top_p              = None,
    cache_implementation = "dynamic"
)
model.generation_config = gen_cfg

processor  = AutoProcessor.from_pretrained(args.output_dir)
tok = processor.tokenizer

model.config.pad_token_id            = tok.pad_token_id
model.generation_config.pad_token_id = tok.pad_token_id

def chat_to_prompt(chat_turns):
    return processor.apply_chat_template(
        chat_turns,
        add_generation_prompt=True,   # tells the model "your turn"
        tokenize=False                # we want raw text, not ids
    )

prompts = [chat_to_prompt(c) for c in test_data["messages"]]
images  = test_data["image"]                         # already a list of PIL images
assert len(prompts) == len(images), "1 prompt must match 1 image!"

The predict_one function takes a prompt and an image as input, processes them using the model's processor, and generates a response. The function ensures that the model's output is decoded into human-readable text.

In [None]:
import torch
from typing import Union, Dict, Any, List
from transformers import AutoModelForImageTextToText, AutoProcessor


def predict_one(
    prompt,
    image,
    model,
    processor,
    *,
    device="cuda",
    dtype=torch.bfloat16,
    disable_compile=True,
    **gen_kwargs
) -> str:
    inputs = processor(text=prompt, images=image, return_tensors="pt").to(
        device, dtype=dtype
    )
    plen = inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        ids = model.generate(
            **inputs,
            disable_compile=disable_compile,
            **gen_kwargs
        )
    return processor.decode(ids[0, plen:], skip_special_tokens=True)

We will use the predict_one to generate a response for the 11th sample from the dataset. This involves preparing the prompt and running the prediction function.

In [None]:
idx     = 3
chat    = test_data["messages"][idx]
prompt  = processor.apply_chat_template(
            chat,
            add_generation_prompt=True,
            tokenize=False
          )

# run the one-sample helper
answer = predict_one(
    prompt   = prompt,
    image    = test_data["image"][idx],
    model    = model,
    processor= processor,
    max_new_tokens = 500
)

import textwrap
print("Dataset report:", textwrap.fill(test_data["report"][idx], 80))
print("Model answer:", textwrap.fill(answer, 80))
test_data["image"][idx]
