~~~
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
~~~

# Fine-tuning TxGemma with Hugging Face

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/TxGemma/[TxGemma]Finetune_with_Hugging_Face.ipynb">
      <img alt="Google Colab logo" src="https://www.tensorflow.org/images/colab_logo_32px.png" width="32px"><br> Run in Google Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma/%5BTxGemma%5DFinetune_with_Hugging_Face.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://huggingface.co/collections/google/txgemma-release-67dd92e931c857d15e4d1e87">
      <img alt="HuggingFace logo" src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="32px"><br> View on HuggingFace
    </a>
  </td>
</tr></tbody></table>

This notebook demonstrates fine-tuning TxGemma models to generalize to new therapeutic development tasks using Hugging Face libraries.

The demo uses Hugging Face's [Transformer Reinforcement Learning (`TRL`)](https://github.com/huggingface/trl) library to train the model with Supervised Fine-Tuning (SFT), utilizing [Parameter-Efficient Fine-Tuning (`PEFT`)](https://github.com/huggingface/peft) with Low-Rank Adaptation (LoRA)  to reduce computational costs. The training data includes a subset of the [TrialBench](https://arxiv.org/abs/2407.00631) dataset to fine-tune TxGemma to predict adverse events in clinical trials.


## Setup

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to fine-tune and run the TxGemma model. In this case, you can use a T4 GPU:

1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU**.

### Get access to TxGemma

Before you get started, make sure that you have access to TxGemma models on Hugging Face:

1. If you don't already have a Hugging Face account, you can create one for free by clicking [here](https://huggingface.co/join).
2. Head over to the [TxGemma model page](https://huggingface.co/google/txgemma-2b-predict) and accept the usage conditions.

### Configure your HF token

Generate a Hugging Face `read` access token by clicking [here](https://huggingface.co/settings/tokens) and add your access token to the Colab Secrets manager to securely store it.

1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
2. Create a new secret with the name `HF_TOKEN`.
3. Copy/paste your token key into the Value input box of `HF_TOKEN`.
4. Toggle the button on the left to allow notebook access to the secret.

In [None]:
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

### Install dependencies

In [None]:
! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.0/76.0 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m411.0/411.0 kB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m110.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m336.4/336.4 kB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m17.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━

## Load model from Hugging Face Hub

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "google/txgemma-2b-predict"

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

## Load dataset

This notebook uses adverse event prediction data from [TrialBench](https://arxiv.org/abs/2407.00631) to fine-tune TxGemma. The dataset has been preprocessed into an instruction-tuning format and is available in [Cloud Storage](https://console.cloud.google.com/storage/browser/healthai-us/txgemma/datasets).

Load the dataset using the Hugging Face [`datasets`](https://github.com/huggingface/datasets) library.

In [None]:
from datasets import load_dataset

! wget -nc https://storage.googleapis.com/healthai-us/txgemma/datasets/trialbench_adverse-event-rate-prediction_train.jsonl
data = load_dataset(
    "json",
    data_files="/content/trialbench_adverse-event-rate-prediction_train.jsonl",
    split="train",
)

# Display dataset details
data

--2025-04-06 06:36:33--  https://storage.googleapis.com/healthai-us/txgemma/datasets/trialbench_adverse-event-rate-prediction_train.jsonl
Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.180.207, 142.251.167.207, 172.253.115.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.180.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18413655 (18M) [application/octet-stream]
Saving to: ‘trialbench_adverse-event-rate-prediction_train.jsonl’


2025-04-06 06:36:34 (63.5 MB/s) - ‘trialbench_adverse-event-rate-prediction_train.jsonl’ saved [18413655/18413655]



Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['input_text', 'output_text'],
    num_rows: 14368
})

Each data point includes:

* `"input_text"`: Question, which prompts the model to predict whether there will be an adverse event given information about a clinical trial. Inputs include drug SMILES strings and textual information.

* `"output_text"`: Answer, which is either "Yes" or "No".

Below is an example from the dataset:

In [None]:
data["input_text"][0]

'From the following information about a clinical trial, predict whether it would have an adverse event.\n\nTitle: Safety, Tolerability and Pharmacokinetics of Single and Repeat Doses of GSK2292767 in Healthy Participants Who Smoke Cigarettes\nSummary: This study is the first administration of GSK2292767 to humans. The study will evaluate the safety, tolerability, pharmacokinetics (PK) and pharmacodynamics (PD) of single and repeat inhaled doses of GSK2292767 in healthy smokers. This study is intended to provide sufficient confidence in the safety of the molecule and preliminary information on target engagement to allow progression to further repeat dose and proof of mechanism studies. This is a two part, single site, randomized, double-blind (sponsor open), placebo controlled study. Part A will consist of two 3-period interlocking cohorts to evaluate the safety, tolerability and pharmacokinetics of ascending single doses of GSK2292767 administered as a dry powder inhalation. Part B is 

In [None]:
data["output_text"][0]

'No'

The expected data format for training is a single `"text"` column containing a full sequence of text.

Here, define a function that properly formats each example in the dataset. In a later section, it will be passed to the `SFTTrainer`, which applies the formatting function to the dataset before tokenization.

In [None]:
def formatting_func(example):
    text = f"{example['input_text']} {example['output_text']}<eos>"
    return text

# Display formatted training data example
print(formatting_func(data[0]))

From the following information about a clinical trial, predict whether it would have an adverse event.

Title: Safety, Tolerability and Pharmacokinetics of Single and Repeat Doses of GSK2292767 in Healthy Participants Who Smoke Cigarettes
Summary: This study is the first administration of GSK2292767 to humans. The study will evaluate the safety, tolerability, pharmacokinetics (PK) and pharmacodynamics (PD) of single and repeat inhaled doses of GSK2292767 in healthy smokers. This study is intended to provide sufficient confidence in the safety of the molecule and preliminary information on target engagement to allow progression to further repeat dose and proof of mechanism studies. This is a two part, single site, randomized, double-blind (sponsor open), placebo controlled study. Part A will consist of two 3-period interlocking cohorts to evaluate the safety, tolerability and pharmacokinetics of ascending single doses of GSK2292767 administered as a dry powder inhalation. Part B is plan

## Try out the pretrained model

Prompt the pretrained model to see how it performs on a sample adverse event prediction task. Prior to fine-tuning, the model does not understand the instruction and provides an inappropriate answer.

In [None]:
prompt = "From the following information about a clinical trial, predict whether it would have an adverse event.\n\nDrug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

From the following information about a clinical trial, predict whether it would have an adverse event.

Drug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2

Answer:188


## Fine-tune the model with LoRA

Traditional fine-tuning of large language models is resource-intensive because it requires adjusting billions of parameters. Parameter-Efficient Fine-Tuning (PEFT) addresses this by training a smaller number of parameters, using techniques like Low-Rank Adaptation (LoRA). LoRA efficiently adapts large language models by training small, low-rank matrices that are added to the original model instead of updating the full-weight matrices.


This section demonstrates fine-tuning TxGemma using LoRA and the `SFTTrainer` from the Hugging Face `TRL` library.

First, define the [`LoraConfig`](https://huggingface.co/docs/peft/main/en/package_reference/lora), including the rank of the adaptation matrices and the model layers to add LoRA adapters to.

In [None]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

Prepare the model for training.

In [None]:
from peft import prepare_model_for_kbit_training, get_peft_model

# Preprocess quantized model for training
model = prepare_model_for_kbit_training(model)

# Create PeftModel from quantized model and configuration
model = get_peft_model(model, lora_config)

This example uses the Supervised Fine-Tuning (SFT) method to train the TxGemma model.

Here, construct the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer) that handles the complete training loop, including data loading, forward and backward passes, and optimizer steps. Specify the LoRA configuration and dataset formatting function defined earlier and the `SFTConfig` with training parameters.

In [None]:
import transformers
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    train_dataset=data,
    args=SFTConfig(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=50,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=5,
        max_seq_length=512,
        output_dir="/content/outputs",
        optim="paged_adamw_8bit",
        report_to="none",
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
)

Applying formatting function to train dataset:   0%|          | 0/14368 [00:00<?, ? examples/s]

Converting train dataset to ChatML:   0%|          | 0/14368 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/14368 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/14368 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/14368 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Launch the fine-tuning process.

In [None]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
5,16.7508
10,10.9105
15,8.5497
20,6.3555
25,5.2708
30,4.8518
35,4.0853
40,3.8537
45,3.7279
50,3.5955


TrainOutput(global_step=50, training_loss=6.79513858795166, metrics={'train_runtime': 158.6588, 'train_samples_per_second': 1.261, 'train_steps_per_second': 0.315, 'total_flos': 726227766793728.0, 'train_loss': 6.79513858795166})

## Test the fine-tuned model

Prompt the fine-tuned model to see how it performs on a sample adverse event prediction task. After fine-tuning, the model has learned to respond with an appropriate answer to the prompt.


In [None]:
prompt = "From the following information about a clinical trial, predict whether it would have an adverse event.\n\nDrug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

  return fn(*args, **kwargs)


From the following information about a clinical trial, predict whether it would have an adverse event.

Drug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2

Answer: Yes


# Next steps

Explore the other [notebooks](https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma) to learn what else you can do with the model.