In [1]:
import os
from google.colab import userdata
os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")

# Install dependencies

In [21]:
! pip install --upgrade --quiet bitsandbytes datasets peft transformers trl rdkit

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.3/34.3 MB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[?25h

# Load model from HF

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model = "google/txgemma-"
CHAT_VARIANT = "9b-chat" # @param ["9b-chat", "27b-chat"]

model_id = base_model + CHAT_VARIANT

# Use 4-bit quantization to reduce memory usage
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map={"":0},
    torch_dtype="auto",
    attn_implementation="eager",
)

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/852 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.67G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.90G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

# Load Dataset and Clean It

## Known Binders

In [22]:
import pandas as pd

df = pd.read_csv("Known_HIF_Binders.csv")

In [23]:
clean_df = df[["Ligand SMILES", "IC50 (nM)"]].dropna()

# remove rows that contain '<' or '>'
has_censor = clean_df["IC50 (nM)"] \
    .astype(str) \
    .str.contains(r"[<>]")

# count how many rows will be dropped
dropped_count = has_censor.sum()
print(f"Dropping {dropped_count} rows with '<' or '>' in IC50")

# keep only the rows *without* '<' or '>'
clean_df = clean_df.loc[~has_censor].reset_index(drop=True)
clean_df

Dropping 198 rows with '<' or '>' in IC50


Unnamed: 0,Ligand SMILES,IC50 (nM)
0,FC(F)(F)c1ccc(NC(=O)c2ccc(CN3CCOCC3)cn2)cc1,230
1,FC(F)S(=O)(=O)c1ccc(Oc2cc(F)cc(c2)C#N)c(Cl)c1C#N,230
2,CS(=O)(=O)c1ccc(Oc2cc(F)cc(Cl)c2)c2CCC(O)c12,230
3,Cc1cc(ccc1Oc1cc(F)cc(c1)C#N)S(=O)(=O)C(F)F,230
4,Fc1cc(Cl)cc(Oc2ccc(c(F)c2C#N)S(=O)(=O)C(F)(F)F)c1,230
...,...,...
949,FC(F)(F)S(=O)(=O)c1ccc(Oc2cccc(c2)C#N)c(Br)c1,210
950,[O-][N+](=O)c1c(Cl)c(ccc1Oc1cccc(Cl)c1)C(F)(F)F,210
951,Cc1nc(SCCC(O)=O)c2cc(sc2n1)-c1ccccc1,212
952,FC(F)c1cc(ccc1Oc1cc(F)cc(Cl)c1)S(=O)(=O)C(F)(F)F,220


In [27]:
import re
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, Lipinski

# --- 2) Clean & standardize IC50, compute pIC50 ---
def parse_ic50_to_pic50(ic50_str):
    """Convert a string like '<5' or '200' (in nM) to pIC50."""
    # strip any whitespace
    s = str(ic50_str).strip()
    try:
        nm = float(s)
    except ValueError:
        return np.nan  # unparseable

    # convert nM → M
    m = nm * 1e-9
    # pIC50
    pic50 = -np.log10(m)
    return pic50

clean_df["pIC50"] = clean_df["IC50 (nM)"].apply(parse_ic50_to_pic50)

# --- 3) Bin into activity classes ---
# strong binder if pIC50 ≥ 7 (IC50 ≤ 100 nM), else weak/non-binder
threshold = 7.0
clean_df["activity_class"] = np.where(clean_df["pIC50"] >= threshold, "strong", "weak")

# --- 4) Compute 2D descriptors via RDKit ---
def compute_descriptors(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return {
            "MolWt": np.nan,
            "TPSA": np.nan,
            "HBD": np.nan,
            "HBA": np.nan,
            "RotBonds": np.nan,
            "LogP": np.nan,
        }
    return {
        "MolWt": Descriptors.MolWt(mol),
        "TPSA": Descriptors.TPSA(mol),
        "HBD": Lipinski.NumHDonors(mol),
        "HBA": Lipinski.NumHAcceptors(mol),
        "RotBonds": Descriptors.NumRotatableBonds(mol),
        "LogP": Crippen.MolLogP(mol),
    }

# apply and expand into separate columns
desc_df = clean_df["Ligand SMILES"].apply(compute_descriptors).apply(pd.Series)
clean_df = pd.concat([clean_df, desc_df], axis=1)
clean_df["is_known_binder"] = True

# --- 5) View the table ---
print(clean_df.head())

clean_df.to_csv("binders_enriched.csv", index=False)


                                       Ligand SMILES IC50 (nM)     pIC50  \
0        FC(F)(F)c1ccc(NC(=O)c2ccc(CN3CCOCC3)cn2)cc1       230  6.638272   
1   FC(F)S(=O)(=O)c1ccc(Oc2cc(F)cc(c2)C#N)c(Cl)c1C#N       230  6.638272   
2       CS(=O)(=O)c1ccc(Oc2cc(F)cc(Cl)c2)c2CCC(O)c12       230  6.638272   
3         Cc1cc(ccc1Oc1cc(F)cc(c1)C#N)S(=O)(=O)C(F)F       230  6.638272   
4  Fc1cc(Cl)cc(Oc2ccc(c(F)c2C#N)S(=O)(=O)C(F)(F)F)c1       230  6.638272   

  activity_class    MolWt   TPSA  HBD  HBA  RotBonds     LogP    MolWt   TPSA  \
0           weak  365.355  54.46  1.0  4.0       4.0  3.18490  365.355  54.46   
1           weak  386.738  90.95  0.0  5.0       4.0  4.01106  386.738  90.95   
2           weak  356.802  63.60  1.0  4.0       3.0  3.65450  356.802  63.60   
3           weak  341.310  67.16  0.0  4.0       4.0  3.79440  341.310  67.16   
4           weak  397.708  67.16  0.0  4.0       3.0  4.57568  397.708  67.16   

   HBD  HBA  RotBonds     LogP  is_known_binder  
0  1.0

## Duds