In [None]:
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Quick start with Model Garden - TxGemma

<table><tbody><tr>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2Fgoogle-gemini%2Fgemma-cookbook%2Fmain2FTxGemma%2F%255BTxGemma%255DQuickstart_with_Model_Garden.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma/%5BTxGemma%5DQuickstart_with_Model_Garden.ipynb">
      <img alt="GitHub logo" src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" width="32px"><br> View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates how to use TxGemma in Vertex AI to generate predictions from therapeutic related text prompts using two methods for getting predictions:

* **Online predictions** are synchronous requests that are made to the endpoint deployed from Model Garden and are served with low latency. Online predictions are useful if the embeddings are being used in production. The cost for online prediction is based on the time a virtual machine spends waiting in an active state (an endpoint with a deployed model) to handle prediction requests.

* **Batch predictions** are asynchronous requests that are run on a set number of prompts specified in a single job. They are made directly to an uploaded model and do not use an endpoint deployed from Model Garden. Batch predictions are useful if you want to generate embeddings for a large number of prompts for use in training and don't require low latency. The cost for batch prediction is based on the time a virtual machine spends running your prediction job.

Vertex AI makes it easy to serve your model and make it accessible to the world. Learn more about [Vertex AI](https://cloud.google.com/vertex-ai/docs/start/introduction-unified-platform).

### Objectives

- Deploy TxGemma to a Vertex AI Endpoint and get online predictions.
- Upload TxGemma to Vertex AI Model Registry and get batch predictions.

<!-- @import cloud/ml/applications/vision/model_garden/notebooks/templates/common.ipynb:costs -->

## Before you begin

In [None]:
# @title Import packages and define common functions

import datetime
import importlib
import io
import json
import os
import uuid

import google.auth
from google.cloud import aiplatform
from google.cloud import storage
import openai

if not os.path.isdir("vertex-ai-samples"):
  ! git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples.git

common_util = importlib.import_module(
    "vertex-ai-samples.community-content.vertex_model_garden.model_oss.notebook_util.common_util"
)

models, endpoints = {}, {}

In [None]:
# @title Set up Google Cloud environment

# @markdown #### Prerequisites

# @markdown 1. Make sure that [billing is enabled](https://cloud.google.com/billing/docs/how-to/modify-project) for your project.

# @markdown 2. Make sure that either the Compute Engine API is enabled or that you have the [Service Usage Admin](https://cloud.google.com/iam/docs/understanding-roles#serviceusage.serviceUsageAdmin) (`roles/serviceusage.serviceUsageAdmin`) role to enable the API.

# @markdown This section sets the default Google Cloud project and region, enables the Compute Engine API (if not already enabled), and initializes the Vertex AI API.

# Get the default project ID.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Compute Engine API, if not already.
print("Enabling Compute Engine API.")
! gcloud services enable compute.googleapis.com

# Initialize Vertex AI API.
print("Initializing Vertex AI API.")
aiplatform.init(project=PROJECT_ID, location=REGION)

## Format prompts for therapeutic tasks

The following sections in this notebook demonstrate prompting TxGemma for therapeutic development tasks from [Therapeutics Data Commons](https://tdcommons.ai/) (TDC).

For these predictive tasks, prompts should be formatted according to the TDC structure, including:

- **Instruction:** Briefly describes the task.

- **Context:** Provides 2-3 sentences of relevant biochemical background, derived from TDC descriptions and literature.

- **Question:** Queries a specific therapeutic property, incorporating textual representations of therapeutics and/or targets as inputs.

  - Inputs can include SMILES strings, amino acid sequences, nucleotide sequences, and natural language text.

  - Optional few-shot examples can be provided.

In [None]:
# @title Load prompt template

# @markdown First, load a JSON file that contains the prompt format for various TDC tasks.

! gcloud storage cp gs://hai-def-playground-bucket/txgemma/tdc_prompts.json tdc_prompts.json

with open("tdc_prompts.json", "r") as f:
    tdc_prompts_json = json.load(f)

In [None]:
# @title Prepare sample prompt

# @markdown Construct a prompt using the template and an input drug SMILES string from the [BBB (Blood-Brain Barrier), Martins et al.](https://tdcommons.ai/single_pred_tasks/adme#bbb-blood-brain-barrier-martins-et-al) dataset. This prompt will be used for generating predictions in the next sections.

# Example task and input
task_name = "BBB_Martins"
input_type = "{Drug SMILES}"
drug_smiles = "CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21"

TDC_PROMPT = tdc_prompts_json[task_name].replace(input_type, drug_smiles)
print("Formatted prompt:\n")
print(TDC_PROMPT)

## Get online predictions

In [None]:
# @title Import deployed model

# @markdown To get [online predictions](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions), you will need a TxGemma [Vertex AI Endpoint](https://cloud.google.com/vertex-ai/docs/general/deployment) that has been deployed from Model Garden. If you have not already done so, go to the [TxGemma model card](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/txgemma) in Model Garden and click "Deploy" to deploy the model.

# @markdown This section gets the Vertex AI Endpoint resource that you deployed from Model Garden to use for online predictions.

# @markdown Fill in the endpoint ID and region below. You can find your deployed endpoint on the [Vertex AI online prediction page](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).

ENDPOINT_ID = ""  # @param {type: "string", placeholder:"e.g. 123456789"}
ENDPOINT_REGION = ""  # @param {type: "string", placeholder:"e.g. us-central1"}

endpoints["endpoint"] = aiplatform.Endpoint(
    endpoint_name=ENDPOINT_ID,
    project=PROJECT_ID,
    location=ENDPOINT_REGION,
)

### Explore predictive capabilities

TxGemma models are designed to process and understand information related to various therapeutic modalities and targets, including small molecules, proteins, nucleic acids, diseases, and cell lines, and can generate predictions on a broad set of therapeutic development tasks.

You can send [online prediction](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) requests to the endpoint with formatted text prompts.

In [None]:
# @title #### Run inference on a therapeutic task

# @markdown This section demonstrates prompting TxGemma for a predictive task from TDC.

# Prepare request instance using sample prompt for a TDC task
instances = [
    {
        "prompt": TDC_PROMPT,
        "max_tokens": 8,
        "temperature": 0
    },
]

response = endpoints["endpoint"].predict(instances=instances)
predictions = response.predictions

print(predictions[0])

### Explore conversational capabilities with TxGemma-Chat

TxGemma features conversational models that add reasoning and explainability to predictions and can be used in multi-turn interactions. Their conversational ability comes at the expense of some predictive power.

You can send [online prediction](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) requests to the endpoint using the [OpenAI SDK](https://github.com/openai/openai-python).

**For this section, make sure that you have selected a TxGemma-Chat model variant.**

In [None]:
# @title #### Ask questions in a multi-turn conversation

# @markdown This section demonstrates prompting TxGemma for conversational use.

# @markdown In this example, first prompt the model to answer a question regarding a predictive task using the TDC format. Then, ask a follow-up question requesting the model to provide its reasoning for the predicted answer.

import google.auth
from IPython.display import display, Markdown
import openai

creds, project = google.auth.default()
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)

ENDPOINT_RESOURCE_URL = (
    f"https://{ENDPOINT_REGION}-aiplatform.googleapis.com/v1beta1/"
    f"projects/{PROJECT_ID}/"
    f"locations/{ENDPOINT_REGION}/"
    f"endpoints/{ENDPOINT_ID}"
)

client = openai.OpenAI(base_url=ENDPOINT_RESOURCE_URL, api_key=creds.token)

questions = [
    TDC_PROMPT,  # Initial question is a predictive task from TDC
    "Explain your reasoning based on the molecule structure."
]

messages = []

display(Markdown("\n\n---\n\n"))
for question in questions:
    display(Markdown(f"**User:**\n\n{question}\n\n---\n\n"))
    messages.append(
        { "role": "user", "content": question },
    )
    model_response = client.chat.completions.create(
        model="",
        messages=messages,
        temperature=0,
        max_tokens=512,
    )
    response_content = model_response.choices[0].message.content
    display(Markdown(f"**TxGemma:**\n\n{response_content}\n\n---\n\n"))
    messages.append(
        { "role": "assistant", "content": response_content},
    )

## Get batch predictions

In [None]:
# @title Get access to TxGemma

# @markdown The prediction container directly loads the model from Hugging Face Hub.

# @markdown To enable access to the TxGemma models, you must provide a Hugging Face User Access Token. You can follow the [Hugging Face documentation](https://huggingface.co/docs/hub/en/security-tokens) to create a **read** access token and specify it in the `HF_TOKEN` field below.

HF_TOKEN = ""  # @param {type:"string", placeholder:"Hugging Face User Access Token"}

In [None]:
# @title Upload model to Vertex AI Model Registry

# @markdown To get [batch predictions](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions), you must first upload the prebuilt TxGemma model to [Vertex AI Model Registry](https://cloud.google.com/vertex-ai/docs/model-registry/introduction). Batch prediction requests are made directly to a model in Model Registry without deploying to an endpoint.

MODEL_VARIANT = "2b-predict"  # @param ["2b-predict", "9b-chat", "9b-predict", "27b-chat", "27b-predict"]

MODEL_ID = f"txgemma-{MODEL_VARIANT}"

# The pre-built serving docker image.
SERVE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20250114_0916_RC00_maas"

# This notebook uses Nvidia L4 GPUs for demonstration.
# See https://cloud.google.com/vertex-ai/docs/predictions/configure-compute#batch_prediction
# for details on configuring compute for Vertex AI batch predictions.
if "2b" in MODEL_ID:
    MACHINE_TYPE = "g2-standard-12"
    ACCELERATOR_TYPE = "NVIDIA_L4"
    ACCELERATOR_COUNT = 1
elif "9b" in MODEL_ID:
    MACHINE_TYPE = "g2-standard-24"
    ACCELERATOR_TYPE = "NVIDIA_L4"
    ACCELERATOR_COUNT = 2
else:
    MACHINE_TYPE = "g2-standard-48"
    ACCELERATOR_TYPE = "NVIDIA_L4"
    ACCELERATOR_COUNT = 4


def upload_model(
    model_name: str,
    model_id: str,
    tensor_parallel_size: int = ACCELERATOR_COUNT,
    gpu_memory_utilization: float = 0.95,
) -> aiplatform.Model:
    vllm_args = [
        "python",
        "-m",
        "vllm.entrypoints.api_server",
        "--host=0.0.0.0",
        "--port=8080",
        f"--model={model_id}",
        f"--tensor-parallel-size={tensor_parallel_size}",
        f"--swap-space=16",
        f"--gpu-memory-utilization={gpu_memory_utilization}",
        "--enable-chunked-prefill",
        "--disable-log-stats",
    ]

    env_vars = {
        "MODEL_ID": model_id,
        "DEPLOY_SOURCE": "notebook",
        "HF_TOKEN": HF_TOKEN,
    }

    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=SERVE_DOCKER_URI,
        serving_container_args=vllm_args,
        serving_container_ports=[8080],
        serving_container_predict_route="/generate",
        serving_container_health_route="/ping",
        serving_container_environment_variables=env_vars,
    )

    return model


models["model"] = upload_model(
    model_name=common_util.get_job_name_with_datetime(prefix=MODEL_ID),
    model_id=f"google/{MODEL_ID}",
)

In [None]:
# @title Set up Google Cloud resources

# @markdown This section sets up a [Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing batch prediction inputs and outputs and gets the [Compute Engine default service account](https://cloud.google.com/compute/docs/access/service-accounts#default_service_account) which will be used to run the batch prediction jobs.

# @markdown 1. Make sure that you have the following required roles:
# @markdown - [Storage Admin](https://cloud.google.com/iam/docs/understanding-roles#storage.admin) (`roles/storage.admin`) to create and use Cloud Storage buckets
# @markdown - [Service Account User](https://cloud.google.com/iam/docs/understanding-roles#iam.serviceAccountUser) (`roles/iam.serviceAccountUser`) on either the project or the Compute Engine default service account

# @markdown 2. Set up a Cloud Storage bucket.
# @markdown - A new bucket will automatically be created for you.
# @markdown - [Optional] To use an existing bucket, specify the `gs://` bucket URI. The specified Cloud Storage bucket should be located in the same region as where the notebook was launched. Note that a multi-region bucket (e.g. "us") is not considered a match for a single region (e.g. "us-central1") covered by the multi-region range.

BUCKET_URI = ""  # @param {type:"string", placeholder:"[Optional] Cloud Storage bucket URI"}

# Cloud Storage bucket for storing batch prediction artifacts.
# A unique bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value of BUCKET_URI above.
if BUCKET_URI is None or BUCKET_URI.strip() == "":
    now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}-{str(uuid.uuid4())[:4]}"
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    ! gcloud storage buckets create --location {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    shell_output = ! gcloud storage buckets describe {BUCKET_NAME} | grep "location:" | sed "s/location://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            f"Bucket region {bucket_region} is different from notebook region {REGION}"
        )
print(f"Using this Cloud Storage Bucket: {BUCKET_URI}")

# Service account used for running the prediction container.
# Gets the Compute Engine default service account. If you prefer using your own
# custom service account, change the value of SERVICE_ACCOUNT below.
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
print("Using this service account:", SERVICE_ACCOUNT)

### Predict

You can send [batch prediction requests](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions#request_a_batch_prediction) to the model using a [JSON Lines](https://jsonlines.org/) file to specify a list of input instances with text prompts to generate predictions. For more details on configuring batch prediction jobs, see how to [format your input data](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions#input_data_requirements) and [choose compute settings](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions#choose_machine_type_and_replica_count).

In [None]:
# @title #### Generate predictions in batch from text prompts

# @markdown This section shows an example of using TxGemma to generate predictions in batch from text prompts.

batch_predict_instances = [
    {"prompt": TDC_PROMPT, "max_tokens": 8, "temperature": 0},
    {"prompt": TDC_PROMPT, "max_tokens": 8, "temperature": 0},
]

# Write instances to JSON Lines file
os.makedirs("batch_predict_input", exist_ok=True)
instances_filename = "gcs_instances.jsonl"
with open(f"batch_predict_input/{instances_filename}", "w") as f:
   for line in batch_predict_instances:
       json_str = json.dumps(line)
       f.write(json_str)
       f.write("\n")

# Copy the file to Cloud Storage
batch_predict_prefix = f"batch-predict-{MODEL_ID}"
! gcloud storage cp ./batch_predict_input/{instances_filename} {BUCKET_URI}/{batch_predict_prefix}/input/{instances_filename}

batch_predict_job_name = common_util.get_job_name_with_datetime(prefix=f"batch-predict-{MODEL_ID}")

batch_predict_job = models["model"].batch_predict(
    job_display_name=batch_predict_job_name,
    gcs_source=os.path.join(BUCKET_URI, batch_predict_prefix, f"input/{instances_filename}"),
    gcs_destination_prefix=os.path.join(BUCKET_URI, batch_predict_prefix, "output"),
    machine_type=MACHINE_TYPE,
    accelerator_type=ACCELERATOR_TYPE,
    accelerator_count=ACCELERATOR_COUNT,
    service_account=SERVICE_ACCOUNT,
)

batch_predict_job.wait()

print(batch_predict_job.display_name)
print(batch_predict_job.resource_name)
print(batch_predict_job.state)

In [None]:
# @title #### Get prediction results

# @markdown This section shows an example of [retrieving batch prediction results](https://cloud.google.com/vertex-ai/docs/predictions/get-batch-predictions#retrieve_batch_prediction_results) from the JSON Lines file(s) in the output Cloud Storage location.


def download_gcs_files_as_json(gcs_files_prefix):
    """Download specified files from Cloud Storage and convert content to JSON."""
    lines = []
    client = storage.Client()
    bucket = storage.bucket.Bucket.from_string(BUCKET_NAME, client)
    blobs = bucket.list_blobs(prefix=gcs_files_prefix)
    for blob in blobs:
        with blob.open("r") as f:
            for line in f:
                lines.append(json.loads(line))
    return lines


batch_predict_output_dir = batch_predict_job.output_info.gcs_output_directory
batch_predict_output_files_prefix = os.path.join(
    batch_predict_output_dir.replace(f"{BUCKET_NAME}/", ""),
    "prediction.results"
)
batch_predict_results = download_gcs_files_as_json(
    gcs_files_prefix=batch_predict_output_files_prefix
)

# Display first two batch prediction results
for i, line in enumerate(batch_predict_results[:2]):
    prediction = line["prediction"]
    print(prediction)

## Next steps

Explore the other [notebooks](https://github.com/google-gemini/gemma-cookbook/blob/main/TxGemma) to learn what else you can do with the model.

## Clean up resources

In [None]:
# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continuous charges that may incur.

# Undeploy model and delete endpoint.
for endpoint in endpoints.values():
    endpoint.delete(force=True)

# Delete models.
for model in models.values():
    model.delete()

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_NAME