~~~
Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
~~~

# Fine-tuning TxGemma with Hugging Face

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/TxGemma/[TxGemma]Finetune_with_Hugging_Face.ipynb">
      <img alt="Google Colab logo" src="https://www.tensorflow.org/images/colab_logo_32px.png" width="32px"><br> Run in Google Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma/%5BTxGemma%5DFinetune_with_Hugging_Face.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://huggingface.co/collections/google/txgemma-release-67dd92e931c857d15e4d1e87">
      <img alt="HuggingFace logo" src="https://huggingface.co/front/assets/huggingface_logo-noborder.svg" width="32px"><br> View on HuggingFace
    </a>
  </td>
</tr></tbody></table>

This notebook demonstrates fine-tuning TxGemma models to generalize to new therapeutic development tasks using Hugging Face libraries.

The demo uses Hugging Face's [Transformer Reinforcement Learning (`TRL`)](https://github.com/huggingface/trl) library to train the model with Supervised Fine-Tuning (SFT), utilizing [Parameter-Efficient Fine-Tuning (`PEFT`)](https://github.com/huggingface/peft) with Low-Rank Adaptation (LoRA)  to reduce computational costs. The training data includes a subset of the [TrialBench](https://arxiv.org/abs/2407.00631) dataset to fine-tune TxGemma to predict adverse events in clinical trials.


## Setup

To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to fine-tune and run the TxGemma model. In this case, you can use a T4 GPU:

1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.
2. Select **Change runtime type**.
3. Under **Hardware accelerator**, select **T4 GPU**.

### Get access to TxGemma

Before you get started, make sure that you have access to TxGemma models on Hugging Face:

1. If you don't already have a Hugging Face account, you can create one for free by clicking [here](https://huggingface.co/join).
2. Head over to the [TxGemma model page](https://huggingface.co/google/txgemma-2b-predict) and accept the usage conditions.

### Configure your HF token

Generate a Hugging Face `read` access token by clicking [here](https://huggingface.co/settings/tokens) and add your access token to the Colab Secrets manager to securely store it.

1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src="https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg" alt="The Secrets tab is found on the left panel." width=50%>
2. Create a new secret with the name `HF_TOKEN`.
3. Copy/paste your token key into the Value input box of `HF_TOKEN`.
4. Toggle the button on the left to allow notebook access to the secret.

In [1]:
import os
from google.colab import userdata
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

### Install dependencies

In [2]:
! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl --use-deprecated=legacy-resolver

## Load model from Hugging Face Hub

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "google/txgemma-2b-predict"

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Load dataset

This notebook uses adverse event prediction data from [TrialBench](https://arxiv.org/abs/2407.00631) to fine-tune TxGemma. The dataset has been preprocessed into an instruction-tuning format and is available in [Cloud Storage](https://console.cloud.google.com/storage/browser/healthai-us/txgemma/datasets).

Load the dataset using the Hugging Face [`datasets`](https://github.com/huggingface/datasets) library.

In [4]:
from datasets import load_dataset

# ! wget -nc https://storage.googleapis.com/healthai-us/txgemma/datasets/trialbench_adverse-event-rate-prediction_train.jsonl
data = load_dataset(
    "json",
    data_files="chembl.json",
    split="train",
)
# Display dataset details
data

Dataset({
    features: ['combine', 'text'],
    num_rows: 13847
})

Each data point includes:

* `"input_text"`: Question, which prompts the model to predict whether there will be an adverse event given information about a clinical trial. Inputs include drug SMILES strings and textual information.

* `"output_text"`: Answer, which is either "Yes" or "No".

Below is an example from the dataset:

In [5]:
data['combine'][0]

'Question: Is this drug administered parenterally?\nCompound: Cc1ncccc1OC[C@@H]1CCCN1'

In [6]:
data["text"][0]

'No'

The expected data format for training is a single `"text"` column containing a full sequence of text.

Here, define a function that properly formats each example in the dataset. In a later section, it will be passed to the `SFTTrainer`, which applies the formatting function to the dataset before tokenization.

In [7]:
def formatting_func(example):
    text = f"{example['combine']} {example['text']}<eos>"
    return text

# Display formatted training data example
print(formatting_func(data[0]))

Question: Is this drug administered parenterally?
Compound: Cc1ncccc1OC[C@@H]1CCCN1 No<eos>


In [8]:

split_dataset = data.train_test_split(test_size=0.2, seed=42)
train_dataset = split_dataset['train']
valid_dataset = split_dataset['test']


## Try out the pretrained model

Prompt the pretrained model to see how it performs on a sample adverse event prediction task. Prior to fine-tuning, the model does not understand the instruction and provides an inappropriate answer.

In [9]:
prompt = "From the following information about a clinical trial, predict whether it would have an adverse event.\n\nDrug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2\n\nAnswer:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

From the following information about a clinical trial, predict whether it would have an adverse event.

Drug: C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2

Answer:188


## Fine-tune the model with LoRA

Traditional fine-tuning of large language models is resource-intensive because it requires adjusting billions of parameters. Parameter-Efficient Fine-Tuning (PEFT) addresses this by training a smaller number of parameters, using techniques like Low-Rank Adaptation (LoRA). LoRA efficiently adapts large language models by training small, low-rank matrices that are added to the original model instead of updating the full-weight matrices.


This section demonstrates fine-tuning TxGemma using LoRA and the `SFTTrainer` from the Hugging Face `TRL` library.

First, define the [`LoraConfig`](https://huggingface.co/docs/peft/main/en/package_reference/lora), including the rank of the adaptation matrices and the model layers to add LoRA adapters to.

In [10]:
from peft import LoraConfig

lora_config = LoraConfig(
    r=8,
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "o_proj",
        "k_proj",
        "v_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
)

Prepare the model for training.

In [11]:
from peft import prepare_model_for_kbit_training, get_peft_model

# Preprocess quantized model for training
model = prepare_model_for_kbit_training(model)

# Create PeftModel from quantized model and configuration
model = get_peft_model(model, lora_config)

This example uses the Supervised Fine-Tuning (SFT) method to train the TxGemma model.

Here, construct the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer) that handles the complete training loop, including data loading, forward and backward passes, and optimizer steps. Specify the LoRA configuration and dataset formatting function defined earlier and the `SFTConfig` with training parameters.

In [20]:
from sklearn.metrics import accuracy_score

def compute_metrics(pred):
    import numpy as np
    from sklearn.metrics import accuracy_score

    # Move predictions and labels to CPU and NumPy
    predictions = pred.predictions
    labels = pred.label_ids

    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()

    # Usual postprocessing
    predictions = predictions.argmax(axis=-1)  # if classification

    acc = accuracy_score(labels, predictions)

    return {
        "accuracy": acc,
    }

In [21]:
import transformers
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    args=SFTConfig(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=2,
        max_steps=200,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=5,
        max_seq_length=512,
        output_dir="/content/outputs",
        optim="paged_adamw_8bit",
        report_to="none",
    ),
    peft_config=lora_config,
    formatting_func=formatting_func,
    compute_metrics=compute_metrics,
)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Launch the fine-tuning process.

In [22]:
trainer.train()

  return fn(*args, **kwargs)


Step,Training Loss
5,0.5301
10,0.6619
15,0.5745
20,0.5457
25,0.5482
30,0.5067
35,0.5804
40,0.4664
45,0.52
50,0.5168


TrainOutput(global_step=200, training_loss=0.6408772671222687, metrics={'train_runtime': 425.7341, 'train_samples_per_second': 1.879, 'train_steps_per_second': 0.47, 'total_flos': 666023178470400.0, 'train_loss': 0.6408772671222687})

In [23]:
torch.cuda.empty_cache()
trainer.model = trainer.model.to('cpu')
metrics = trainer.evaluate()
print(metrics)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [None]:
import torch
from tqdm import tqdm
model.to('cuda')
model.eval()

correct = 0
total = 0

for sample in tqdm(valid_dataset):
    # 1. Get the prompt
    prompt = sample['combine']  # or whatever your input field is called
    label = sample['text']      # or whatever your ground-truth field is called

    # 2. Tokenize and move to device
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    # 3. Generate output
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=8)

    # 4. Decode the output
    pred_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = pred_text.strip().split(" ")[-1]
    # print(f"Prompt: {prompt}\nPrediction: {answer}\nLabel: {label}\n")

    # 5. Compare prediction to label
    if answer == label.strip():
        correct += 1
    total += 1

# 6. Compute accuracy
accuracy = correct / total
print(f"Validation Accuracy: {accuracy:.4f}")


 80%|████████  | 2225/2770 [13:43<03:20,  2.72it/s]

## Test the fine-tuned model

Prompt the fine-tuned model to see how it performs on a sample adverse event prediction task. After fine-tuning, the model has learned to respond with an appropriate answer to the prompt.


In [15]:
prompt = "Is this drug administered parenterally?\nDrug: Cc1cc([C@H](C)Nc2nccc(N3C(=O)OC[C@@H]3[C@@H](C)F)n2)ncc1-c1ccnc(C(F)(F)F)c1"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens=8)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Is this drug administered parenterally?
Drug: Cc1cc([C@H](C)Nc2nccc(N3C(=O)OC[C@@H]3[C@@H](C)F)n2)ncc1-c1ccnc(C(F)(F)F)c1 No


# Next steps

Explore the other [notebooks](https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma) to learn what else you can do with the model.