# Inference - run pretrained model with kiba data

In [None]:
import sys
tankbind_src_folder_path = "./tankbind/"
sys.path.insert(0, tankbind_src_folder_path)

In [None]:
# imports from tankbind
from feature_utils import get_protein_feature, get_clean_res_list, extract_torchdrug_feature_from_mol, get_canonical_smiles
from utils import construct_data_from_graph_gvp, evaulate_with_affinity, evaulate
from model import get_model
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords
from metrics import print_metrics, myMetric

# general imports
import os
import pandas as pd
import numpy as np
import torch
import logging
from tqdm import tqdm

import rdkit.Chem as Chem
from rdkit.Chem import AllChem
from Bio.PDB import PDBParser
import torchmetrics
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader

import warnings
# warnings.filterwarnings("ignore") # only uncomment if appearing warnings are not relevant

## Load molecule_dict and protein_dict & kiba_data pt files

In [None]:
# load protein and molecule dictionaries & kiba_data
protein_dict = torch.load("data/protein_dict.pt")
molecule_dict = torch.load("data/molecule_dict.pt")
kiba_data = torch.load('data/kiba_data.pt') # kiba_data is the complete DataFrame with the P2Rank information

# Dataset class + Creation

I also return the target affinities together with the model input since some of the inputs might be discarded during training due to memory size issues. So I return both to keep them correctly assigned/ordered.

In [None]:
class MyDataset_VS(Dataset):
    def __init__(self, root, data=None, protein_dict=None, molecule_dict=None, proteinMode=0, compoundMode=1,
                 pocket_radius=20, shake_nodes=None,
                 transform=None, pre_transform=None, pre_filter=None):
        self.data = data
        self.protein_dict = protein_dict
        self.molecule_dict = molecule_dict
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data = torch.load(self.processed_paths[0])
        self.protein_dict = torch.load(self.processed_paths[1])
        self.molecule_dict = torch.load(self.processed_paths[2])
        self.proteinMode = proteinMode
        self.pocket_radius = pocket_radius
        self.compoundMode = compoundMode
        self.shake_nodes = shake_nodes

    @property
    def processed_file_names(self):
        return ['kiba_data.pt', 'protein_dict.pt', 'molecule_dict.pt']

    def process(self):
        # Save data and protein dictionary
        torch.save(self.data, self.processed_paths[0])
        torch.save(self.protein_dict, self.processed_paths[1])
        torch.save(self.molecule_dict, self.processed_paths[2])

    def len(self):
        return len(self.data)
    
    def get(self, idx):
        line = self.data.iloc[idx]
        smiles = line['smiles']
        target_affinity = line['target_affinity']
        pocket_com = line['pocket_com']
        pocket_com = np.array(pocket_com.split(",")).astype(float) if isinstance(pocket_com, str) else pocket_com
        pocket_com = pocket_com.reshape((1, 3))
        use_whole_protein = line.get('use_whole_protein', False)

        protein_name = line['protein_name']
        protein_data = self.protein_dict.get(protein_name)
        
        if protein_data is None:
            raise ValueError(f"Protein {protein_name} not found in pre-calculated protein dictionary")

        protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = protein_data

        # Load precomputed molecular features
        molecule_data = self.molecule_dict.get(smiles)
        if molecule_data is None:
            raise ValueError(f"SMILES {smiles} not found in precomputed molecular dictionary")
        
        coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = self.molecule_dict[smiles]

        data, input_node_list, keepNode = construct_data_from_graph_gvp(
            protein_node_xyz, protein_seq, protein_node_s, protein_node_v, 
            protein_edge_index, protein_edge_s, protein_edge_v,
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list,
            pocket_radius=self.pocket_radius, use_whole_protein=use_whole_protein, includeDisMap=True,
            use_compound_com_as_pocket=False, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode
        )
        data.compound_pair = pair_dis_distribution.reshape(-1, 16)
        
        return data, target_affinity

### Create dataset instance:

In [None]:
dataset_path = 'data/dataset' # Specify the path where the dataset will be stored (TODO: some directory cleanup)
dataset = MyDataset_VS(root=dataset_path, data=kiba_data, protein_dict=protein_dict, molecule_dict=molecule_dict) # only on first run, otherwise execute line below
# dataset = MyDataset_VS(root=dataset_path)

# Model testing

In [None]:
# check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
batch_size = 4

logging.basicConfig(level=logging.INFO)
model = get_model(0, logging, device)

# self-dock model
modelFile = "/system/user/studentwork/hernler/tankbind_project_data/self_dock.pt"

model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()

data_loader = DataLoader(dataset, batch_size=batch_size, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=0)
affinity_pred_list = []
y_pred_list = [] # add code to save y_preds
vector_representations = []
for x, y in tqdm(data_loader):        
    if x.dis_map.shape[0] < 20000:
        x = x.to(device) # only move x to device as y is not used in the model
        y_pred, affinity_pred, vector_repr = model(x)
        affinity_pred_list.append(affinity_pred.detach().cpu())
        vector_representations.append((vector_repr.detach().cpu(), y))
    else:
        affinity_pred_list.append(torch.zeros(batch_size).detach().cpu())
        vector_representations.append((torch.zeros((batch_size, 128)).detach().cpu(), y)) # TODO: check if 128 is the correct size for vector_repr


affinity_pred_list = torch.cat(affinity_pred_list)
vector_representations = torch.stack([v[0] for v in vector_representations]) # TODO: check if this is correct / if it is needed to also add target affinities, since the list bshould hae the same length as the data_loader/dataset(kiba_df)

In [None]:
# save the affinity predictions
torch.save(affinity_pred_list, 'data/affinity_pred.pt')

# save the vector representations
torch.save(vector_representations, 'vector_representations/vector_representations.pt')

### Add affinity predictions to kiba dataframe

In [None]:
kiba_df = dataset.data
kiba_df['affinity_pred'] = affinity_pred_list

# save the updated kiba_df with affinity predictions
kiba_df.to_csv('data/kiba_data_with_affinity_pred.csv', index=False)

### Add vector representations to kiba df (if needed, since labels [target_affinity] are already linked with it)

In [None]:
# TODO: Check if this is needed, works like this
# Note: currently the vector_reprs are a list with tuples of (vector_repr, target_affinity)

kiba_df['vector_repr'] = vector_representations.tolist()  # Convert tensor to list for DataFrame compatibility
# save the updated kiba_df with vector representations 
kiba_df.to_csv('vector_representations/kiba_data_with_vector_repr.csv', index=False)

# Evaluation of the predicted affinities (needed for comparison with new model)

In [None]:
# mean squared error
mse = torchmetrics.functional.mean_squared_error(kiba_df["affinity_pred"].values, kiba_df['target_affinity'].values, squared=False)

# mean absolute error
mae = torchmetrics.functional.mean_absolute_error(kiba_df['affinity_pred'].values, kiba_df['target_affinity'].values)

print(f"Mean Squared Error: {mse.item()}")
print(f"Mean Absolute Error: {mae.item()}")

OR (old function):

In [None]:
def eval_metrics(preds, targets):
    criterion = torch.nn.MSELoss()
    with torch.no_grad():
        mse = criterion(preds, targets)
        mae = torch.mean(torch.abs(preds - targets))
    return mse.item(), mae.item()

In [None]:
preds = torch.tensor(kiba_df['affinity_pred'].to_list(), requires_grad=True, device=device)
targets = torch.tensor(kiba_df['target_affinity'].to_list())

mse, mae = eval_metrics(preds, targets)
print(f"Mean Squared Error: {mse}")
print(f"Mean Absolute Error: {mae}")