In [1]:
import os
import sys

# This is to force the path to be on the same level as the dl_ba folder
sys.path.append("../..")

from transformers import AutoTokenizer
import torch
from datasets import load_dataset

import time

from balm import common_utils
from balm.models.utils import load_trained_model, load_pretrained_pkd_bounds
from balm.configs import Configs
from balm.models import BALM

DEVICE = "cuda"

# Load Pretrained BindingDB BALM

In [2]:
config_filepath = "../../default_configs/balm_peft.yaml"
configs = Configs(**common_utils.load_yaml(config_filepath))

# Load the model
model = BALM(configs.model_configs)
model = load_trained_model(model, configs.model_configs, is_training=False)
model.to(DEVICE)
model.eval()
# Pretrained pKd lower and upper bounds
pkd_lower_bound, pkd_upper_bound = load_pretrained_pkd_bounds(configs.model_configs.checkpoint_path)

# Load the tokenizers
protein_tokenizer = AutoTokenizer.from_pretrained(
    configs.model_configs.protein_model_name_or_path
)
drug_tokenizer = AutoTokenizer.from_pretrained(
    configs.model_configs.drug_model_name_or_path
)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  checkpoint = torch.load(


trainable params: 128,160 || all params: 148,923,641 || trainable%: 0.0861
trainable params: 221,184 || all params: 3,648,624 || trainable%: 6.0621
Loading checkpoint from BALM/bdb-cleaned-r-esm-lokr-chemberta-loha-cosinemse
Merging protein model with its adapter
Merging drug model with its adapter


# Polaris Data Loading

In [3]:
import polaris as po

CHALLENGE = 'antiviral-potency-2025'
competition = po.load_competition(f"asap-discovery/{CHALLENGE}")

train, test = competition.get_train_test_split()

In [4]:
import numpy as np

sars2_train = []
mers_train = []

for t in train:
    if not np.isnan(t[1]['pIC50 (SARS-CoV-2 Mpro)']):
        sars2_train.append((t[0], t[1]['pIC50 (SARS-CoV-2 Mpro)']))
    if not np.isnan(t[1]['pIC50 (MERS-CoV Mpro)']):
        mers_train.append((t[0], t[1]['pIC50 (MERS-CoV Mpro)']))

# Pretrained prediction

In [21]:
start = time.time()
predictions = {'pIC50 (SARS-CoV-2 Mpro)': [], 'pIC50 (MERS-CoV Mpro)': []}

sars2_chains = (
    "SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDVVYCPRHVICTSEDMLNPNYEDLLIRKSNHNFLVQAGNVQLRVIGHSMQNCVLKLKVDTANPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNFTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGNFYGPFVDRQTAQAAGTDTTITVNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCASLKELLQNGMNGRTILGSALLEDEFTPFDVVRQCSGVT",
    "SGFRKMAFPSGKVEGCMVQVTCGTTTLNGLWLDDVVYCPRHVICTSEDMLNPNYEDLLIRKSNHNFLVQAGNVQLRVIGHSMQNCVLKLKVDTANPKTPKYKFVRIQPGQTFSVLACYNGSPSGVYQCAMRPNFTIKGSFLNGSCGSVGFNIDYDCVSFCYMHHMELPTGVHAGTDLEGNFYGPFVDRQTAQAAGTDTTITVNVLAWLYAAVINGDRWFLNRFTTTLNDFNLVAMKYNYEPLTQDHVDILGPLSAQTGIAVLDMCASLKELLQNGMNGRTILGSALLEDEFTPFDVVRQCSGVT",
)
mers_chains = (
    "SGLVKMSHPSGDVEACMVQVTCGSMTLNGLWLDNTVWCPRHVMCPADQLSDPNYDALLISMTNHSFSVQKHIGAPANLRVVGHAMQGTLLKLTVDVANPSTPAYTFTTVKPGAAFSVLACYNGRPTGTFTVVMRPNYTIKGSFLCGSCGSVGYTKEGSVINFCYMHQMELANGTHTGSAFDGTMYGAFMDKQVHQVQLTDKYCSVNVVAWLYAAILNGCAWFVKPNRTSVVSFNEWALANQFTEFVGTQSVDMLAVKTGVAIEQLLYAIQQLYTGFQGKQILGSTMLEDEFTPEDVNMQIMGV",
    "SGLVKMSHPSGDVEACMVQVTCGSMTLNGLWLDNTVWCPRHVMCPADQLSDPNYDALLISMTNHSFSVQKHIGAPANLRVVGHAMQGTLLKLTVDVANPSTPAYTFTTVKPGAAFSVLACYNGRPTGTFTVVMRPNYTIKGSFLCGSCGSVGYTKEGSVINFCYMHQMELANGTHTGSAFDGTMYGAFMDKQVHQVQLTDKYCSVNVVAWLYAAILNGCAWFVKPNRTSVVSFNEWALANQFTEFVGTQSVDMLAVKTGVAIEQLLYAIQQLYTGFQGKQILGSTMLEDEFTPEDVNMQIMGV",
)

for entry in test:
    # Prepare input
    sars2_inputs = protein_tokenizer(sars2_chains[0]+sars2_chains[1], return_tensors="pt").to(DEVICE)
    mers_inputs = protein_tokenizer(mers_chains[0]+mers_chains[1], return_tensors="pt").to(DEVICE)
    drug_inputs = drug_tokenizer(entry, return_tensors="pt").to(DEVICE)
    sars2_inputs = {
        "protein_input_ids": sars2_inputs["input_ids"],
        "protein_attention_mask": sars2_inputs["attention_mask"],
        "drug_input_ids": drug_inputs["input_ids"],
        "drug_attention_mask": drug_inputs["attention_mask"],
    }
    mers_inputs = {
        "protein_input_ids": mers_inputs["input_ids"],
        "protein_attention_mask": mers_inputs["attention_mask"],
        "drug_input_ids": drug_inputs["input_ids"],
        "drug_attention_mask": drug_inputs["attention_mask"],
    }
    sars2_prediction = model(sars2_inputs)["cosine_similarity"]
    sars2_prediction = model.cosine_similarity_to_pkd(
        sars2_prediction,
        pkd_upper_bound=pkd_upper_bound,
        pkd_lower_bound=pkd_lower_bound,
    )
    mers_prediction = model(mers_inputs)["cosine_similarity"]
    mers_prediction = model.cosine_similarity_to_pkd(
        mers_prediction,
        pkd_upper_bound=pkd_upper_bound,
        pkd_lower_bound=pkd_lower_bound,
    )
    predictions['pIC50 (SARS-CoV-2 Mpro)'].append(sars2_prediction.item())
    predictions['pIC50 (MERS-CoV Mpro)'].append(mers_prediction.item())

print(f"Time taken for {len(list(test))*2} predictions: {time.time() - start}")

Time taken for 594 predictions: 16.27002239227295


In [30]:
competition.submit_predictions(
    predictions=predictions,
    prediction_name="BALM_potency_pretrained",
    prediction_owner="ialibay",
    report_url="https://github.com/meyresearch/polaris_challenge/tree/potency",
    github_url="https://github.com/meyresearch/polaris_challenge/tree/potency",
    user_attributes={"Method": "BALM"},
)

Output()

# Few shot training

TBD