# Dataset creation & Model Testing

In [57]:
# imports
import sys
import os
import logging
import numpy as np
import random
import pandas as pd
import torch
from tqdm import tqdm

import rdkit.Chem as Chem
from Bio.PDB import PDBParser
import torchmetrics
from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader

Add tankbind directory to System path:

In [3]:
tankbind_src_folder_path = "./tankbind/"
sys.path.insert(0, tankbind_src_folder_path)

In [52]:
# imports from tankbind
from feature_utils import get_protein_feature, get_clean_res_list, extract_torchdrug_feature_from_mol, get_canonical_smiles
from utils import construct_data_from_graph_gvp, evaulate_with_affinity, evaulate
from model import get_model
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords
from metrics import print_metrics, myMetric

In [6]:
dataset_path = f"dataset_10"
df = pd.read_csv(f'{dataset_path}/Mcule_10000.csv', index_col=0)

Get the protein names:

In [None]:
def extract_protein_names(file_path):
    protein_names = []
    with open(file_path, 'r') as file:
        for line in file:
            # Get the base name without directory path
            base_name = os.path.basename(line.strip())
            # Split the base name to get the protein name without extension
            protein_name = os.path.splitext(base_name)[0]
            protein_names.append(protein_name)
    return protein_names

file_path = 'pdb_dataset.ds'
protein_names = extract_protein_names(file_path)

# print(protein_names[0:5]. len(protein_names))

Randomly select 1/10 of the dataset:

In [None]:
n = len(protein_names) * 0.1
sampled_protein_names = random.sample(protein_names, n)

# print(sampled_protein_names[0:5], len(sampled_protein_names))

#### Get protein features:

In [None]:
def process_proteins(protein_names, pdb_directory):
    parser = PDBParser(QUIET=True)
    protein_dict = {}

    for proteinName in protein_names:
        try:
            proteinFile = f"{pdb_directory}/{proteinName}.pdb"
            s = parser.get_structure(proteinName, proteinFile)
            res_list = list(s.get_residues())
            clean_res_list = get_clean_res_list(res_list, ensure_ca_exist=True)
            protein_dict[proteinName] = get_protein_feature(clean_res_list)
        except Exception as e:
            print(f"Error processing {proteinName}: {e}")

    return protein_dict

pdb_directory = "PDB_files"  # Directory containing PDB files

protein_dict = process_proteins(sampled_protein_names, pdb_directory)

# torch.save(protein_dict, 'protein_dict.pt')   # execute only in first notebook run

In [None]:
protein_names_new = list(protein_dict.keys())  # updated protein names after processing, in case some proteins failed to process

# print(len(protein_names_new))

#### Create protein dataframe:

In [None]:
info = []
for protein_name in protein_names_new:
    for i, line in tqdm(df.iterrows(), total=len(df)):
        smiles = line['smiles']
        compund_name = ''
        protein_name = protein_name
        com = ",".join([str(x.round(3)) for x in protein_dict[protein_name][0].mean(axis=0).numpy()])
        info.append([protein_name, compund_name, smiles, "protein_center", com])

info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'smiles', 'pocket_name', 'pocket_com'])

# torch.save(info, 'data.pt')   # execute only in first notebook run

### Construct Dataset

In [7]:
torch.set_num_threads(1)

In [8]:
class MyDataset_VS(Dataset):
    def __init__(self, root, data=None, protein_dict=None, molecule_dict=None, proteinMode=0, compoundMode=1,
                 pocket_radius=20, shake_nodes=None,
                 transform=None, pre_transform=None, pre_filter=None):
        self.data = data
        self.protein_dict = protein_dict
        self.molecule_dict = molecule_dict
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data = torch.load(self.processed_paths[0])
        self.protein_dict = torch.load(self.processed_paths[1])
        self.molecule_dict = torch.load(self.processed_paths[2])
        self.proteinMode = proteinMode
        self.pocket_radius = pocket_radius
        self.compoundMode = compoundMode
        self.shake_nodes = shake_nodes

    @property
    def processed_file_names(self):
        return ['data.pt', 'proteins.pt', 'molecules.pt']

    def process(self):
        # Save data and protein dictionary
        torch.save(self.data, self.processed_paths[0])
        torch.save(self.protein_dict, self.processed_paths[1])
        torch.save(self.molecule_dict, self.processed_paths[2])

    def len(self):
        return len(self.data)

    def get(self, idx):
        line = self.data.iloc[idx]
        smiles = line['smiles']
        pocket_com = line['pocket_com']
        pocket_com = np.array(pocket_com.split(",")).astype(float) if isinstance(pocket_com, str) else pocket_com
        pocket_com = pocket_com.reshape((1, 3))
        use_whole_protein = line.get('use_whole_protein', False)

        protein_name = line['protein_name']
        protein_data = self.protein_dict.get(protein_name)
        
        if protein_data is None:
            raise ValueError(f"Protein {protein_name} not found in pre-calculated protein dictionary")

        protein_node_xyz, protein_seq, protein_node_s, protein_node_v, protein_edge_index, protein_edge_s, protein_edge_v = protein_data

        # Load precomputed molecular features
        molecule_data = self.molecule_dict.get(smiles)
        if molecule_data is None:
            raise ValueError(f"SMILES {smiles} not found in precomputed molecular dictionary")
        
        coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list, pair_dis_distribution = self.molecule_dict[smiles]

        data, input_node_list, keepNode = construct_data_from_graph_gvp(
            protein_node_xyz, protein_seq, protein_node_s, protein_node_v, 
            protein_edge_index, protein_edge_s, protein_edge_v,
            coords, compound_node_features, input_atom_edge_list, input_atom_edge_attr_list,
            pocket_radius=self.pocket_radius, use_whole_protein=use_whole_protein, includeDisMap=True,
            use_compound_com_as_pocket=False, chosen_pocket_com=pocket_com, compoundMode=self.compoundMode
        )
        data.compound_pair = pair_dis_distribution.reshape(-1, 16)
        
        return data

In [9]:
# create dataset dirrectory
# os.system(f"rm -r {dataset_path}") # only on first run
# os.system(f"mkdir -p {dataset_path}") # only on first run

Precalculate the molecule_dict s.t. the Dataset creation is no longer dependent on torchdrug:

In [None]:
# execute only in first notebook run
smiles_list = info['smiles'].to_list()
smiles_set = set(smiles_list)

molecule_dict = {}
for molecule in smiles_set:
    smiles = get_canonical_smiles(molecule)
    mol = Chem.MolFromSmiles(smiles)
    mol.Compute2DCoords()
    molecule_dict[molecule] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)

torch.save(molecule_dict, 'dataset_10/processed/molecules.pt')   

In [None]:
# only excute after running the code above once (& only if needed), otherwise FileNotFoundError
data = torch.load(f'{dataset_path}/processed/data.pt')
protein_dict = torch.load(f'{dataset_path}/processed/proteins.pt')
molecule_dict = torch.load(f'{dataset_path}/processed/molecules.pt')

#### Create dataset instance:

In [11]:
# dataset = MyDataset_VS(root=dataset_path, data=info, protein_dict=protein_dict, molecule_dict=molecule_dict) # only on first run, otherwise execute line below
dataset = MyDataset_VS(root=dataset_path)

### Model Testing

In [None]:
batch_size = 5 # higher batchsize possible only if enough memmory is available (eg.: 10)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

logging.basicConfig(level=logging.INFO)

model = get_model(0, logging, device)

# load self-dock model
modelFile = "./model/self_dock.pt"

model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()

data_loader = DataLoader(dataset, batch_size=batch_size, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=0) # changed num_workers from 8 to 0, due to multiprocessing error
affinity_pred_list = []
for data in tqdm(data_loader):
    if data.dis_map.shape[0] < 50000: # to filter out proteins with dismap size > 10000 -> 50000 due to batch size 5, can be adjusted as needed (if it's too big, it will cause a memory error)
        data = data.to(device)
        y_pred, affinity_pred = model(data)
        affinity_pred_list.append(affinity_pred.detach().cpu())
    else:
        affinity_pred_list.append(torch.zeros(5, 1))

affinity_pred_list = torch.cat(affinity_pred_list)

In [None]:
info = dataset.data
info['affinity'] = affinity_pred_list

In [None]:
info.to_csv(f"{dataset_path}/result_info.csv")

### Choose compund with highest affinity score for each protein:

In [18]:
info = pd.read_csv(f'{dataset_path}/result_info.csv') # only execute if needed

In [20]:
chosen = info.loc[info.groupby('protein_name',sort=False)['affinity'].agg('idxmax')].reset_index()

In [27]:
chosen = chosen[chosen['affinity'] != 0]
chosen_smiles = chosen['smiles'].tolist()

chosen_data = MyDataset_VS(root=dataset_path, data=chosen) 
chosen_data.data = chosen

# print(len(chosen_data))

In [28]:
chosen

Unnamed: 0.2,index,Unnamed: 0.1,Unnamed: 0,protein_name,compound_name,smiles,pocket_name,pocket_com,affinity
0,9658,9658,9658,2R3G,,N1(N=CC(C(=O)NCC2=CC=C(C(OC)=C2)OC)=C1COC)C1N=...,protein_center,"1.924,29.866,21.385",8.819671
1,15087,15087,15087,1P4O,,C(/N1CCC(CC1)(C(N)=O)N1CCCCC1)(\NC1=NC(C)=CC(C...,protein_center,"17.63,64.754,17.18",6.751139
2,29273,29273,29273,8AO3,,N12N=C(SC1=NC(=O)C(=CC1C=C(N(C3C=CC(=CC=3)CC)C...,protein_center,"-0.879,3.457,37.96",9.114771
3,38556,38556,38556,6W4P,,C12(C3C=CC=CC=3C(C3C=CC=CC=3)(C3=CC=CC=C3)N1C1...,protein_center,"178.04,175.021,180.363",5.292430
4,44114,44114,44114,3OJM,,N1C(C2C=CC(=CC=2)C)=CSC=1NN=CC1C2C=CC=CC=2C=C2...,protein_center,"3.855,32.326,32.293",6.531941
...,...,...,...,...,...,...,...,...,...
81,816059,476059,816059,5U1M,,CN(C1SC2C(=CC(=CC=2)Cl)N=1)C(C1(C#N)CCCCC1)=O,protein_center,"8.684,19.363,14.771",8.048314
82,829833,489833,829833,7UP4,,C1(C(=O)N2CCC3C(=CC=CC=3)C2)=NN(C2CCC(CC1=2)N1...,protein_center,"27.132,49.381,62.994",4.963066
83,836655,496655,836655,1W7H,,N1(CC2=C(C)C=C(C=C2C)C)C=NC2C(=CC3=C(CCC(C)C3)...,protein_center,"21.934,25.734,41.109",8.003905
84,847459,507459,847459,6YPK,,N1([C@H]2[C@](C)(O)CCN(C(=O)CCC3N(CC)C4=C(C=CC...,protein_center,"-106.935,-186.89,310.699",7.838318


In [30]:
# get model if previous code not executed in this notebook
device = 'cuda' if torch.cuda.is_available() else 'cpu'

logging.basicConfig(level=logging.INFO)

model = get_model(0, logging, device)

# load self-dock model
modelFile = "./model/self_dock.pt"

model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()


14:28:29   5 stack, readout2, pred dis map add self attention and GVP embed, compound model GIN


In [32]:
data_loader = DataLoader(chosen_data, batch_size=1, follow_batch=['x', 'y', 'compound_pair'], shuffle=False, num_workers=0)

y_preds = []
tankbind_list = []

for i, data_with_batch_info in enumerate(tqdm(data_loader)):

    y_pred, affinity_pred = model(data_with_batch_info)

    coords = data_with_batch_info.coords.to(device)
    protein_nodes_xyz = data_with_batch_info.node_xyz.to(device)
    n_compound = coords.shape[0] 
    n_protein = protein_nodes_xyz.shape[0] 
    y_pred = y_pred.reshape(n_protein, n_compound).to(device).detach()
    y_preds.append(y_pred)
    y = data_with_batch_info.dis_map.reshape(n_protein, n_compound).to(device)
    compound_pair_dis_constraint = torch.cdist(coords, coords)

    # Extract SMILES and generate 2D coordinates
    smiles = chosen_smiles[i]  
    mol = Chem.MolFromSmiles(smiles)
    mol.Compute2DCoords()

    # Compute LAS distance constraint mask
    LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()

    # Calculate information
    info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint,
                                  LAS_distance_constraint_mask=LAS_distance_constraint_mask,
                                  n_repeat=1, show_progress=False)

    # Save to file
    toFile = f'{dataset_path}/chosen/KIBA_tankbind_{i}.sdf' # instead of i use 'chosen.iloc(i)['protein_name']' as file name
    new_coords = info.sort_values("loss")['coords'].iloc[0].astype(np.double)
    write_with_new_coords(mol, new_coords, toFile)
    
    tankbind_list.append([mol, new_coords])


100%|██████████| 85/85 [08:43<00:00,  6.16s/it]


In [None]:
output_file = 'KIBA_tankbind.sdf'

# Write each molecule with its new coordinates to a temporary file and append to the final SDF file
with open(output_file, 'w') as final_sdf:
    for i, (mol, new_coords) in enumerate(tankbind_list):
        temp_file = f'temp_{i}.sdf'
        write_with_new_coords(mol, new_coords, temp_file)
        with open(temp_file, 'r') as temp_sdf:
            final_sdf.write(temp_sdf.read())
        os.remove(temp_file)

### Evaluation

In [55]:
def evaluate(data_loader, model, criterion, device, saveFileName=None):
    y_list = []
    y_pred_list = []
    batch_loss = 0.0
    for data in tqdm(data_loader):
        data = data.to(device)
        y_pred, _ = model(data)
        with torch.no_grad():
            loss = criterion(y_pred, data.y)
        batch_loss += len(y_pred)*loss.item()
        y_list.append(data.y)
        y_pred_list.append(y_pred.sigmoid().detach())
        # torch.cuda.empty_cache()
    y = torch.cat(y_list)
    y_pred = torch.cat(y_pred_list)
    metrics = {"loss":batch_loss/len(y_pred)}
    metrics.update(myMetric(y_pred, y))
    if saveFileName:
        torch.save((y, y_pred), saveFileName)
    return metrics

In [78]:
def myMetric(y_pred, y, threshold=0.5):
    y = y.float()
    criterion = torch.nn.BCELoss()
    with torch.no_grad():
        loss = criterion(y_pred, y)

    # y = y.long()
    y = y.bool()
    acc = torchmetrics.functional.accuracy(y_pred, y, task='binary', threshold=threshold)
    auroc = torchmetrics.functional.auroc(y_pred, y, task='binary')
    precision_0, precision_1 = torchmetrics.functional.precision(y_pred, y, task='multiclass',
                                      num_classes=2,
                                      average='none', threshold=threshold)
    recall_0, recall_1 = torchmetrics.functional.recall(y_pred, y, task='multiclass',
                                      num_classes=2,
                                      average='none', threshold=threshold)
    f1_0, f1_1 = torchmetrics.functional.f1_score(y_pred, y, task='multiclass',
                                      num_classes=2,
                                      average='none', threshold=threshold)
    return {"BCEloss":loss.item(),
            "acc":acc, "auroc":auroc, "precision_1":precision_1,
           "recall_1":recall_1, "f1_1":f1_1,"precision_0":precision_0,
           "recall_0":recall_0, "f1_0":f1_0}

In [81]:
saveFileName = f"{dataset_path}/chosen_metrics.pt"
metrics = evaluate(data_loader, model, torch.nn.MSELoss(), device, saveFileName=saveFileName)

100%|██████████| 85/85 [01:52<00:00,  1.33s/it]


15:35:19   epoch 0   , test,  loss:95.865, BCEloss: 9.708, acc: 0.006, auroc: 0.629, precision_1: 0.000, recall_1: 0.000, f1_1: 0.000, precision_0: 0.994, recall_0: 1.000, f1_0: 0.997


In [82]:
print_metrics(metrics)

'loss:95.865, BCEloss: 9.708, acc: 0.006, auroc: 0.629, precision_1: 0.000, recall_1: 0.000, f1_1: 0.000, precision_0: 0.994, recall_0: 1.000, f1_0: 0.997'