# overview

We start from the raw PDBbind dataset downloaded from http://www.pdbbind.org.cn/download.php

1. filter out those unable to process using RDKit.

2. Process the protein by only preserving the chains that with at least one atom within 10Å from any atom of the ligand.

3. Use p2rank to segment protein into blocks.

4. extract protein and ligand features.

5. construct the training and test dataset.


In [1]:
import time
from datetime import datetime as dt
def tt():
    print(time.strftime("%m-%d %H:%M:%S", time.localtime(time.time()+8*3600)))

In [2]:
CTN = True
tt()

10-10 09:59:26


In [3]:
tankbind_src_folder_path = "../tankbind/"
import sys
import torch
sys.path.insert(0, tankbind_src_folder_path)
tt()

10-10 09:59:28


In [4]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import defaultdict
tt()

10-10 09:59:28


# process the raw PDBbind dataset.

In [5]:
from utils import read_pdbbind_data
tt()

10-10 09:59:30


In [6]:
# raw PDBbind dataset could be downloaded from http://www.pdbbind.org.cn/download.php
pre = "/home/jovyan/data/pdbbind2020"
df_pdb_id = pd.read_csv(f'{pre}/index/INDEX_general_PL_name.2020', sep="  ", comment='#', header=None, names=['pdb', 'year', 'uid', 'd', 'e','f','g','h','i','j','k','l','m','n','o'], engine='python')
df_pdb_id = df_pdb_id[['pdb','uid']]
data = read_pdbbind_data(f'{pre}/index/INDEX_general_PL_data.2020')
data = data.merge(df_pdb_id, on=['pdb'])
tt()


10-10 09:59:33


In [10]:
ctn_path = "/home/jovyan/dataspace/NFT/main/CTNs"
tt()

10-10 10:15:28


# ligand file should be readable by RDKit.

In [7]:
from feature_utils import read_mol
tt()

10-10 09:59:35


In [12]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
pdb_list = []
probem_list = []
if not CTN:
    for pdb in tqdm(data.pdb):
        sdf_fileName = f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.sdf"
        mol2_fileName = f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.mol2"
        mol, problem = read_mol(sdf_fileName, mol2_fileName)
        if problem:
            probem_list.append(pdb)
            continue
        pdb_list.append(pdb)
    torch.save(pdb_list, f"{ctn_path}/pdb_list.pt")
pdb_list = torch.load(f"{ctn_path}/pdb_list.pt")
tt()

10-10 10:15:44


In [13]:
data = data.query("pdb in @pdb_list").reset_index(drop=True)
print(data.shape)
tt()

(19127, 7)
10-10 10:15:47


### for ease of RMSD evaluation later, we renumber the atom index to be consistent with the smiles

In [14]:
from feature_utils import write_renumbered_sdf
pre_main = "/home/jovyan/dataspace/NFT/main"
toFolder = f"{pre_main}/renumber_atom_index_same_as_smiles"
os.system(f"mkdir -p {toFolder}")
# for PDB_Code in lig_id_dict.keys(), the sdf in NCI dataset will be used.
lig_id_dict = torch.load("../../dataspace/NFT/main/ligand_name_dicts/ligdict_id.pt")
tt()

10-10 10:15:51


In [17]:
if CTN:
    for pdb in tqdm(pdb_list):
        if pdb not in lig_id_dict:  # Use PDBBind dataset sdf
            sdf_fileName = f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.sdf"
            mol2_fileName = f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.mol2"
            toFile = f"{toFolder}/{pdb}.sdf"
            write_renumbered_sdf(toFile, sdf_fileName, mol2_fileName)
        else:  # Use NCI dataset sdf
            sdf_fileName = f"/home/jovyan/dataspace/pdb_bind_2020/{pdb}/{pdb}_ligand.sdf"
            toFile = f"{toFolder}/{pdb}.sdf"
            write_renumbered_sdf(toFile, sdf_fileName, None)
tt()

100%|██████████| 19127/19127 [14:15<00:00, 22.36it/s]

10-10 10:30:38





# process PDBbind proteins, removing extra chains, cutoff 10A

In [33]:
toFolder = f"{pre_main}/protein_remove_extra_chains_10A/"
os.system(f"mkdir -p {toFolder}")
tt()

10-10 10:33:26


In [19]:
input_ = []
cutoff = 10
for pdb in data.pdb.values:
    pdbFile = f"{pre}/pdbbind_files/{pdb}/{pdb}_protein.pdb"
    ligandFile = f"{pre_main}/renumber_atom_index_same_as_smiles/{pdb}.sdf"
    toFile = f"{toFolder}/{pdb}_protein.pdb"
    x = (pdbFile, ligandFile, cutoff, toFile)
    input_.append(x)
tt()

10-10 10:30:38


In [20]:
from feature_utils import select_chain_within_cutoff_to_ligand_v2
tt()

10-10 10:30:39


In [21]:
if not CTN:
    import mlcrate as mlc
    import os
    pool = mlc.SuperPool(64)
    pool.pool.restart()
    _ = pool.map(select_chain_within_cutoff_to_ligand_v2,input_)
    pool.exit()
tt()

10-10 10:30:39


In [22]:
# previously, I found that 2r1w has no chain near the ligand.
data = data.query("pdb != '2r1w'").reset_index(drop=True)
tt()

10-10 10:30:39


# p2rank segmentation

In [23]:
if not CTN:
    p2rank_prediction_folder = f"{pre_main}/p2rank_protein_remove_extra_chains_10A"
    os.system(f"mkdir -p {p2rank_prediction_folder}")
    ds = f"{p2rank_prediction_folder}/protein_list.ds"
    with open(ds, "w") as out:
        for pdb in data.pdb.values:
            out.write(f"../protein_remove_extra_chains_10A/{pdb}_protein.pdb\n")
tt()

10-10 10:30:39


In [24]:
if not CTN:
    # takes about 30 minutes.
    p2rank = "bash /home/jovyan/p2rank_2.3/prank"
    cmd = f"{p2rank} predict {ds} -o {p2rank_prediction_folder}/p2rank -threads 16"
    os.system(cmd)
tt()

10-10 10:30:40


In [25]:
if not CTN:
    data.to_csv(f"{pre_main}/data.csv")
tt()

10-10 10:30:40


### Continue From Here

In [26]:
data = pd.read_csv(f"{pre_main}/data.csv")
print(data.shape)
tt()

(19126, 8)
10-10 10:30:40


In [27]:
pdb_list = data.pdb.values
tt()

10-10 10:30:40


In [28]:
tankbind_data_path = f"{pre_main}/tankbind_data"
os.system(f"mkdir -p {tankbind_data_path}")

0

In [29]:
name_list = pdb_list
d_list = []

if not CTN:
    for name in tqdm(name_list):
        p2rankFile = f"{pre_main}/p2rank_protein_remove_extra_chains_10A/p2rank/{name}_protein.pdb_predictions.csv"
        d = pd.read_csv(p2rankFile)
        d.columns = d.columns.str.strip()
        d_list.append(d.assign(name=name))
    d = pd.concat(d_list).reset_index(drop=True)
    d.reset_index(drop=True).to_feather(f"{tankbind_data_path}/p2rank_result.feather")

In [30]:
d = pd.read_feather(f"{tankbind_data_path}/p2rank_result.feather")
tt()

10-10 10:30:41


In [31]:
pockets_dict = {}
for name in tqdm(name_list):
    pockets_dict[name] = d[d.name == name].reset_index(drop=True)
tt()

100%|██████████| 19126/19126 [02:41<00:00, 118.11it/s]

10-10 10:33:23





In [34]:
if not CTN:
    torch.save(pockets_dict, f"{tankbind_data_path}/pocket_dict.pt")
pockets_dict = torch.load(f"{tankbind_data_path}/pocket_dict.pt")
tt()

10-10 10:35:17


# protein feature

In [35]:
from feature_utils import get_protein_feature
from feature_utils import nciyes_get_protein_feature
tt()

10-10 10:35:20


In [36]:
input_ = []
protein_embedding_folder = f"{tankbind_data_path}/gvp_protein_embedding"
os.system(f"mkdir -p {protein_embedding_folder}")
for pdb in pdb_list:
    proteinFile = f"{pre_main}/protein_remove_extra_chains_10A/{pdb}_protein.pdb"
    toFile = [f"{protein_embedding_folder}/{pdb}.pt", f"{protein_embedding_folder}/{pdb}_id.pt"]
    x = (pdb, proteinFile, toFile)
    input_.append(x)
tt()

10-10 10:35:24


In [37]:
def get_full_id(full_id_ls: list, resname):
    chain_id = full_id_ls[2]
    res_id = full_id_ls[3][1]
    return chain_id + "_" + str(res_id) + "_" + resname
tt()

10-10 10:35:30


In [38]:
from Bio.PDB import PDBParser
from feature_utils import get_clean_res_list
import torch
torch.set_num_threads(1)

def batch_run(x):
    protein_dict = {}
    protein_id_dict = {}
    pdb, proteinFile, toFile = x
    parser = PDBParser(QUIET=True)
    s = parser.get_structure(pdb, proteinFile)
    res_list = get_clean_res_list(s.get_residues(), verbose=False, ensure_ca_exist=True)
    res_full_id_list = [get_full_id(x.full_id, x.get_resname()) for x in res_list]
    protein_dict[pdb], protein_id_dict[pdb] = nciyes_get_protein_feature(res_list, res_full_id_list)
    torch.save(protein_dict, toFile[0])
    torch.save(protein_id_dict, toFile[1])
tt()

10-10 10:35:35


In [39]:
if not CTN:
    import mlcrate as mlc
    import os
    pool = mlc.SuperPool(64)
    pool.pool.restart()
    _ = pool.map(batch_run,input_)
    pool.exit()
tt()

10-10 10:35:39


In [40]:
import torch
protein_dict = {}
protein_id_dict = {}
for pdb in tqdm(pdb_list):
    protein_dict.update(torch.load(f"{protein_embedding_folder}/{pdb}.pt"))
    protein_id_dict.update(torch.load(f"{protein_embedding_folder}/{pdb}_id.pt"))
tt()

100%|██████████| 19126/19126 [12:50<00:00, 24.81it/s]

10-10 10:48:38





# Compound Features

In [41]:
from feature_utils import extract_torchdrug_feature_from_mol
if not CTN:
    compound_dict = {}
    skip_pdb_list = []
    for pdb in tqdm(pdb_list):
        mol, _ = read_mol(f"{pre_main}/renumber_atom_index_same_as_smiles/{pdb}.sdf", None)
        # extract features from sdf.
        try:
            compound_dict[pdb] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)  # self-dock set has_LAS_mask to true
        except Exception as e:
            print(e)
            skip_pdb_list.append(pdb)
            print(pdb)
    torch.save(compound_dict, f"{tankbind_data_path}/compound_torchdrug_features.pt")
tt()

 28%|██▊       | 5270/19126 [03:53<13:12, 17.48it/s]


3kqs


100%|██████████| 19126/19126 [15:41<00:00, 20.32it/s]

10-10 11:04:22





In [43]:
compound_dict = torch.load(f"{tankbind_data_path}/compound_torchdrug_features.pt")
skip_pdb_list = ['3kqs']
tt()

In [45]:
data = data.query("pdb not in @skip_pdb_list").reset_index(drop=True)
tt()

10-10 11:05:54


### NCI Features

In [46]:
lig_nameid_dict = torch.load("../../dataspace/NFT/main/ligand_name_dicts/ligdict_nameid.pt")
lig_id_dict = torch.load("../../dataspace/NFT/main/ligand_name_dicts/ligdict_id.pt")
NCI_df = pd.read_csv("/home/jovyan/dataspace/processed_nci_data/v10.8/Data.v10.8.NCIs.csv")
tt()

10-10 11:06:04


In [52]:
# for pdb in tqdm(pdb_list):
from collections import defaultdict
from tqdm import tqdm
if not CTN:
    nci_info = {}
    nci_save_path = "/home/jovyan/dataspace/NFT/main/nci_protein_ligand_matrix"
    # os.system(f"rm -rf {nci_save_path}")
    os.system(f"mkdir -p {nci_save_path}")
    for pdb in tqdm(pdb_list):
        #print(pdb)
        if True:
            if (pdb not in lig_nameid_dict) and (pdb not in lig_id_dict):
                nci_info[pdb] = "BAD_NoLigName"
                continue
                
            
            # Lig
            mol, _ = read_mol(f"{pre_main}/renumber_atom_index_same_as_smiles/{pdb}.sdf", None)
            
            # Protein
            proteinFile = f"{pre_main}/protein_remove_extra_chains_10A/{pdb}_protein.pdb"
            toFile = [f"{protein_embedding_folder}/{pdb}.pt", f"{protein_embedding_folder}/{pdb}_id.pt"]
            ebd = torch.load(toFile[0])[pdb]
            protein_ids = torch.load(toFile[1])[pdb]
            
            # inverse protein_ids:
            protein_id_inverse = defaultdict(list)
            for _id, _name in protein_ids.items():
                protein_id_inverse[_name[0]].append(_id)
            
            protein_id_inverse = {k:v for k, v in protein_id_inverse.items()}
            
            # inverse lig_ids:
            
            if pdb in lig_nameid_dict:
                liglen = len(lig_nameid_dict[pdb])
                lig_id_inverse = {_value[1]:_key for _key, _value in lig_nameid_dict[pdb].items()}
            elif pdb in lig_id_dict:
                liglen = len(lig_id_dict[pdb])
                lig_id_inverse = {_value:_key for _key, _value in lig_id_dict[pdb].items()}
                
            ncis = NCI_df[NCI_df.PDB_Code==pdb]
            if len(ncis) == 0:
                nci_info[pdb] = "BAD_EmptyNCI"
            
            pl_0 = ebd[0].shape[0]
            pl_1 = len(protein_ids)
            ll_0 = mol.GetNumAtoms()
            ll_1 = liglen
            if pl_0 != pl_1:
                nci_info[pdb] = "BAD_ResNum"
                continue
            elif ll_0 != ll_1:
                nci_info[pdb] = "BAD_AtomNum"
                continue
            
            nci_matrix = torch.zeros(pl_0, ll_0)
            nci_matrix_0 = torch.zeros(pl_0, ll_0)
            
            if pdb in nci_info:
                torch.save({pdb:nci_matrix_0}, f"{nci_save_path}/{pdb}_nci_matrix.pt")
                continue
            
            for line in ncis.iterrows():
                ResFullID, l_id = line[1]['ResFullID'], line[1]['l_id']
                
                try:
                    p_index = protein_id_inverse[ResFullID]
                except:
                    nci_info[pdb] = "BAD_ResKey"
                    torch.save({pdb:nci_matrix_0}, f"{nci_save_path}/{pdb}_nci_matrix.pt")
                    break
                try:
                    l_index = lig_id_inverse[l_id]
                except:
                    nci_info[pdb] = "BAD_AtomKey"
                    torch.save({pdb:nci_matrix_0}, f"{nci_save_path}/{pdb}_nci_matrix.pt")
                    break
                
                if len(p_index) > 1:
                    nci_info[pdb] = "BAD_ResIncoherant"
                    torch.save({pdb:nci_matrix_0}, f"{nci_save_path}/{pdb}_nci_matrix.pt")
                    break
                else:
                    nci_matrix[p_index[0]][l_index] = 1
                    
            if pdb in nci_info:
                continue
            else:
                nci_info[pdb] = "YES"
                torch.save({pdb:nci_matrix}, f"{nci_save_path}/{pdb}_nci_matrix.pt")
    torch.save(nci_info, "/home/jovyan/dataspace/NFT/main/nci_info.pt")
tt()

10-10 11:56:44


In [53]:
for k, v in nci_info.items():
    if v == "BAD_ResNum" or v == "BAD_AtomNum":
        print(k)
tt()

10-10 11:56:46


In [55]:
nci_info = torch.load("/home/jovyan/dataspace/NFT/main/nci_info.pt")
tt()

10-10 11:56:58


In [56]:
nci_info_stat = defaultdict(int)
for k,v in nci_info.items():
    nci_info_stat[v] += 1
nci_info_stat

defaultdict(int,
            {'YES': 10727,
             'BAD_NoLigName': 5680,
             'BAD_EmptyNCI': 2597,
             'BAD_ResKey': 119,
             'BAD_ResIncoherant': 3})

In [60]:
nci_dict = {}
if not CTN:
    for pdb in tqdm(pdb_list):
        if pdb not in nci_info or nci_info[pdb] != "YES":
            continue
        else:
            nci_dict.update(torch.load(f"{nci_save_path}/{pdb}_nci_matrix.pt"))
    torch.save(nci_dict, f"{tankbind_data_path}/nci_matrix.pt")
tt()

10-10 12:04:32


In [61]:
nci_dict = torch.load(f"{tankbind_data_path}/nci_matrix.pt")
tt()

10-10 12:04:36


# construct dataset.

In [63]:
# we use the time-split defined in EquiBind paper.
# https://github.com/HannesStark/EquiBind/tree/main/data
valid = np.loadtxt("/home/jovyan/data/Equiband/timesplit_no_lig_overlap_val", dtype=str)
test = np.loadtxt("/home/jovyan/data/Equiband/timesplit_test", dtype=str)
def assign_group(pdb, valid=valid, test=test):
    if pdb in valid:
        return 'valid'
    if pdb in test:
        return 'test'
    return 'train'

data['group'] = data.pdb.map(assign_group)
tt()

10-10 12:04:43


In [65]:
print(data.value_counts("group"))
tt()

group
train    17794
valid      968
test       363
dtype: int64
10-10 12:04:50


In [67]:
data['name'] = data['pdb']
tt()

10-10 12:04:57


In [72]:
if CTN:
    info = []
    for i, line in tqdm(data.iterrows(), total=data.shape[0]):
        pdb = line['pdb']
        uid = line['uid']
        # smiles = line['smiles']
        smiles = ""
        affinity = line['affinity']
        group = line['group']

        compound_name = line['name']
        protein_name = line['name']

        pocket = pockets_dict[pdb].head(10)
        pocket.columns = pocket.columns.str.strip()
        pocket_coms = pocket[['center_x', 'center_y', 'center_z']].values
        
        has_nci = (nci_info[protein_name] == "YES")
        compound_sdf_from_nci = (protein_name in lig_id_dict.keys())
        # native block.
        info.append([protein_name, compound_name, pdb, smiles, affinity, uid, None, True, False, group, has_nci, compound_sdf_from_nci])
        # protein center as a block.
        protein_com = protein_dict[protein_name][0].numpy().mean(axis=0).astype(float).reshape(1, 3)
        info.append([protein_name, compound_name, pdb+"_c", smiles, affinity, uid, protein_com, False, False, group, has_nci, compound_sdf_from_nci])
        
        for idx, pocket_line in pocket.iterrows():
            pdb_idx = f"{pdb}_{idx}"
            info.append([protein_name, compound_name, pdb_idx, smiles, affinity, uid, pocket_coms[idx].reshape(1, 3), False, False, group, has_nci, compound_sdf_from_nci])
    info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'pdb', 'smiles', 'affinity', 'uid', 'pocket_com', 
                                    'use_compound_com', 'use_whole_protein',
                                    'group', 'has_nci_info', 'compound_sdf_from_nci'])
    info.to_csv(f"{tankbind_data_path}/info.csv")
info = pd.read_csv(f"{tankbind_data_path}/info.csv")
tt()



100%|██████████| 19125/19125 [00:22<00:00, 866.22it/s]


10-10 12:15:28


In [75]:
info.shape
tt()

10-10 12:23:53


In [76]:
info
tt()

10-10 12:23:56


In [43]:
from data_nci import 

In [None]:
toFilePre = f"{pre}/dataset"
os.system(f"mkdir -p {toFilePre}")
dataset = TankBindDataSet(toFilePre, data=info, protein_dict=protein_dict, compound_dict=compound_dict)

In [None]:
dataset = TankBindDataSet(toFilePre)


In [45]:
t = []
data = dataset.data
pre_pdb = None
for i, line in tqdm(data.iterrows(), total=data.shape[0]):
    pdb = line['compound_name']
    d = dataset[i]
    p_length = d['node_xyz'].shape[0]
    c_length = d['coords'].shape[0]
    y_length = d['y'].shape[0]
    num_contact = (d.y > 0).sum()
    t.append([i, pdb, p_length, c_length, y_length, num_contact])



100%|██████████| 161940/161940 [11:33<00:00, 233.55it/s]


In [38]:
# data = data.drop(['p_length', 'c_length', 'y_length', 'num_contact'], axis=1)

In [46]:
t = pd.DataFrame(t, columns=['index', 'pdb' ,'p_length', 'c_length', 'y_length', 'num_contact'])
t['num_contact'] = t['num_contact'].apply(lambda x: x.item())

In [47]:
data = pd.concat([data, t[['p_length', 'c_length', 'y_length', 'num_contact']]], axis=1)

In [48]:
native_num_contact = data.query("use_compound_com").set_index("protein_name")['num_contact'].to_dict()
data['native_num_contact'] = data.protein_name.map(native_num_contact)
# data['fract_of_native_contact'] = data['num_contact'] / data['native_num_contact']

In [49]:
torch.save(data, f"{toFilePre}/processed/data.pt")

In [51]:
info = torch.load(f"{toFilePre}/processed/data.pt")


In [52]:
test = info.query("group == 'test'").reset_index(drop=True)
test_pdb_list = info.query("group == 'test'").protein_name.unique()

In [53]:
subset_protein_dict = {}
for pdb in tqdm(test_pdb_list):
    subset_protein_dict[pdb] = protein_dict[pdb]

100%|██████████| 363/363 [00:00<00:00, 251866.39it/s]


In [54]:
subset_compound_dict = {}
for pdb in tqdm(test_pdb_list):
    subset_compound_dict[pdb] = compound_dict[pdb]

100%|██████████| 363/363 [00:00<00:00, 182208.28it/s]


In [None]:

toFilePre = f"{pre}/test_dataset"
os.system(f"mkdir -p {toFilePre}")
dataset = TankBindDataSet(toFilePre, data=test, protein_dict=subset_protein_dict, compound_dict=subset_compound_dict)

In [None]:
def canonical_smiles(smiles):
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles))