In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("WARNING: No GPU detected! Please enable GPU in Runtime settings.")

In [None]:
!pip install -q accelerate peft transformers bitsandbytes datasets pillow tqdm evaluate trl scikit-learn

In [None]:
from google.colab import drive
drive.mount('/content/drive')

DATASET_PATH = "/content/drive/MyDrive/pediatric_xray_dataset_chest"

import os
if os.path.exists(DATASET_PATH):
    print(f"✓ Dataset found at: {DATASET_PATH}")
    print(f"  Files: {os.listdir(DATASET_PATH)}")
else:
    print(f"✗ Dataset NOT found at: {DATASET_PATH}")

### Loading the dataset
Load the dataset and display its structure

In [None]:
import os
import json

def load_jsonl_dataset(file_path):
    data = []
    with open(file_path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    print(f"Loaded {len(data)} entries from {os.path.basename(file_path)}")
    return data

train_file_path = os.path.join(DATASET_PATH, 'train.jsonl')
val_file_path = os.path.join(DATASET_PATH, 'val.jsonl')
test_file_path = os.path.join(DATASET_PATH, 'test.jsonl')

print(f"Loading training data from: {train_file_path}")
train_data = load_jsonl_dataset(train_file_path)

print(f"Loading validation data from: {val_file_path}")
val_data = load_jsonl_dataset(val_file_path)

print(f"Loading test data from: {test_file_path}")
test_data = load_jsonl_dataset(test_file_path)

print("\nFirst entry of train_data:")
print(train_data[0])

print("\nFirst entry of val_data:")
print(val_data[0])

print("\nFirst entry of test_data:")
print(test_data[0])

In [None]:
from datasets import Dataset, DatasetDict
import os

# Function to prepare data for Hugging Face Dataset format
def prepare_hf_dataset_format(data_list, base_path):
    processed_data = []
    for item in data_list:
        # Construct the full image path
        full_image_path = os.path.join(base_path, item["image"])
        # Use 'age_group' as the label for this example
        processed_data.append({
            "image": full_image_path,
            "report": item["report"],
            "age_group": item["age_group"],
            "region": "chest",
            "gender": item["sex"]
        })
    return processed_data

# Process training and validation data
processed_train_data = prepare_hf_dataset_format(train_data, DATASET_PATH)
processed_val_data = prepare_hf_dataset_format(val_data, DATASET_PATH)
processed_test_data = prepare_hf_dataset_format(test_data, DATASET_PATH)

# Create Hugging Face Dataset objects
hf_train_dataset = Dataset.from_list(processed_train_data)
hf_val_dataset = Dataset.from_list(processed_val_data)
hf_test_dataset = Dataset.from_list(processed_test_data)

# Create a DatasetDict
data = DatasetDict({
    "train": hf_train_dataset,
    "validation": hf_val_dataset,
    "test": hf_test_dataset
})

print("Dataset created successfully:")

data

### Processing the dataset
We create a custom prompt that will be used to guide the model during fine-tuning. The prompt includes the updated class labels. To prepare the dataset for fine-tuning, we will create a new column called "messages". This column will contain structured data representing a user query (the prompt) and assistant response (the report).

In [None]:
def format_data(example: dict[str, any]) -> dict[str, any]:
    PROMPT = "{question} in this {anatomy} X-ray of a {subject}?".format(
        question="Are there any lung consolidations, infitrates, opacities, pleural effusion, pneumothorax or pneumoperitoneum",
        anatomy=example["region"],
        subject=example["age_group"]
    )
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": PROMPT,
                },
            ],
        },
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": example["report"],
                },
            ],
        },
    ]
    return example

# Apply the formatting to the dataset
formatted_data = data.map(format_data)

# formatted_data["train"][0]["messages"]

### Loading the model and tokenizer
Since MedGemma is a gated model, you need to log in to the Hugging Face CLI using your API key. This also allows you to save your fine-tuned model to the Hugging Face Hub.

In [None]:
from huggingface_hub import login
import os

hf_token = os.environ.get("HF_TOKEN")
login(hf_token)

We use the Transformers library to load the MedGemma 4B Instruct model and its processor. The model is configured to use bfloat16 precision for efficient computation on GPUs.

In [None]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

model_id = "google/medgemma-4b-it"

## Check if GPU supports bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(model_id)

# Use right padding to avoid issues during training
processor.tokenizer.padding_side = "right"

### Setting up the model

To fine-tune the MedGemma 4B Instruct model efficiently, we will use Low-Rank Adaptation (LoRA), a parameter-efficient fine-tuning method.

LoRA allows us to adapt large models by training only a small number of additional parameters, significantly reducing computational costs while maintaining performance.

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

To handle both image and text inputs during training, we define a custom collation function. This function processes the dataset examples into a format suitable for the model, including tokenizing text and preparing image data.

In [None]:
def collate_fn(examples: list[dict[str, any]]):
    texts = []
    images = []
    for example in examples:
        images.append([example["image"]])
        texts.append(
            processor.apply_chat_template(
                example["messages"], add_generation_prompt=False, tokenize=False
            ).strip()
        )

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, with the padding and image tokens masked in
    # the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens that are not used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch

We use the SFTConfig class from the trl library to define the training arguments. These arguments control the fine-tuning process, including batch size, learning rate, and gradient accumulation steps.

In [None]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="medgemma-4b-it-ped",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=0.1,
    save_strategy="epoch",
    eval_strategy="steps",
    eval_steps=0.1,
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",
    push_to_hub=True,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns = False,
    label_names=["labels"],
)

The SFTTrainer simplifies the fine-tuning process by combining the model, dataset, data collator, training arguments, and LoRA configuration into a single workflow. This makes the process streamlined and user-friendly.

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=formatted_data["train"],
    eval_dataset=formatted_data["validation"],
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

### Model training
Once the model, dataset, and training configurations are set up, we can begin the fine-tuning process. The SFTTrainer simplifies this step, allowing us to train the model with just a single command:

In [None]:
#import os
#os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"

trainer.train()

After the training is complete, the fine-tuned model can be saved locally and pushed to the Hugging Face Hub using the save_model() method.

In [None]:
trainer.save_model()