# overview

We start from the raw PDBbind dataset downloaded from http://www.pdbbind.org.cn/download.php

1. filter out those unable to process using RDKit.

2. Process the protein by only preserving the chains that with at least one atom within 10Å from any atom of the ligand.

3. Use p2rank to segment protein into blocks.

4. extract protein and ligand features.

5. construct the training and test dataset.


In [1]:
CTN = True

In [2]:
tankbind_src_folder_path = "../tankbind/"
import sys
import torch
sys.path.insert(0, tankbind_src_folder_path)

In [3]:
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import defaultdict

# process the raw PDBbind dataset.

In [4]:
from utils import read_pdbbind_data

In [5]:
# raw PDBbind dataset could be downloaded from http://www.pdbbind.org.cn/download.php
pre = "/home/jovyan/data/pdbbind2020"
df_pdb_id = pd.read_csv(f'{pre}/index/INDEX_general_PL_name.2020', sep="  ", comment='#', header=None, names=['pdb', 'year', 'uid', 'd', 'e','f','g','h','i','j','k','l','m','n','o'], engine='python')
df_pdb_id = df_pdb_id[['pdb','uid']]
data = read_pdbbind_data(f'{pre}/index/INDEX_general_PL_data.2020')
data = data.merge(df_pdb_id, on=['pdb'])


# ligand file should be readable by RDKit.

In [6]:
from feature_utils import read_mol

In [7]:
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
pdb_list = []
probem_list = []
for pdb in tqdm(data.pdb):
    sdf_fileName = f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.sdf"
    mol2_fileName = f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.mol2"
    mol, problem = read_mol(sdf_fileName, mol2_fileName)
    if problem:
        probem_list.append(pdb)
        continue
    pdb_list.append(pdb)

100%|██████████| 19442/19442 [00:52<00:00, 372.17it/s]


In [8]:
data = data.query("pdb in @pdb_list").reset_index(drop=True)

In [9]:
data.shape

(19127, 7)

### for ease of RMSD evaluation later, we renumber the atom index to be consistent with the smiles

In [10]:
from feature_utils import write_renumbered_sdf
pre_main = "/home/jovyan/dataspace/NFT/main"
toFolder = f"{pre_main}/renumber_atom_index_same_as_smiles"
os.system(f"mkdir -p {toFolder}")

0

In [11]:
if not CTN:
    
    for pdb in tqdm(pdb_list):
        sdf_fileName = f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.sdf"
        mol2_fileName = f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.mol2"
        toFile = f"{toFolder}/{pdb}.sdf"
        write_renumbered_sdf(toFile, sdf_fileName, mol2_fileName)


In [46]:
from feature_utils import write_renumbered_sdf
pre_main = "/home/jovyan/dataspace/NFT/main"
toFolder = f"{pre_main}/renumber_atom_index_same_as_smiles_from_NCI_sdf"
os.system(f"mkdir -p {toFolder}")

0

In [53]:
if False:
    import shutil
    os.system("mkdir -p /home/jovyan/dataspace/Predownloads/WZCase2")
    os.system("mkdir -p /home/jovyan/dataspace/Predownloads/PDBCase2")
    for pdb in lig_id_dict.keys():
        shutil.copy(f"/home/jovyan/dataspace/pdb_bind_2020/{pdb}/{pdb}_ligand.sdf", f"/home/jovyan/dataspace/Predownloads/WZCase2/{pdb}_NCI.sdf")
        shutil.copy(f"{pre}/pdbbind_files/{pdb}/{pdb}_ligand.sdf", f"/home/jovyan/dataspace/Predownloads/PDBCase2/{pdb}_PDBBind.sdf")

# process PDBbind proteins, removing extra chains, cutoff 10A

In [12]:
toFolder = f"{pre_main}/protein_remove_extra_chains_10A/"
os.system(f"mkdir -p {toFolder}")

0

In [13]:
input_ = []
cutoff = 10
for pdb in data.pdb.values:
    pdbFile = f"{pre}/pdbbind_files/{pdb}/{pdb}_protein.pdb"
    ligandFile = f"{pre_main}/renumber_atom_index_same_as_smiles/{pdb}.sdf"
    toFile = f"{toFolder}/{pdb}_protein.pdb"
    x = (pdbFile, ligandFile, cutoff, toFile)
    input_.append(x)

In [14]:
from feature_utils import select_chain_within_cutoff_to_ligand_v2

In [15]:
if not CTN:
    import mlcrate as mlc
    import os
    pool = mlc.SuperPool(64)
    pool.pool.restart()
    _ = pool.map(select_chain_within_cutoff_to_ligand_v2,input_)
    pool.exit()

In [16]:
# previously, I found that 2r1w has no chain near the ligand.
data = data.query("pdb != '2r1w'").reset_index(drop=True)

# p2rank segmentation

In [17]:
if not CTN:
    p2rank_prediction_folder = f"{pre_main}/p2rank_protein_remove_extra_chains_10A"
    os.system(f"mkdir -p {p2rank_prediction_folder}")
    ds = f"{p2rank_prediction_folder}/protein_list.ds"
    with open(ds, "w") as out:
        for pdb in data.pdb.values:
            out.write(f"../protein_remove_extra_chains_10A/{pdb}_protein.pdb\n")

In [18]:
if not CTN:
    # takes about 30 minutes.
    p2rank = "bash /home/jovyan/p2rank_2.3/prank"
    cmd = f"{p2rank} predict {ds} -o {p2rank_prediction_folder}/p2rank -threads 16"
    os.system(cmd)

In [19]:
if not CTN:
    data.to_csv(f"{pre_main}/data.csv")

### Continue From Here

In [20]:
data = pd.read_csv(f"{pre_main}/data.csv")

In [21]:
pdb_list = data.pdb.values

In [22]:
tankbind_data_path = f"{pre_main}/tankbind_data"
os.system(f"mkdir -p {tankbind_data_path}")

0

In [23]:
name_list = pdb_list
d_list = []

if not CTN:
    for name in tqdm(name_list):
        p2rankFile = f"{pre_main}/p2rank_protein_remove_extra_chains_10A/p2rank/{name}_protein.pdb_predictions.csv"
        d = pd.read_csv(p2rankFile)
        d.columns = d.columns.str.strip()
        d_list.append(d.assign(name=name))
    d = pd.concat(d_list).reset_index(drop=True)
    d.reset_index(drop=True).to_feather(f"{tankbind_data_path}/p2rank_result.feather")

In [24]:
d = pd.read_feather(f"{tankbind_data_path}/p2rank_result.feather")

In [25]:
pockets_dict = {}
for name in tqdm(name_list):
    pockets_dict[name] = d[d.name == name].reset_index(drop=True)



100%|██████████| 19126/19126 [02:30<00:00, 126.79it/s]


In [26]:
if not CTN:
    torch.save(pockets_dict, f"{tankbind_data_path}/pocket_dict.pt")
    
    
pockets_dict = torch.load(f"{tankbind_data_path}/pocket_dict.pt")

# protein feature

In [27]:
from feature_utils import get_protein_feature
from feature_utils import nciyes_get_protein_feature


In [28]:
input_ = []
protein_embedding_folder = f"{tankbind_data_path}/gvp_protein_embedding"
os.system(f"mkdir -p {protein_embedding_folder}")
for pdb in pdb_list:
    proteinFile = f"{pre_main}/protein_remove_extra_chains_10A/{pdb}_protein.pdb"
    toFile = [f"{protein_embedding_folder}/{pdb}.pt", f"{protein_embedding_folder}/{pdb}_id.pt"]
    x = (pdb, proteinFile, toFile)
    input_.append(x)

In [29]:
def get_full_id(full_id_ls: list, resname):
    chain_id = full_id_ls[2]
    res_id = full_id_ls[3][1]
    return chain_id + "_" + str(res_id) + "_" + resname

In [30]:
from Bio.PDB import PDBParser
from feature_utils import get_clean_res_list
import torch
torch.set_num_threads(1)

def batch_run(x):
    protein_dict = {}
    protein_id_dict = {}
    pdb, proteinFile, toFile = x
    parser = PDBParser(QUIET=True)
    s = parser.get_structure(pdb, proteinFile)
    res_list = get_clean_res_list(s.get_residues(), verbose=False, ensure_ca_exist=True)
    res_full_id_list = [get_full_id(x.full_id, x.get_resname()) for x in res_list]
    protein_dict[pdb], protein_id_dict[pdb] = nciyes_get_protein_feature(res_list, res_full_id_list)
    torch.save(protein_dict, toFile[0])
    torch.save(protein_id_dict, toFile[1])

In [31]:
if not CTN:
    import mlcrate as mlc
    import os
    pool = mlc.SuperPool(64)
    pool.pool.restart()
    _ = pool.map(batch_run,input_)
    pool.exit()

In [32]:
import torch
protein_dict = {}
protein_id_dict = {}
for pdb in tqdm(pdb_list):
    protein_dict.update(torch.load(f"{protein_embedding_folder}/{pdb}.pt"))
    protein_id_dict.update(torch.load(f"{protein_embedding_folder}/{pdb}_id.pt")) 

100%|██████████| 19126/19126 [01:36<00:00, 199.03it/s]


# Compound Features

In [33]:
from feature_utils import extract_torchdrug_feature_from_mol


if not CTN:
    compound_dict = {}
    skip_pdb_list = []
    for pdb in tqdm(pdb_list):
        mol, _ = read_mol(f"{pre_main}/renumber_atom_index_same_as_smiles/{pdb}.sdf", None)
        # extract features from sdf.
        try:
            compound_dict[pdb] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)  # self-dock set has_LAS_mask to true
        except Exception as e:
            print(e)
            skip_pdb_list.append(pdb)
            print(pdb)

In [34]:
if not CTN:
    torch.save(compound_dict, f"{tankbind_data_path}/compound_torchdrug_features.pt")

In [35]:
compound_dict = torch.load(f"{tankbind_data_path}/compound_torchdrug_features.pt")

In [36]:
skip_pdb_list = ['3kqs']

In [37]:
skip_pdb_list

['3kqs']

In [38]:
data = data.query("pdb not in @skip_pdb_list").reset_index(drop=True)

### NCI Features

In [39]:
lig_nameid_dict = torch.load("../../dataspace/NFT/main/ligand_name_dicts/ligdict_nameid.pt")
lig_id_dict = torch.load("../../dataspace/NFT/main/ligand_name_dicts/ligdict_id.pt")

In [40]:
NCI_df = pd.read_csv("/home/jovyan/dataspace/processed_nci_data/v10.8/Data.v10.8.NCIs.csv")

In [41]:
# for pdb in tqdm(pdb_list):
from collections import defaultdict
from tqdm import tqdm
if not CTN:
    nci_info = {}
    nci_save_path = "/home/jovyan/dataspace/NFT/main/nci_protein_ligand_matrix"
    # os.system(f"rm -rf {nci_save_path}")
    os.system(f"mkdir -p {nci_save_path}")
    for pdb in tqdm(pdb_list):
        #print(pdb)
        if True:
            if (pdb not in lig_nameid_dict) and (pdb not in lig_id_dict):
                nci_info[pdb] = "BAD_NoLigName"


            # Lig
            mol, _ = read_mol(f"{pre_main}/renumber_atom_index_same_as_smiles/{pdb}.sdf", None)
            
            # Protein
            proteinFile = f"{pre_main}/protein_remove_extra_chains_10A/{pdb}_protein.pdb"
            toFile = [f"{protein_embedding_folder}/{pdb}.pt", f"{protein_embedding_folder}/{pdb}_id.pt"]
            ebd = torch.load(toFile[0])[pdb]
            protein_ids = torch.load(toFile[1])[pdb]
            
            # inverse protein_ids:
            protein_id_inverse = defaultdict(list)
            for _id, _name in protein_ids.items():
                protein_id_inverse[_name[0]].append(_id)
            
            protein_id_inverse = {k:v for k, v in protein_id_inverse.items()}
            
            # inverse lig_ids:
            
            if pdb in lig_nameid_dict:
                liglen = len(lig_nameid_dict[pdb])
                lig_id_inverse = {_value[1]:_key for _key, _value in lig_nameid_dict[pdb].items()}
            elif pdb in lig_id_dict:
                liglen = len(lig_id_dict[pdb])
                lig_id_inverse = {_value:_key for _key, _value in lig_id_dict[pdb].items()}
                

            
            pl_0 = ebd[0].shape[0]
            pl_1 = len(protein_ids)
            ll_0 = mol.GetNumAtoms()
            ll_1 = liglen

            nci_matrix = torch.zeros(pl_0, ll_0)

            if pl_0 != pl_1:
                nci_info[pdb] = "BAD_ResNum"
                continue
            elif ll_0 != ll_1:
                nci_info[pdb] = "BAD_AtomNum"
                continue

            ncis = NCI_df[NCI_df.PDB_Code==pdb]
            if len(ncis) == 0:
                nci_info[pdb] = "BAD_EmptyNCI"

            nci_matrix = torch.zeros(pl_0, ll_0)
            
            for line in ncis.iterrows():
                ResFullID, l_id = line[1]['ResFullID'], line[1]['l_id']
                
                try:
                    p_index = protein_id_inverse[ResFullID]
                except:
                    nci_info[pdb] = "BAD_ResKey"
                    break

                try:
                    l_index = lig_id_inverse[l_id]
                except:
                    nci_info[pdb] = "BAD_AtomKey"
                    break
                
                if len(p_index) > 1:
                    nci_info[pdb] = "BAD_ResIncoherant"
                    break
                else:
                    nci_matrix[p_index[0]][l_index] = 1
                    
            if pdb in nci_info:
                continue
            else:
                nci_info[pdb] = "YES"
                torch.save({pdb:nci_matrix}, f"{nci_save_path}/{pdb}_nci_matrix.pt")

In [42]:
if not CTN:
    torch.save(nci_info, "/home/jovyan/dataspace/NFT/main/nci_info.pt")

nci_info = torch.load("/home/jovyan/dataspace/NFT/main/nci_info.pt")

In [43]:
nci_info_stat = defaultdict(int)
for k,v in nci_info.items():
    nci_info_stat[v] += 1
nci_info_stat

defaultdict(int,
            {'YES': 10727,
             'BAD_NoLigName': 5680,
             'BAD_EmptyNCI': 2597,
             'BAD_ResKey': 119,
             'BAD_ResIncoherant': 3})

In [44]:
nci_dict = {}
if not CTN:
    for pdb in tqdm(pdb_list):
        if pdb not in nci_info or nci_info[pdb] != "YES":
            continue
        else:
            nci_dict.update(torch.load(f"{nci_save_path}/{pdb}_nci_matrix.pt"))

In [45]:
if not CTN:
    torch.save(nci_dict, f"{tankbind_data_path}/nci_matrix.pt")

nci_dict = torch.load(f"{tankbind_data_path}/nci_matrix.pt")

# construct dataset.

In [60]:
# we use the time-split defined in EquiBind paper.
# https://github.com/HannesStark/EquiBind/tree/main/data
valid = np.loadtxt("/home/jovyan/data/Equiband/timesplit_no_lig_overlap_val", dtype=str)
test = np.loadtxt("/home/jovyan/data/Equiband/timesplit_test", dtype=str)
def assign_group(pdb, valid=valid, test=test):
    if pdb in valid:
        return 'valid'
    if pdb in test:
        return 'test'
    return 'train'

data['group'] = data.pdb.map(assign_group)

In [61]:
data.value_counts("group")

group
train    17794
valid      968
test       363
dtype: int64

In [62]:
data['name'] = data['pdb']

In [72]:
info = []
for i, line in tqdm(data.iterrows(), total=data.shape[0]):
    pdb = line['pdb']
    uid = line['uid']
    # smiles = line['smiles']
    smiles = ""
    affinity = line['affinity']
    group = line['group']

    compound_name = line['name']
    protein_name = line['name']

    pocket = pockets_dict[pdb].head(10)
    pocket.columns = pocket.columns.str.strip()
    pocket_coms = pocket[['center_x', 'center_y', 'center_z']].values
    
    has_nci = (nci_info[protein_name] == "YES")
    compound_sdf_from_nci = (protein_name in lig_id_dict.keys())
    # native block.
    info.append([protein_name, compound_name, pdb, smiles, affinity, uid, None, True, False, group, has_nci, compound_sdf_from_nci])
    # protein center as a block.
    protein_com = protein_dict[protein_name][0].numpy().mean(axis=0).astype(float).reshape(1, 3)
    info.append([protein_name, compound_name, pdb+"_c", smiles, affinity, uid, protein_com, False, False, group, has_nci, compound_sdf_from_nci])
    
    for idx, pocket_line in pocket.iterrows():
        pdb_idx = f"{pdb}_{idx}"
        info.append([protein_name, compound_name, pdb_idx, smiles, affinity, uid, pocket_coms[idx].reshape(1, 3), False, False, group, has_nci, compound_sdf_from_nci])
info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'pdb', 'smiles', 'affinity', 'uid', 'pocket_com', 
                                   'use_compound_com', 'use_whole_protein',
                                  'group', 'has_nci_info', 'compound_sdf_from_nci'])



100%|██████████| 19125/19125 [00:20<00:00, 936.43it/s]


In [73]:
info.shape

(162029, 12)

In [74]:
info

Unnamed: 0,protein_name,compound_name,pdb,smiles,affinity,uid,pocket_com,use_compound_com,use_whole_protein,group,has_nci_info,compound_sdf_from_nci
0,3zzf,3zzf,3zzf,,0.40,Q01217,,True,False,train,True,False
1,3zzf,3zzf,3zzf_c,,0.40,Q01217,"[[5.51331901550293, 36.50146484375, 14.4291219...",False,False,train,True,False
2,3zzf,3zzf,3zzf_0,,0.40,Q01217,"[[9.2232, 36.6453, 4.2458]]",False,False,train,True,False
3,3zzf,3zzf,3zzf_1,,0.40,Q01217,"[[-3.9652, 36.9019, 2.8611]]",False,False,train,True,False
4,3zzf,3zzf,3zzf_2,,0.40,Q01217,"[[16.5628, 39.1406, 26.3637]]",False,False,train,True,False
...,...,...,...,...,...,...,...,...,...,...,...,...
162024,2avi,2avi,2avi_5,,15.22,P02701,"[[-24.8819, 33.8811, 24.9982]]",False,False,train,False,False
162025,2avi,2avi,2avi_6,,15.22,P02701,"[[5.0382, 35.5432, 16.4793]]",False,False,train,False,False
162026,2avi,2avi,2avi_7,,15.22,P02701,"[[-4.7665, 15.8424, 22.5071]]",False,False,train,False,False
162027,2avi,2avi,2avi_8,,15.22,P02701,"[[4.7665, 64.5276, 22.5071]]",False,False,train,False,False


In [43]:
from data_nci import 

In [None]:
toFilePre = f"{pre}/dataset"
os.system(f"mkdir -p {toFilePre}")
dataset = TankBindDataSet(toFilePre, data=info, protein_dict=protein_dict, compound_dict=compound_dict)

In [None]:
dataset = TankBindDataSet(toFilePre)


In [45]:
t = []
data = dataset.data
pre_pdb = None
for i, line in tqdm(data.iterrows(), total=data.shape[0]):
    pdb = line['compound_name']
    d = dataset[i]
    p_length = d['node_xyz'].shape[0]
    c_length = d['coords'].shape[0]
    y_length = d['y'].shape[0]
    num_contact = (d.y > 0).sum()
    t.append([i, pdb, p_length, c_length, y_length, num_contact])



100%|██████████| 161940/161940 [11:33<00:00, 233.55it/s]


In [38]:
# data = data.drop(['p_length', 'c_length', 'y_length', 'num_contact'], axis=1)

In [46]:
t = pd.DataFrame(t, columns=['index', 'pdb' ,'p_length', 'c_length', 'y_length', 'num_contact'])
t['num_contact'] = t['num_contact'].apply(lambda x: x.item())

In [47]:
data = pd.concat([data, t[['p_length', 'c_length', 'y_length', 'num_contact']]], axis=1)

In [48]:
native_num_contact = data.query("use_compound_com").set_index("protein_name")['num_contact'].to_dict()
data['native_num_contact'] = data.protein_name.map(native_num_contact)
# data['fract_of_native_contact'] = data['num_contact'] / data['native_num_contact']

In [49]:
torch.save(data, f"{toFilePre}/processed/data.pt")

In [51]:
info = torch.load(f"{toFilePre}/processed/data.pt")


In [52]:
test = info.query("group == 'test'").reset_index(drop=True)
test_pdb_list = info.query("group == 'test'").protein_name.unique()

In [53]:
subset_protein_dict = {}
for pdb in tqdm(test_pdb_list):
    subset_protein_dict[pdb] = protein_dict[pdb]

100%|██████████| 363/363 [00:00<00:00, 251866.39it/s]


In [54]:
subset_compound_dict = {}
for pdb in tqdm(test_pdb_list):
    subset_compound_dict[pdb] = compound_dict[pdb]

100%|██████████| 363/363 [00:00<00:00, 182208.28it/s]


In [None]:

toFilePre = f"{pre}/test_dataset"
os.system(f"mkdir -p {toFilePre}")
dataset = TankBindDataSet(toFilePre, data=test, protein_dict=subset_protein_dict, compound_dict=subset_compound_dict)

In [None]:
def canonical_smiles(smiles):
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles))