In [2]:
import os, time, pickle, sys, math
import numpy as np
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
from pymol import cmd
from sklearn.metrics import pairwise_distances, pairwise_distances_argmin_min
from sklearn.cluster import AgglomerativeClustering

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.ML.Cluster import Butina
from scipy.cluster.hierarchy import fcluster, linkage, single
from scipy.spatial.distance import pdist

In [3]:
import pymesh

### 0. Define dataset parameters

In [4]:
da = 6
outdir = './PocketDetectionData_HOLOplus_da{}/'.format(da)
if not os.path.exists(outdir):
    os.mkdir(outdir)

### 1. Load dataset files

In [5]:
table = pd.read_csv('/data/lishuya/lab/PocketAnchor/Revise1_new_data/Clean_protein_HOLO4k_table.tsv',
                        sep='\t')
holo_ligand_table = pd.read_csv("lists/holo4k_pdbid_ligandname.csv")
ligand_dict = {}
for i in holo_ligand_table.index:
    list_ligand = eval(holo_ligand_table.loc[i, 'ligand_name'])
    ligand_dict[holo_ligand_table.loc[i, 'pdbid']] = ",".join(list_ligand)
ligand_list =  [ligand_dict[pdbid] for pdbid in table['original_sample']]
table['ligand'] = ligand_list
print(table.shape)

(491, 4)


In [15]:
# dict_new_name = {
#     '1fpx':'6cig', 
#     '1m98':'5ui2', 
#     '1pmq':'4z9l'
# }
def update_pdbid_chains(pdbid_chains):
    return pdbid_chains
#     pdbid, chains = pdbid_chains.split("_")
#     if pdbid not in dict_new_name:
#         return pdbid_chains
#     newid = dict_new_name[pdbid]
#     return "_".join([newid, chains])

In [8]:
dict_chains = {}
for i in table.index:
    pdbid, chains = table.loc[i, ['pdbid', 'chains']]
    dict_chains[pdbid] = [chains]

In [10]:
table.head()

Unnamed: 0,pdbid,chains,original_sample,ligand
0,1bkd,R,18gs,gdn
1,11as,AB,18gs,gdn
2,1eog,AB,18gs,gdn
3,1p7s,A,18gs,gdn
4,6wsk,A,18gs,gdn


### 2. Get protein features

In [16]:
def load_anchor(list_samples):
    list_anchor = []
    for pdbid_chains in list_samples:
        pdbid_chains = update_pdbid_chains(pdbid_chains)
        anchor = np.load('AnchorOutput/{}_da_{}/anchors.npy'.format(pdbid_chains, da))[0]
        list_anchor.append(anchor)
    return np.concatenate(list_anchor)

In [17]:
anchor_dict = {}
for i in table.index:
    print(len(anchor_dict), "\r", end="")
    pdbid = table.loc[i, 'pdbid']
    try:
        list_chains = [pdbid + "_" + chain for chain in table.loc[i, 'chains'].split(",")]
        anchor_dict[pdbid] = load_anchor(list_chains)
    except Exception as E:
        print(E)
        pass

print("anchor_dict", len(anchor_dict))

anchor_dict 491


In [18]:
with open(outdir+'anchor_dict_thre'+str(da), 'wb') as f:
    pickle.dump(anchor_dict, f)

In [19]:
def load_atom_dict(list_filenames):
    list_fa = []
    list_coord = []
    list_nei = []
    count = 0
    for pdbid_chains in list_filenames:
        pdbid_chains = update_pdbid_chains(pdbid_chains)
        fa, coord, nei = pickle.load(open('AnchorOutput/{}_da_{}/atom_feature.pk'\
                                          .format(pdbid_chains, da), 'rb'))[0]
        list_fa.append(fa)
        list_coord.append(coord)
        list_nei.extend([[jtem + count for jtem in item] for item in nei])
        count += len(fa)
    list_fa = np.concatenate(list_fa)
    list_coord = np.concatenate(list_coord)
    return (list_fa, list_coord, list_nei)

In [20]:
atom_dict = {}

for i in table.index:
    print(len(atom_dict), "\r", end="")
    pdbid = table.loc[i, 'pdbid']
    try:
        list_chains = [pdbid + "_" + chain for chain in table.loc[i, 'chains'].split(",")]
        atom_dict[pdbid] = load_atom_dict(list_chains)
    except:
        pass

len(atom_dict)

489 

490

In [21]:
with open(outdir+"atom_feature_coord_nei_dict_thre"+str(da), "wb") as f:
    pickle.dump(atom_dict, f)

In [22]:
def load_masif_coords(list_filenames):
    list_prt_coord = []
    for pdbid_chains in list_filenames:
        pdbid_chains = update_pdbid_chains(pdbid_chains)
        masif_coords = np.vstack([
            np.load('MasifOutput/04a-precomputation_12A/precomputation/{}/p1_X.npy'.format(pdbid_chains)),
            np.load('MasifOutput/04a-precomputation_12A/precomputation/{}/p1_Y.npy'.format(pdbid_chains)),
            np.load('MasifOutput/04a-precomputation_12A/precomputation/{}/p1_Z.npy'.format(pdbid_chains)),
        ]).T
        list_prt_coord.append(masif_coords)
    return np.concatenate(list_prt_coord)


def load_masif_feature_neighbor(list_filenames):
    list_feat = []
    list_nei = []
    count = 0
    for pdbid_chains in list_filenames:
        pdbid_chains = update_pdbid_chains(pdbid_chains)
        feat = np.load('AnchorOutput/{}_da_{}/masif_feature.npy'.format(pdbid_chains, da))
        nei = np.load('AnchorOutput/{}_da_{}/masif_neighbor.npy'.format(pdbid_chains, da))
        if np.isnan(feat).sum() != 0:
            feat[np.where(np.isnan(feat))] = 0
        list_feat.append(feat)
        list_nei.append(nei + count)
        count += len(feat)
    return np.concatenate(list_feat), np.concatenate(list_nei)

In [23]:
masif_feature_coord_nei_dict = {}

for i in table.index:
    print(len(masif_feature_coord_nei_dict), "\r", end="")
    pdbid = table.loc[i, 'pdbid']
    try:    
        list_chains = [pdbid + "_" + chain for chain in table.loc[i, 'chains'].split(",")]    
        masif_feature, masif_neighbor = load_masif_feature_neighbor(list_chains)
        masif_coords = load_masif_coords(list_chains)
        assert masif_feature.shape[0] == masif_coords.shape[0], "{} {} {}".format(pdbid, masif_feature.shape[0], masif_coords.shape[0])

        masif_feature_coord_nei_dict[pdbid] = (masif_feature, masif_coords, masif_neighbor)
    except Exception as E:
        print(E)
        pass
len(masif_feature_coord_nei_dict)

[Errno 2] No such file or directory: 'AnchorOutput/1ksv_A_da_6/masif_feature.npy'
489 

490

In [24]:
with open(outdir+'masif_feature_coord_nei_dict', 'wb') as f:
    pickle.dump(masif_feature_coord_nei_dict, f)

In [25]:
import torch

In [26]:
am_dict = {}
aa_dict = {}
at_dict = {}
for i in table.index:
    print(len(am_dict), len(aa_dict), len(at_dict), "\r", end="")
    pdbid = table.loc[i, 'pdbid']
    try:  
        anchor_coords = anchor_dict[pdbid]
        masif_coords = masif_feature_coord_nei_dict[pdbid][1]
        atom_coords = atom_dict[pdbid][1]
        
        # aa
        aa_dist = pairwise_distances(anchor_coords, anchor_coords)
        sele = np.where(aa_dist<=6)
        i = torch.LongTensor(np.vstack(sele))
        v = torch.FloatTensor(aa_dist[sele])
        aa_sparse = torch.sparse.FloatTensor(i, v, torch.Size([aa_dist.shape[0], aa_dist.shape[1]]))
        aa_dict[pdbid] = aa_sparse   
        
        # am
        am_dist = pairwise_distances(anchor_coords, masif_coords)
        sele = np.where(am_dist<=6)
        i = torch.LongTensor(np.vstack(sele))
        v = torch.FloatTensor(am_dist[sele])
        am_sparse = torch.sparse.FloatTensor(i, v, torch.Size([am_dist.shape[0], am_dist.shape[1]]))
        am_dict[pdbid] = am_sparse
        
        # at
        at_dist = pairwise_distances(anchor_coords, atom_coords)
        sele = np.where(at_dist<=6)
        i = torch.LongTensor(np.vstack(sele))
        v = torch.FloatTensor(at_dist[sele])
        at_sparse = torch.sparse.FloatTensor(i, v, torch.Size([at_dist.shape[0], at_dist.shape[1]]))
        at_dict[pdbid] = at_sparse
        
    except:
        pass
    
print(len(am_dict), len(aa_dict), len(at_dict))

490 490 490 


In [20]:
with open(outdir+'am_dict', 'wb') as f:
    pickle.dump(am_dict, f)
with open(outdir+'aa_dict', 'wb') as f:
    pickle.dump(aa_dict, f)
with open(outdir+'at_dict', 'wb') as f:
    pickle.dump(at_dict, f)

### 3. Get label

In [28]:
import pymol
import numpy as np

# def get_ligand_counts_coords(filename, chains, ligand_list, removeHs=True):
#     """
#     Input: 
#     filename：path+name of pdb file
#     ligand_list: list of ligand codes (3-letter IDs)
#     removeHs: whether to removeHs from the coordinates

#     Output:
#     count_dict: key: ligand id, value: number of occurrence
#     coord_dict: key: ligand id + chain + residue id, value: n*3 numpy array of compound coordinates
#     """
#     pymol.cmd.reinitialize()
#     pymol.cmd.load(filename)
#     if removeHs:
#         pymol.cmd.remove('hydro')
    
#     protein_coords = []
#     pymol.cmd.iterate_state(-1, "chain "+"+".join([x for x in chains])+" and not het", "protein_coords.append((x,y,z))", space=locals())

#     count_dict, coord_dict = {}, {}
#     list_tabu = ["HOH", "DOD", "WAT", "NAG", "MAN", "UNK", "GLC", "ABA", "MPD", "GOL", "SO4", "PO4", '', 'U', 'HEM', 'PI']
#     list_tabu += ['ASN', "GLY", "ALA", "PRO", "VAL", "LEU", "ILE", "MET", "PHE", "TYR", "TRP", "SER", "THR", "CYS", \
#                  "GLN", "LYS", "HIS", "ARG", "ASP", "GLU"]
#     list_ligand_ok = set()
#     for ligand in ligand_list:
#         if ligand in list_tabu:
#             continue
#         resi_set = set()
#         ligand = ligand.upper()
#         pymol.cmd.iterate('resname {}'.format(ligand), "resi_set.add(chain+'_'+resi)", space=locals())
#         count_dict[ligand] = 0
#         for chain_resi in resi_set:
#             chain, resi = chain_resi.split('_')
#             pymol.cmd.select('{}_{}'.format(ligand, chain_resi), 'chain {} and resi {}'.format(chain, resi))
#             coords = []
#             pymol.cmd.iterate_state(-1, '{}_{}'.format(ligand, chain_resi), "coords.append((x,y,z))", space=locals())
#             if len(coords) < 5:
#                 continue
#             coords = np.array(coords) 
#             if pairwise_distances(protein_coords, coords).min() < 1.5:
#                 continue
#             if pairwise_distances(protein_coords, coords).min() > 4:
#                 continue
#             if pairwise_distances(protein_coords, np.mean(coords, 0, keepdims=True)).min() > 5.5:
#                 continue
#             coord_dict['{}_{}'.format(ligand, chain_resi)] = coords
#             count_dict[ligand] += 1
#             list_ligand_ok.add(ligand)
#     return count_dict, coord_dict, list(list_ligand_ok)

In [29]:
# dict_ligand_coords = {}
# dict_num_lig = {}
# for i in table.index:
#     pdbid = table.loc[i, 'pdbid']
#     if pdbid in dict_new_name:
#         newid = dict_new_name[pdbid]
#     else:
#         newid = pdbid
#     chains = table.loc[i, 'chains']
#     if isinstance(table.loc[i, 'ligand'], str) and len(table.loc[i, 'ligand']) > 0:
#         list_ligand = table.loc[i, 'ligand'].split(',')
#     else:
#         list_ligand = []
#     if os.path.exists('MasifOutput/00-raw_pdbs/fixed_{}.pdb'.format(newid)):
#         filename = 'MasifOutput/00-raw_pdbs/fixed_{}.pdb'.format(newid)
#     elif os.path.exists('MasifOutput/00-raw_pdbs/{}.pdb'.format(newid)):
#         filename = 'MasifOutput/00-raw_pdbs/{}.pdb'.format(newid)
#     else:
#         print("NO pdb file", pdbid)
#         continue
#     count_dict, coord_dict, list_ligand = get_ligand_counts_coords(filename, chains, list_ligand)
#     dict_ligand_coords[pdbid] = coord_dict

#     if len(list_ligand) == 0:
#         table.loc[i, 'ligand_used'] = ""
#     else:
#         table.loc[i, 'ligand_used'] = ",".join(list_ligand)
#     dict_num_lig[pdbid] = sum(count_dict.values())
#     table.loc[i, 'num_ligands'] = int(sum(count_dict.values()))
#     print(i, '\r', end="")
# table['num_ligands'] = np.array(table['num_ligands'], dtype=int)

In [54]:
# get ligand coordinates before and after align
def get_coord_lists(query_pdbid, chains, num_coords=None, debug=False):
    # coords after align
    pymol.cmd.reinitialize()
    # pymol.cmd.remove('het')
    pymol.cmd.load('/data/lishuya/lab/PocketAnchor/Revise1_new_data/HOLO4k_aligned/{}.pdb'.format(query_pdbid))
    pos_after_dict = {}
    resi_record_a = []
    pymol.cmd.iterate('name CA and chain {}'.format(chains[0]), 
                      'resi_record_a.append(resi)', space=locals())
    for resi in resi_record_a:
        tmp=[]
        pymol.cmd.iterate_state(-1, 'name CA and chain {} and resi {}'.format(chains[0], resi), 
                          'tmp.append([x,y,z])', space=locals())
        # assert len(tmp) > 0, resi
        pos_after_dict[resi] = tmp[0]
    
    # coords before align (masif processed and add h)
    pymol.cmd.reinitialize()
    prefix = '/data/lishuya/lab/PocketAnchorData/MasifOutput/01-benchmark_pdbs/'
    pymol.cmd.load(prefix+'{}_{}.pdb'.format(query_pdbid, chains))
    pos_before_dict = {}
    resi_record_b = []
    pymol.cmd.iterate('name CA and chain {}'.format(chains[0]), 
                      'resi_record_b.append(resi)', space=locals())
    for resi in resi_record_b:
        tmp = []
        pymol.cmd.iterate_state(-1, 'name CA and chain {} and resi {}'.format(chains[0], resi), 
                          'tmp.append([x,y,z])', space=locals())
        pos_before_dict[resi] = tmp[0]
    resi_record = list(set(resi_record_b).intersection(resi_record_a))
    pos_before = np.array([pos_before_dict[r] for r in resi_record])
    pos_after = np.array([pos_after_dict[r] for r in resi_record])
    
    if debug:
        print('pos_before', pos_before.shape)
        print('resi_record', resi_record, len(resi_record))
        print('pos_after', pos_after.shape)
    assert len(pos_before) == len(pos_after)
    if num_coords is not None:
        pos_before = pos_before[:num_coords]
        pos_after = pos_after[:num_coords]
        if debug:
            print('pos_before', pos_before.shape)
            print('pos_after', pos_after.shape)
    return pos_before, pos_after

In [40]:
def restore_position(Xref, Yref, X):
    c1 = Xref.mean(1, keepdims=True)
    c2 = Yref.mean(1, keepdims=True)
    H = np.matmul((Xref - c1), (Yref - c2).T)
    U, D, V = np.linalg.svd(H)
    d = np.sign(np.linalg.det(H))
    I = np.eye(3)
    I[2, 2] = d
    rotation_ = np.matmul(V.T, np.matmul(I, U.T))
    translation_ = c2 - np.matmul(rotation_, c1)
    Zref = np.matmul(np.linalg.inv(rotation_), Yref - translation_)
    err = Zref - Xref
    assert np.mean(np.abs(err))<1e-3
    return np.matmul(rotation_, X) + translation_

In [30]:
# pickle.dump(dict_ligand_coords, open(outdir+"ligand_coords_dict", "wb"))

In [34]:
with open(outdir+"ligand_coords_dict", "rb") as f: # copy from PocketDetectionData_HOLO4k_da6
    dict_ligand_coords = pickle.load(f)
len(dict_ligand_coords)

4009

In [47]:
# for i in table.index:
#     ref_pdbid = table.loc[i, 'original_sample']
#     break

# dict_ligand_coords[ref_pdbid]

In [59]:
### save label
label_dict = {}
processed = 0
for i in table.index:
    print(i, " \r", end='')
    query_pdbid, chains, ref_pdbid = table.loc[i, ['pdbid', 'chains', 'original_sample']]
#     if query_pdbid in label_dict:
#         continue
    try:
        anchor_coords = anchor_dict[query_pdbid]
        cpd_coords_ref = np.concatenate(list(dict_ligand_coords[ref_pdbid].values()))
        ca_before, ca_after = get_coord_lists(query_pdbid, chains, 100)
        cpd_coords_query = restore_position(ca_after.T, ca_before.T, cpd_coords_ref.T).T
        ag = pairwise_distances(anchor_coords, cpd_coords_query).min(axis=1)
        label = (ag <= 4).astype(int)
        label_dict[query_pdbid] = label
    except Exception as E:
        print(E)
        continue
    
len(label_dict)

 Error: failed to open file "/data/lishuya/lab/PocketAnchorData/MasifOutput/01-benchmark_pdbs/1ksv_A.pdb"
453  
490  

489

In [60]:
with open(outdir+'anchor_label_n4_dict_'+str(da), 'wb') as f:
    pickle.dump(label_dict, f)

In [62]:
for value in label_dict.values():
    print(value.size, value.sum())

311 14
922 22
675 36
350 6
594 6
327 7
324 8
335 7
674 30
638 13
311 20
725 19
425 3
467 0
538 0
487 14
379 8
599 0
440 0
360 0
390 11
397 9
624 7
336 0
343 0
349 0
1010 12
1167 7
359 18
744 11
533 10
556 6
565 0
344 4
356 3
341 5
1329 12
530 7
996 13
366 12
326 13
331 11
991 15
645 11
440 6
429 6
440 7
448 6
428 8
441 8
375 9
1484 12
455 12
413 16
325 5
337 0
335 15
337 15
208 19
335 1
546 0
198 2
411 0
527 11
766 0
343 0
255 0
382 0
335 0
310 4
514 9
508 0
487 0
505 10
495 0
292 0
670 0
665 0
570 0
424 0
524 20
789 0
697 0
1036 0
1422 11
590 0
774 0
399 13
1056 14
1101 6
596 19
277 0
1113 0
2244 0
552 8
542 25
535 22
542 0
361 3
557 0
649 13
1945 8
544 17
326 16
1433 9
294 14
1045 22
910 11
366 16
534 14
762 0
512 7
496 7
616 0
515 0
985 0
1469 0
1460 0
1405 21
572 13
386 3
368 0
746 0
508 0
381 0
403 0
479 7
501 11
343 0
271 16
643 21
487 11
488 13
393 0
434 0
430 1
810 0
954 0
480 0
458 0
775 0
387 0
341 16
458 0
759 21
536 20
525 0
496 20
1041 0
453 9
412 6
1100 0
1575 15
589 16
5

### 4. Save final dataset table

In [63]:
success_list = []
for i in table.index:
    print(i, " \r", end='')
    pdbid = table.loc[i, 'pdbid']
    if pdbid not in anchor_dict:
        continue
    if pdbid not in atom_dict:
        continue  
    if pdbid not in masif_feature_coord_nei_dict:
        continue  
    if pdbid not in am_dict:
        continue  
    if pdbid not in at_dict:
        continue  
    if pdbid not in aa_dict:
        continue  
#     if table.loc[i, 'num_ligands'] == 0:
#         continue
    if pdbid not in label_dict:
        continue  
    
    success_list.append(i)
print("success_list", len(success_list))

0  1  2  3  4  5  6  7  8  9  10  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99  100  101  102  103  104  105  106  107  108  109  110  111  112  113  114  115  116  117  118  119  120  121  122  123  124  125  126  127  128  129  130  131  132  133  134  135  136  137  138  139  140  141  142  143  144  145  146  147  148  149  150  151  152  153  154  155  156  157  158  159  160  161  162  163  164  165  166  167  168  169  170  171  172  173  174  175  176  177  178  179  180  181  182  183  184  

In [64]:
print(table.shape[0])
table = table.loc[success_list]
print(table.shape[0])

491
489


In [65]:
table.to_csv(outdir+"holoplus_table_pocket_full.tsv", sep="\t", index=None)