# Inference - run pretrained model with kiba data

In [None]:
import os
# Set which GPU to use (here, GPU 2) by setting the CUDA_VISIBLE_DEVICES environment variable 
# (important if executed on server, check which GPU is available)
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda:0


In [2]:
gpu_id = device.index
print('Gpu name:', torch.cuda.get_device_name(gpu_id))
print('Memory allocated:', round(torch.cuda.memory_allocated(gpu_id) / (1024 ** 3),2), 'GB')
print('Memory cached:', round(torch.cuda.memory_reserved(gpu_id) / (1024 ** 3),2), 'GB')

Gpu name: Tesla P40
Memory allocated: 0.0 GB
Memory cached: 0.0 GB


In [None]:
import sys
# Add the tankbind source folder to the Python path (needed for imports from tankbind to work)
tankbind_src_folder_path = "./tankbind/"
sys.path.insert(0, tankbind_src_folder_path)

In [4]:
# imports from tankbind
from feature_utils import get_protein_feature, get_clean_res_list, extract_torchdrug_feature_from_mol, get_canonical_smiles
from utils import construct_data_from_graph_gvp, evaulate_with_affinity, evaulate
from model import get_model
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords
from metrics import print_metrics, myMetric

# general imports
import os
import pandas as pd
import numpy as np
import torch
import logging
from tqdm.notebook import tqdm

import rdkit.Chem as Chem
from rdkit.Chem import AllChem
from Bio.PDB import PDBParser
import torchmetrics
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader

import warnings
# warnings.filterwarnings("ignore") # NOTE: only uncomment if appearing warnings are not relevant

## Load molecule_dict and protein_dict & kiba_data pt files

In [None]:
# load protein and molecule dictionaries & filtered kiba_data
protein_dict = torch.load("data/protein_dict.pt")
molecule_dict = torch.load("data/molecule_dict.pt")
kiba_data = torch.load('data/kiba_data_small_dismap.pt') # NOTE: kiba_data is the complete DataFrame with the P2Rank information

### Remove samples with dis_map > 10000 (due to memory issues) --> removed 12620 samples

In [None]:
# # Uncomment & execute only if the dataset was previously created with the unfiltered kiba dataset (inlcuding dis_map size > 10000)

# # check & drop dis_map size > 10000
# data_loader = DataLoader(dataset, batch_size=1, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=0)
# indices = [] # list to store indices of elements with dis_map size > 10000
# for i, (elem, y) in tqdm(enumerate(data_loader)):
#    if elem.dis_map.shape[0] > 10000:
#       indices.append(i)
# print(f"Number of elements with dis_map size > 10000: {len(indices)}")

# data_no_dis_map = dataset.data
# data_no_dis_map = data_no_dis_map.drop(indices)  # drop the elements with dis_map size > 10000
# torch.save(data_no_dis_map, 'data/kiba_data_small_dismap.pt')  # save the modified dataset without dis_map

118254it [18:46, 104.96it/s]

Number of elements with dis_map size > 10000: 0





# Dataset class & Dataset creation

I also return the target affinities together with the model input since some of the inputs might be discarded during training due to memory size issues. So I return both to keep them correctly assigned/ordered.

NOTE: The dataset class is copied from the Tankbind repository, with only some small changes to also return the target affinity.

In [None]:
class MyDataset_VS(Dataset):
    def __init__(self, root, data=None, protein_dict=None, molecule_dict=None, proteinMode=0, compoundMode=1,
                 pocket_radius=20, shake_nodes=None,
                 transform=None, pre_transform=None, pre_filter=None):
        self.data = data
        self.protein_dict = protein_dict
        self.molecule_dict = molecule_dict
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data = torch.load(self.processed_paths[0])
        self.protein_dict = torch.load(self.processed_paths[1])
        self.molecule_dict = torch.load(self.processed_paths[2])
        self.proteinMode = proteinMode
        self.pocket_radius = pocket_radius
        self.compoundMode = compoundMode
        self.shake_nodes = shake_nodes

    @property
    def processed_file_names(self):
        return ['kiba_data.pt', 'protein_dict.pt', 'molecule_dict.pt']

    def process(self):
        # Save data and protein dictionary
        torch.save(self.data, self.processed_paths[0])
        torch.save(self.protein_dict, self.processed_paths[1])
        torch.save(self.molecule_dict, self.processed_paths[2])

    def len(self):
        return len(self.data)
    
    def get(self, idx):
        line = self.data.iloc[idx]
        smiles = line['smiles']
        target_affinity = line['target_affinity'] # get the target affinity value so we can return it together with the data (used for evaluation)
        pocket_com = line['pocket_com']
        pocket_com = np.array(pocket_com.split(",")).astype(float) if isinstance(pocket_com, str) else pocket_com
        pocket_com = pocket_com.reshape((1, 3))
        use_whole_protein = line.get('use_whole_protein', False)

        protein_name = line['protein_name']
        protein_data = self.protein_dict.get(protein_name)
        
        if protein_data is None:
            raise ValueError(f"Protein {protein_name} not found in pre-calculated protein dictionary")

        protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = protein_data

        # Load precomputed molecular features
        molecule_data = self.molecule_dict.get(smiles)
        if molecule_data is None:
            raise ValueError(f"SMILES {smiles} not found in precomputed molecular dictionary")
        
        coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = self.molecule_dict[smiles]

        data, input_node_list, keepNode = construct_data_from_graph_gvp(
            protein_node_xyz, protein_seq, protein_node_s, protein_node_v, 
            protein_edge_index, protein_edge_s, protein_edge_v,
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list,
            pocket_radius=self.pocket_radius, use_whole_protein=use_whole_protein, includeDisMap=True,
            use_compound_com_as_pocket=False, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode
        )
        data.compound_pair = pair_dis_distribution.reshape(-1, 16)
        
        return data, target_affinity

### Create dataset instance:

In [None]:
dataset_path = 'data' # Specify the path where the dataset will be stored

# dataset = MyDataset_VS(root=dataset_path, data=kiba_data, protein_dict=protein_dict, molecule_dict=molecule_dict) # NOTE: use this only on first run, otherwise execute line below

dataset = MyDataset_VS(root=dataset_path)

# Model testing - takes around 33h on cpu

In [None]:
import model
import importlib

# reload the whole module so changes in IaBNet_with_affinity and get_model are reflected
importlib.reload(model)

# NOTE: forced reload is needed so the changes made in the Tankbind files are actually applied

<module 'model' from '/system/user/studentwork/hernler/./tankbind/model.py'>

### Masking function for bringing the vector representations to the same size

In [9]:
# convert z to same size vector representation 
def masked_mean_pool(z, z_mask):
    z_mask_unsqueezed = z_mask.unsqueeze(-1)  # [B, P, C, 1]
    masked_z = z * z_mask_unsqueezed
    sum_z = masked_z.sum(dim=(1, 2))  # [B, H]
    norm = z_mask_unsqueezed.sum(dim=(1, 2)) + 1e-6  # [B, 1]
    return sum_z / norm  # [B, H]


In [None]:
# Clears GPU memory - only uncomment/execute if necessary
# torch.cuda.empty_cache()
# torch.cuda.ipc_collect()  # Optional: cleans up inter-process memory

In [None]:
# !!!! Code without dis_map check: !!!! NOTE: use this only with the filtered dataset

batch_size = 6 # NOTE: max batchsize for for GPU with 24GB RAM (with higher batchsize --> I recieved a memory error when testing was around 97% complete)

logging.basicConfig(level=logging.INFO)
model = get_model(0, logging, device)

# load pretrainded self-dock model
modelFile = "model/self_dock.pt"

model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()

data_loader = DataLoader(dataset, batch_size=batch_size, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=0)

affinity_pred_list = []
vector_representations = []

for x, y in tqdm(data_loader):     
    x = x.to(device) # only move x to device as y is not used in the model
    y_pred, affinity_pred = model(x)

    vector_repr = masked_mean_pool(model.vec_repr, model.z_mask) # apply the masked mean pooling to the vector representation

    affinity_pred_list.append(affinity_pred.detach().cpu())
    vector_representations.append(vector_repr.detach().cpu())

# concatenate the lists into tensors
affinity_pred_list = torch.cat(affinity_pred_list)
vector_representations = torch.cat(vector_representations)

# save the affinity predictions & vector representations
torch.save(affinity_pred_list, 'data/affinity_pred.pt')
torch.save(vector_representations, 'vector_representations/vector_representations.pt')

21:00:16   5 stack, readout2, pred dis map add self attention and GVP embed, compound model GIN


  0%|          | 0/17606 [00:00<?, ?it/s]

In [None]:
# check tensor shapes
print(f"affinity_pred_list shape: {affinity_pred_list.shape}")
print(f"vector_representations shape: {vector_representations.shape}")
print(f"vector_representations[0] shape: {vector_representations[0].shape}") # shape of the first vector representation

affinity_pred_list shape: torch.Size([105634])
vector_representations shape: torch.Size([105634, 128])
vector_representations[0] shape: torch.Size([128])


### Add affinity predictions to kiba dataframe

In [None]:
# kiba_df = dataset.data

# if new session, load the affinity predictions & the kiba df
affinity_pred_list = torch.load('data/affinity_pred.pt')
kiba_df = torch.load('data/kiba_data_small_dismap.pt')  # load the filtered kiba_df (used for inference)

kiba_df['affinity_pred'] = affinity_pred_list
kiba_df.head()
# save the updated kiba_df with affinity predictions
# kiba_df.to_csv('data/kiba_data_with_affinity_pred.csv', index=False)

Unnamed: 0.1,Unnamed: 0,protein_name,compound_name,smiles,pocket_name,pocket_com,target_affinity,affinity_pred
0,0,2R5T,,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,best_p2rank_pocket,"32.53,34.506,67.174",11.1,6.020516
1,1,3BRT,,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,best_p2rank_pocket,"14.396,20.696,11.566",11.1,1.52152
2,2,3BRT,,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,best_p2rank_pocket,"14.396,20.696,11.566",11.1,1.52152
3,3,1IVO,,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,best_p2rank_pocket,"115.598,69.377,45.458",11.1,2.156317
4,4,1MFG,,COC1=C(C=C2C(=C1)CCN=C2C3=CC(=C(C=C3)Cl)Cl)Cl,best_p2rank_pocket,"8.068,0.821,17.188",11.1,3.940505


# Evaluation of the predicted affinities

**Mean squared error, mean absolute error & r2-score:**

In [None]:
def eval_metrics(targets, predictions):
    mse = torchmetrics.functional.mean_squared_error(torch.tensor(predictions), torch.tensor(targets))
    mae = torchmetrics.functional.mean_absolute_error(torch.tensor(predictions), torch.tensor(targets))
    r2 = torchmetrics.functional.r2_score(torch.tensor(predictions), torch.tensor(targets))
    return mse.item(), mae.item(), r2.item()

**Concordance index:**

In [None]:
def concordance_idx(y_true, y_pred):
    # 1. Convert inputs to NumPy arrays
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    # 2. Sort by predicted values
    order = np.argsort(y_pred)
    y_true = y_true[order]

    # 3. Find unique labels and group sizes
    unique_labels, inverse = np.unique(y_true, return_inverse=True)
    counts = np.bincount(inverse)  # number of samples per unique label

    print(f"Found {len(unique_labels)} unique labels")

    # 4. Compute cumulative counts
    cum_counts = np.cumsum(counts)  # prefix sum for fast pair counting

    concordant = 0
    total_pairs = 0

    # 5. Iterate over label groups
    for i, count_i in enumerate(tqdm(counts, desc="Processing label groups")):
        # Total pairs with higher labels (since sorted by prediction, no pair expansion)
        higher = cum_counts[-1] - cum_counts[i]
        total_pairs += count_i * higher

        # Since it is sorted by prediction, all these pairs are concordant
        concordant += count_i * higher

    c_index = concordant / total_pairs if total_pairs > 0 else np.nan

    return c_index


In [None]:
mse, mae, r2 = eval_metrics(kiba_df['target_affinity'].values, affinity_pred_list)
c_index = concordance_idx(kiba_df['target_affinity'].values, affinity_pred_list)

OR (old function):

In [6]:
def eval_metrics(preds, targets):
    criterion = torch.nn.MSELoss()
    with torch.no_grad():
        mse = criterion(preds, targets)
        mae = torch.mean(torch.abs(preds - targets))
    return mse.item(), mae.item()

In [8]:
preds = torch.tensor(kiba_df['affinity_pred'].to_list(), requires_grad=True, device=device)
targets = torch.tensor(kiba_df['target_affinity'].to_list(), device=device)

mse, mae = eval_metrics(preds, targets)
print(f"Mean Squared Error: {mse:.4f}")
print(f"Mean Absolute Error: {mae:.4f}")

Mean Squared Error: 36.7511
Mean Absolute Error: 5.6623


**Concordance index:**

In [11]:
def concordance_idx(y_true, y_pred):
    print("[Step 1] Converting inputs to NumPy arrays...")
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    print("[Step 2] Sorting by predicted values...")
    order = np.argsort(y_pred)
    y_true = y_true[order]

    print("[Step 3] Finding unique labels and group sizes...")
    unique_labels, inverse = np.unique(y_true, return_inverse=True)
    counts = np.bincount(inverse)  # number of samples per unique label

    print(f"    Found {len(unique_labels)} unique labels")

    print("[Step 4] Computing cumulative counts...")
    cum_counts = np.cumsum(counts)  # prefix sum for fast pair counting

    concordant = 0
    total_pairs = 0

    print("[Step 5] Iterating over label groups...")
    for i, count_i in enumerate(tqdm(counts, desc="Processing label groups")):
        # Total pairs with higher labels (since sorted by prediction, no pair expansion)
        higher = cum_counts[-1] - cum_counts[i]
        total_pairs += count_i * higher

        # Since we sorted by prediction, all these pairs are concordant
        concordant += count_i * higher

    print("[Step 6] Finalizing concordance index...")
    c_index = concordant / total_pairs if total_pairs > 0 else np.nan
    print(f"[Done] Concordance Index: {c_index:.4f}")

    return c_index


In [13]:
c_index = concordance_idx(kiba_df['target_affinity'].to_numpy(), kiba_df['affinity_pred'].to_numpy())

[Step 1] Converting inputs to NumPy arrays...
[Step 2] Sorting by predicted values...
[Step 3] Finding unique labels and group sizes...
    Found 2678 unique labels
[Step 4] Computing cumulative counts...
[Step 5] Iterating over label groups...


Processing label groups:   0%|          | 0/2678 [00:00<?, ?it/s]

[Step 6] Finalizing concordance index...
[Done] Concordance Index: 1.0000


## Save target affinities for model training later

In [20]:
targets = torch.tensor(kiba_df['target_affinity'].values)

torch.save(targets, "vector_representations/labels.pt")

# load with:
# labels = torch.load("vector_representations/labels.pt")