~~~
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
~~~

# Fine-tune MedGemma with Reinforcement Learning and TRL

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/google-health/medgemma/blob/main/notebooks/rl_with_trl.ipynb">
      <img alt="Google Colab logo" src="https://www.tensorflow.org/images/colab_logo_32px.png" width="32px"><br> Run in Google Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogle-Health%2Fmedgemma%2Fmain%2Fnotebooks%2Frl_with_trl.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/google-health/medgemma/blob/main/notebooks/rl_with_trl.ipynb">
      <img alt="GitHub logo" src="https://github.githubassets.com/assets/GitHub-Mark-ea2971cee799.png" width="32px"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://huggingface.co/collections/google/medgemma-release-680aade845f90bec6a3f60c4">
      <img alt="Hugging Face logo" src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="32px"><br> View on Hugging Face
    </a>
  </td>
</tr></tbody></table>

This notebook demonstrates RL-tuning MedGemma on a text dataset for medical QA using `trl`: Transformer Reinforcement Learning.

In this guide, you will use [`trl`](https://huggingface.co/docs/trl/grpo_trainer) - Transformer Reinforcement Learning to train the model with Reinforcement Learning (RL), specifically, [GRPO (Group Relative Policy Optimization)](https://arxiv.org/abs/2402.03300), utilizing [Low-Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685v2) to reduce computational costs while maintaining high performance.

**Citations:**

- LoRA: Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., ... & Chen, W. (2022). Lora: Low-rank adaptation of large language models. ICLR, 1(2), 3.

- GRPO: Shao, Z., Wang, P., Zhu, Q., Xu, R., Song, J., Bi, X., ... & Guo, D. (2024). Deepseekmath: Pushing the limits of mathematical reasoning in open language models. arXiv preprint arXiv:2402.03300.


## Setup

To complete this tutorial, you'll need to have a runtime with sufficient resources to fine-tune the MedGemma model. **Note:** This guide requires a GPU that supports bfloat16 data type and has at least 40 GB of memory.

You can run this notebook in Google Colab using an A100 GPU:

1. In the upper-right of the Colab window, select **â–¾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **A100 GPU**.

*Note that this will take a long time to run (~10 hrs per epoch of MedQA train split on an A100 40GB).*

### Get access to MedGemma

Before you get started, make sure that you have access to MedGemma models on Hugging Face:

1. If you don't already have a Hugging Face account, you can create one for free by clicking [here](https://huggingface.co/join).
2. Head over to the [MedGemma model page](https://huggingface.co/google/medgemma-4b-it) and accept the usage conditions.


### Configure your HF token

Generate a Hugging Face `write` access token by going to [settings](https://huggingface.co/settings/tokens). **Note:** Make sure that the token has write access to push the fine-tuned model to Hugging Face Hub.

If you are using Google Colab, add your access token to the Colab Secrets manager to securely store it. If not, proceed to run the cell below to authenticate with Hugging Face.

1. Open your Google Colab notebook and click on the ðŸ”‘ Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
2. Create a new secret with the name `HF_TOKEN`.
3. Copy/paste your token key into the Value input box of `HF_TOKEN`.
4. Toggle the button on the left to allow notebook access to the secret.
5. Set `HF_HOME` for downstream `trl` runs as well.

In [None]:
import os
import sys

if "google.colab" in sys.modules and not os.environ.get("VERTEX_PRODUCT"):
    # Use secret if running in Google Colab
    from google.colab import userdata
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
else:
    # Store Hugging Face data under `/content` if running in Colab Enterprise
    if os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE":
        os.environ["HF_HOME"] = "/content/hf"
    # Authenticate with Hugging Face
    from huggingface_hub import get_token
    if get_token() is None:
        from huggingface_hub import notebook_login
        notebook_login()

### Installation

In [None]:
! pip install -U -q transformers trl[vllm] datasets
# tested with python 3.12 and !pip install transformers==4.55.3 trl[vllm]==0.21.0 datasets==4.0.0

In [2]:
# You may also want to store training runs in your drive
# (HIGHLY recommended saving model checkpoints)
from google.colab import drive
drive.mount('/content/drive')

## Dataset processing

This notebook uses the [MedQA](https://arxiv.org/abs/2009.13081) dataset, a multiple-choice question dataset derived from medical licensing exams in the US, China, and Taiwan, designed to assess medical knowledge and clinical reasoning skills.

Load the data using the Hugging Face `datasets` library. Then, create train and validation splits. We subsample the dev split for faster evaluation times.

**Dataset citation:** Jin, D., Pan, E., Oufattole, N., Weng, W. H., Fang, H., & Szolovits, P. (2021). What disease does this patient have? a large-scale open domain question answering dataset from medical exams. Applied Sciences, 11(14), 6421.

In [None]:
import datasets

big_prompt = f"""Answer the given question. Think step by step.
You can directly provide the answer (A single letter) inside <answer> and </answer> (e.g. <answer>A</answer>), without further additions.
Question: [QUESTION]
[OPTIONS]
"""

def process_medqa(data):
  return data.map(lambda x: {
                      'prompt': [
                          {'role': 'user', 'content': big_prompt.replace('[QUESTION]', x['data']['Question']).replace(
                              '[OPTIONS]', f"(A) {x['data']['Options']['A']} (B) {x['data']['Options']['B']} (C) {x['data']['Options']['C']} (D) {x['data']['Options']['D']}")}
                      ],
                      'answer': x['data']['Correct Option']
                  })

medqa_dataset = datasets.load_dataset("openlifescienceai/medqa")
train_dataset = process_medqa(medqa_dataset["train"])
val_dataset = process_medqa(medqa_dataset["dev"].shuffle(seed=42).select(range(100)))

## Post-train the model with LoRA via GRPO on MedQA

Traditional fine-tuning of large language models is resource-intensive because it requires adjusting billions of parameters. Parameter-Efficient Fine-Tuning (PEFT) addresses this by training a smaller number of parameters. A common PEFT technique is *Low-Rank Adaptation (LoRA)*, which efficiently adapts large language models by training small, low-rank matrices that are added to the original model instead of updating the full-weight matrices.

*GRPO (Group Relative Policy Optimization)* is a reinforcement learning (RL) algorithm that aims to improve efficiency and reduce training costs by eliminating the need for a separate value function. Instead, GRPO uses group-based advantage estimation and incorporates KL divergence into the loss function for better stability.

This notebook demonstrates RL training Gemma as well as MedGemma (with verifiable rewards) with LoRA.

First, define the reward function to check when the model's answer letter matches the correct answer letter (i.e. 'A', 'B', 'C', or 'D').

In [None]:
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print(f"-----Question:\n{q}\nAnswer:\n{answer[0]}\nResponse:\n{responses[0]}\nExtracted:\n{extracted_responses[0]}")
    print([(r,a, r == a) for r, a in zip(extracted_responses, answer)])
    return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

Next, configure training with the `GRPOConfig`.

In [None]:
import torch
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer

ckpt = "google/medgemma-4b-it"
# ckpt = "google/gemma-3-4b-it"

training_args = GRPOConfig(
    output_dir="./tuned_medgemma4b",
    eval_on_start=True,                      # Run an evaluation at the very beginning of training.
    learning_rate=3e-5,                      # The initial learning rate for the AdamW optimizer.
    optim="adamw_torch_fused",               # Optimizer to use; 'adamw_torch_fused' is a faster, memory-efficient AdamW.
    per_device_train_batch_size=4,
    gradient_accumulation_steps=64,          # Accumulate gradients for this many steps to simulate a larger batch size (per_device_train_batch_size * gradient_accumulation_steps).
    num_generations=4,                       # Number of completions to generate per prompt for GRPO's preference learning.
    max_prompt_length=512,                   # Maximum token length for input prompts.
    max_completion_length=512,               # Maximum token length for the model's generated completions.
    num_train_epochs=1,
    logging_steps=10,
    save_steps=10,
    eval_strategy="steps",
    eval_steps=10,
    report_to="tensorboard",
    use_vllm=True,                           # Use the vLLM library for significantly faster inference during generation.
    vllm_mode="colocate",                    # vLLM deployment mode; 'colocate' runs vLLM on the same GPU(s) as the trainer.
    vllm_gpu_memory_utilization=.35,         # Fraction of GPU memory that vLLM is allowed to use (35%).
    bf16=True,                               # Enable bfloat16 mixed precision training to save memory and speed up training.
    gradient_checkpointing=True,             # Save memory by trading compute (avoids storing all intermediate activations).
    gradient_checkpointing_kwargs={
        "use_reentrant": False               # Use a more efficient implementation of gradient checkpointing.
    },
    model_init_kwargs={
        "device_map": "auto",
        "dtype": torch.bfloat16,             # Set model parameter data type to bfloat16.
        "attn_implementation": "eager"       # Gemma 3 recommends using the 'eager' attention implementation.
    },
    push_to_hub=True
)

# LoRA to reduce VRAM
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=64,
    lora_alpha=64,
    target_modules="all-linear",
)

Train the model.

Note that this will take a long time to run (~10 hrs per epoch).

In [None]:
import datasets

trainer = GRPOTrainer(
    model=ckpt,
    reward_funcs=[correctness_reward_func],
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    peft_config=lora_config,
)
trainer.train()
trainer.save_model(output_dir=training_args.output_dir)

In [3]:
# Change the relevant paths to store training results.
! cp -r ./tuned_gemma4b/ /content/drive/MyDrive/trl_colab_storage/

In [None]:
# Visualize training curves
! pip install tensorboard
%load_ext tensorboard
%tensorboard --logdir /content/tuned_medgemma4b/

## Model evaluation: Effect of RL-tuning

**Important:** Before you continue, you may need to restart the runtime due to the VRAM limitation on Colab kernels.

The following cells compute and print the accuracy of the baseline and fine-tuned models on the test dataset to assess the effect of RL-tuning.

We also load and process the test split using the same logic as before, repeated below for convenience.

In [None]:
# Reinstantiate env variables upon reload
import os
import datasets
if os.environ.get("VERTEX_PRODUCT") == "COLAB_ENTERPRISE":
    os.environ["HF_HOME"] = "/content/hf"

# Load test split
big_prompt = f"""Answer the given question. Think step by step.
You can directly provide the answer (A single letter) inside <answer> and </answer> (e.g. <answer>A</answer>), without further additions.
Question: [QUESTION]
[OPTIONS]
"""

def process_medqa(data):
  return data.map(lambda x: {
                      'prompt': [
                          {'role': 'user', 'content': big_prompt.replace('[QUESTION]', x['data']['Question']).replace(
                              '[OPTIONS]', f"(A) {x['data']['Options']['A']} (B) {x['data']['Options']['B']} (C) {x['data']['Options']['C']} (D) {x['data']['Options']['D']}")}
                      ],
                      'answer': x['data']['Correct Option']
                  })
  
medqa_dataset_test = datasets.load_dataset("openlifescienceai/medqa", split='test')
test_dataset = process_medqa(medqa_dataset_test)

In [None]:
import torch
import pandas as pd
from tqdm.auto import tqdm

def run_inference_batched(test_dataset, model, processor, batch_size=4, device="cuda", verbose=True):
    """
    Runs inference on a processed test dataset using batching for efficiency.

    Args:
        test_dataset: A dataset where each item has 'prompt' (chat history) and 'answer' (ground truth).
        model: The loaded PEFT model for inference.
        processor: The processor for tokenizing the input.
        batch_size (int): The number of samples to process at once. Adjust based on VRAM.
        device (str): The device to run inference on ('cuda' or 'cpu').
        verbose (bool): Whether to print progress and sample outputs.

    Returns:
        A list of dictionaries, with each dictionary containing the prompt,
        ground truth answer, and the model's generated answer.
    """
    results = []

    # --- Critical settings for batched generation ---
    # Set padding token to EOS token and enable left-padding
    processor.tokenizer.pad_token = processor.tokenizer.eos_token
    processor.tokenizer.padding_side = "left"

    # Create an iterator for the batches
    num_samples = len(test_dataset)

    # Use tqdm for a progress bar if verbose is True
    batch_iterator = range(0, num_samples, batch_size)
    if verbose:
        print(f"Starting batched inference on {num_samples} samples with batch size {batch_size}...")
        batch_iterator = tqdm(batch_iterator, desc="Batch Inference")

    for i in batch_iterator:
        # 1. Prepare the current batch
        batch_data = test_dataset[i : i + batch_size]
        batch_prompts = batch_data['prompt']
        batch_ground_truths = batch_data['answer']

        # 2. Tokenize the entire batch at once with left-padding
        inputs = processor.tokenizer.apply_chat_template(
            batch_prompts,
            tokenize=True,
            add_generation_prompt=True,
            return_pt=True,
            padding=True,  # Pad sequences to the length of the longest in the batch
            max_len=1024,
        ).to(device)

        # 3. Generate responses for the entire batch in one go
        outputs = model.generate(
            inputs,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        )

        # 4. Decode the generated part of the output
        # This is more robust than decoding the whole sequence and stripping the prompt
        input_token_length = inputs.shape[1]
        generated_tokens = outputs[:, input_token_length:]
        model_generated_answers = processor.batch_decode(generated_tokens, skip_special_tokens=True)

        # 5. Store the results for the current batch
        for j in range(len(batch_prompts)):
            results.append({
                'prompt': batch_prompts[j],
                'ground_truth': batch_ground_truths[j],
                'model_answer': model_generated_answers[j].strip() # Use .strip() for clean output
            })

    # Optional: print a few examples from the final results
    if verbose:
        print("\n--- Sample of Batched Inference Results ---")
        for res in results[:3]: # Print first 3 results
            print(f"Ground Truth: {res['ground_truth']}")
            print(f"Model Answer: {res['model_answer']}\n")

    return results

### Evaluate baseline performance

This cell calculates the baseline model's accuracy on the test data. This baseline serves as a benchmark to measure the fine-tuned model's performance improvement.

In [None]:
from transformers import Gemma3Processor, AutoModelForCausalLM

# Load model and processor
print("Loading model and processor...")
ckpt = "google/medgemma-4b-it"
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    dtype=torch.bfloat16,
    device_map="auto",
)
processor = Gemma3Processor.from_pretrained(ckpt)

# Run inference on the entire processed dataset
inference_results = run_inference_batched(
    test_dataset=test_dataset,
    model=model,
    processor=processor,
    batch_size=100,
)
results_df = pd.DataFrame(inference_results)

# Defined here again just in case, same as above
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

results_df['model_pred'] = results_df['model_answer'].apply(extract_xml_answer)
results_df['correct'] = results_df['ground_truth'] == results_df['model_pred']
print('Baseline Accuracy', results_df['correct'].mean())
# Save trained_results
results_df.to_csv('baseline_results.csv')

del model
del processor

### Evaluate tuned model performance

This cell calculates the fine-tuned model's accuracy on the test data. Comparing this with the baseline score shows the improvement from fine-tuning.

In [None]:
from peft import AutoPeftModelForCausalLM
from transformers import Gemma3Processor
import torch
import pandas as pd

# Define model and processor information (make sure the paths are right!)
model_path = "/content/tuned_medgemma4b/checkpoint-159"
ckpt = "google/medgemma-4b-it"

# Load Model and Processor
print("Loading model and processor...")
model = AutoPeftModelForCausalLM.from_pretrained(
    model_path,
    dtype=torch.bfloat16,
    device_map="auto",
)
processor = Gemma3Processor.from_pretrained(ckpt)

# Run inference on the entire processed dataset
inference_results = run_inference_batched(
    test_dataset=test_dataset,
    model=model,
    processor=processor,
    batch_size=64,
)
results_df = pd.DataFrame(inference_results)
results_df['model_pred'] = results_df['model_answer'].apply(extract_xml_answer)
results_df['correct'] = results_df['ground_truth'] == results_df['model_pred']
print('GRPO-tuned Accuracy', results_df['correct'].mean())
# Save trained_results
results_df.to_csv('trained_results.csv')

Reproduced MedQA test accuracy after 1 epoch (10 hrs)

| Model    | Pre-RL Tuning | Post-RL Tuning |
| -------- | ------- | ------- |
| gemma-3-4b-it  | 0.479 | 0.535 |
| medgemma-4b-it | 0.644 | 0.652 |

Observations:

- **Both models improved**: Reinforcement learning led to an accuracy increase for both the generalist Gemma and the domain-specific MedGemma models.
- **Significant gain for Gemma**: The `gemma-3-4b-it` model shows a substantial improvement of over 5 percentage points, demonstrating the effectiveness of fine-tuning on a general-purpose model.
- **Marginal gain for MedGemma**: The `medgemma-4b-it` model, already specialized for medical data, shows a smaller improvement. This suggests it was already highly optimized for this task, leaving less room for gains from this tuning method.

### Additional optimizations
Note that this notebook is meant to be a starting point. There are numerous optimizations that are not covered under this Colab, including [deepspeed](https://huggingface.co/docs/trl/main/en/deepspeed_integration), [parallelization on multiple nodes](https://huggingface.co/docs/trl/main/en/grpo_trainer#grpo-at-scale-train-a-70b-model-on-multiple-nodes), and more.

We recommend checking out [GRPO Trainer](https://huggingface.co/docs/trl/main/en/grpo_trainer) for further details.

## Next steps

Explore the other [notebooks](https://github.com/google-health/medgemma/blob/main/notebooks) to learn what else you can do with the model.