In [None]:
import torch

print(torch.version.cuda)       # should report “12.8” (you might have to downgrade your pip install torch version from requirements.txt depending on what gpu you have)
print(torch.cuda.is_available())  # should be True

In [None]:
import os, re
from google.colab import userdata

os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

# Install Dependencies

In [None]:
! pip install --upgrade --quiet accelerate bitsandbytes huggingface_hub transformers

# Download the prediction and chat model from Hugging Face Hub

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

PREDICT_VARIANT = "2b-predict"  # @param ["2b-predict", "9b-predict", "27b-predict"]
CHAT_VARIANT = "9b-chat" # @param ["9b-chat", "27b-chat"]
USE_CHAT = True # @param {type: "boolean"}

quantization_config = BitsAndBytesConfig(load_in_4bit=True)

predict_tokenizer = AutoTokenizer.from_pretrained(f"google/txgemma-{PREDICT_VARIANT}")
predict_model = AutoModelForCausalLM.from_pretrained(
    f"google/txgemma-{PREDICT_VARIANT}",
    device_map="auto",
    quantization_config=quantization_config,
)

if USE_CHAT:
    chat_tokenizer = AutoTokenizer.from_pretrained(f"google/txgemma-{CHAT_VARIANT}")
    chat_model = AutoModelForCausalLM.from_pretrained(
        f"google/txgemma-{CHAT_VARIANT}",
        device_map="auto",
        quantization_config=quantization_config,
    )

# Looking at TxGemma's prompt format

TxGemma requires a specific input format since it was fine-tuned against TDC, so these lines are just used to visualize what the prompt structure is.

In [None]:
import json
from huggingface_hub import hf_hub_download

tdc_prompts_filepath = hf_hub_download(
    repo_id="google/txgemma-2b-predict",
    filename="tdc_prompts.json",
)

with open(tdc_prompts_filepath, "r") as f:
    tdc_prompts_json = json.load(f)

In [None]:
tdc_prompts_json[task_name]

# Test the prediction and chat model

In [None]:
## Example SMILE taken from Miko's compiled dataset for testing purposes
SMILES = "C1CCN(C1)C2=CC=CC=C2NC3=NS(=O)(=O)C4=CC=CC=C43"

HIF_PROMPT = f"""
Instructions: Answer the following question about ligand–protein binding.

Context: Hypoxia‐inducible factor 2α (HIF-2α) is a transcription factor whose activity depends on specific ligand binding.
Key properties influencing binding include measured or predicted binding affinity, lipophilicity (pLogP), and chemical
similarity to known HIF-2α binders.

Question: Given a ligand’s SMILES string below, predict whether it
  (A) does NOT bind HIF-2α
  (B) does bind HIF-2α

Ligand SMILES: {SMILES}

Answer:
"""

# Fine tuning TxGemma is required
def txgemma_predict(prompt):
    input_ids = predict_tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = predict_model.generate(**input_ids, max_new_tokens=8)
    return predict_tokenizer.decode(outputs[0], skip_special_tokens=True)

#
def txgemma_chat(prompt):
    input_ids = chat_tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = chat_model.generate(**input_ids, max_new_tokens=32)
    return chat_tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Prediction model response: {txgemma_predict(HIF_PROMPT)}")
print("=================================================================")
if USE_CHAT: print(f"Chat model response: {txgemma_chat(HIF_PROMPT)}")