In [1]:
import os, time, pickle, sys, math
import numpy as np
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
from pymol import cmd
from sklearn.metrics import pairwise_distances, pairwise_distances_argmin_min
from sklearn.cluster import AgglomerativeClustering

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.ML.Cluster import Butina
from scipy.cluster.hierarchy import fcluster, linkage, single
from scipy.spatial.distance import pdist

In [2]:
import pymesh

### 0. Define dataset parameters

In [3]:
da = 6
outdir = './PocketDetectionData_COACH420_da{}/'.format(da)
if not os.path.exists(outdir):
    os.mkdir(outdir)

### 1. Load dataset files

In [4]:
table = pd.read_csv("lists/coach420_pdbid_ligandname.csv")
dict_ligands = {}
for i in table.index:
    list_ligand = eval(table.loc[i, 'ligand_name'])
    pdbid_chains = table.loc[i, 'pdbid']
    table.loc[i, 'ligand_name'] = ",".join(list_ligand)
    table.loc[i, 'pdbid'] = pdbid_chains[:4]
    table.loc[i, 'chains'] = pdbid_chains[4:]
    dict_ligands[pdbid_chains[:4]] = ",".join(list_ligand)
table.columns = ['pdbid', 'ligand', 'eval_chains']
print(table.shape)

(420, 3)


In [5]:
table.head()

Unnamed: 0,pdbid,ligand,eval_chains
0,148l,UUU,E
1,1a26,CNA,A
2,1a2k,GDP,C
3,1a4k,FRA,H
4,1a7x,FKA,A


In [6]:
dict_new_name = {
    '2zcp': '3w7f',
} # updated PDB ID

def update_pdbid_chains(pdbid_chains):
    pdbid, chains = pdbid_chains.split("_")
    if pdbid not in dict_new_name:
        return pdbid_chains
    newid = dict_new_name[pdbid]
    return "_".join([newid, chains])

In [7]:
with open('lists/pdbid_chain_list_COACH420.txt', 'r') as f:
    list_task = f.readlines()
list_task = [item.strip() for item in list_task]
dict_chains = {}
for task in list_task:
    pdbid, chains = task.split("_")
    if pdbid in dict_chains:
        dict_chains[pdbid].append(chains)
    else:
        dict_chains[pdbid] = [chains]

In [8]:
for i in table.index:
    pdbid = table.loc[i, 'pdbid']
    if pdbid in dict_new_name:
        pdbid = dict_new_name[pdbid]
    if pdbid not in dict_chains:
        print(pdbid)
        continue
    table.loc[i, 'chains'] = ",".join(dict_chains[pdbid])

### 2. Get protein features

In [9]:
def load_anchor(list_samples):
    list_anchor = []
    for pdbid_chains in list_samples:
        pdbid_chains = update_pdbid_chains(pdbid_chains)
        anchor = np.load('AnchorOutput/{}_da_{}/anchors.npy'.format(pdbid_chains, da))[0]
        list_anchor.append(anchor)
    return np.concatenate(list_anchor)

In [10]:
anchor_dict = {}
for i in table.index:
    print(len(anchor_dict), "\r", end="")
    pdbid = table.loc[i, 'pdbid']
    try:
        list_chains = [pdbid + "_" + chain for chain in table.loc[i, 'chains'].split(",")]
        anchor_dict[pdbid] = load_anchor(list_chains)
    except Exception as E:
        print(i, E)
        pass

print("anchor_dict", len(anchor_dict))

anchor_dict 420


In [11]:
with open(outdir+'anchor_dict_thre'+str(da), 'wb') as f:
    pickle.dump(anchor_dict, f)

In [12]:
def load_atom_dict(list_filenames):
    list_fa = []
    list_coord = []
    list_nei = []
    count = 0
    for pdbid_chains in list_filenames:
        pdbid_chains = update_pdbid_chains(pdbid_chains)
        fa, coord, nei = pickle.load(open('AnchorOutput/{}_da_{}/atom_feature.pk'\
                                          .format(pdbid_chains, da), 'rb'))[0]
        list_fa.append(fa)
        list_coord.append(coord)
        list_nei.extend([[jtem + count for jtem in item] for item in nei])
        count += len(fa)
    list_fa = np.concatenate(list_fa)
    list_coord = np.concatenate(list_coord)
    return (list_fa, list_coord, list_nei)

In [13]:
atom_dict = {}

for i in table.index:
    print(len(atom_dict), "\r", end="")
    pdbid = table.loc[i, 'pdbid']
    try:
        list_chains = [pdbid + "_" + chain for chain in table.loc[i, 'chains'].split(",")]
        atom_dict[pdbid] = load_atom_dict(list_chains)
    except:
        pass

len(atom_dict)

419 

420

In [14]:
with open(outdir+"atom_feature_coord_nei_dict_thre"+str(da), "wb") as f:
    pickle.dump(atom_dict, f)

In [15]:
def load_masif_coords(list_filenames):
    list_prt_coord = []
    for pdbid_chains in list_filenames:
        pdbid_chains = update_pdbid_chains(pdbid_chains)
        masif_coords = np.vstack([
            np.load('MasifOutput/04a-precomputation_12A/precomputation/{}/p1_X.npy'.format(pdbid_chains)),
            np.load('MasifOutput/04a-precomputation_12A/precomputation/{}/p1_Y.npy'.format(pdbid_chains)),
            np.load('MasifOutput/04a-precomputation_12A/precomputation/{}/p1_Z.npy'.format(pdbid_chains)),
        ]).T
        list_prt_coord.append(masif_coords)
    return np.concatenate(list_prt_coord)


def load_masif_feature_neighbor(list_filenames):
    list_feat = []
    list_nei = []
    count = 0
    for pdbid_chains in list_filenames:
        pdbid_chains = update_pdbid_chains(pdbid_chains)
        feat = np.load('AnchorOutput/{}_da_{}/masif_feature.npy'.format(pdbid_chains, da))
        nei = np.load('AnchorOutput/{}_da_{}/masif_neighbor.npy'.format(pdbid_chains, da))
        if np.isnan(feat).sum() != 0:
            feat[np.where(np.isnan(feat))] = 0
        list_feat.append(feat)
        list_nei.append(nei + count)
        count += len(feat)
    return np.concatenate(list_feat), np.concatenate(list_nei)

In [16]:
masif_feature_coord_nei_dict = {}

for i in table.index:
    print(len(masif_feature_coord_nei_dict), "\r", end="")
    pdbid = table.loc[i, 'pdbid']
    try:    
        list_chains = [pdbid + "_" + chain for chain in table.loc[i, 'chains'].split(",")]    
        masif_feature, masif_neighbor = load_masif_feature_neighbor(list_chains)
        masif_coords = load_masif_coords(list_chains)
        assert masif_feature.shape[0] == masif_coords.shape[0], "{} {} {}".format(pdbid, masif_feature.shape[0], masif_coords.shape[0])
        masif_feature_coord_nei_dict[pdbid] = (masif_feature, masif_coords, masif_neighbor)
#         masif_feature_coord_nei_dict[pdbid] = (masif_feature, None, masif_neighbor)
    except Exception as E:
        print(E)
        pass
len(masif_feature_coord_nei_dict)

419 

420

In [17]:
with open(outdir+'masif_feature_coord_nei_dict', 'wb') as f:
    pickle.dump(masif_feature_coord_nei_dict, f)

In [18]:
import torch

In [19]:
am_dict = {}
aa_dict = {}
at_dict = {}
for i in table.index:
    print(len(am_dict), len(aa_dict), len(at_dict), "\r", end="")
    pdbid = table.loc[i, 'pdbid']
    anchor_coords = anchor_dict[pdbid]
    _, atom_coords, _ = atom_dict[pdbid]
    _, masif_coords, _ = masif_feature_coord_nei_dict[pdbid]
#     try:  
    # aa
    aa_dist = pairwise_distances(anchor_coords, anchor_coords)
    sele = np.where(aa_dist<=6)
    i = torch.LongTensor(np.vstack(sele))
    v = torch.FloatTensor(aa_dist[sele])
    aa_sparse = torch.sparse.FloatTensor(i, v, torch.Size([aa_dist.shape[0], aa_dist.shape[1]]))
    aa_dict[pdbid] = aa_sparse   

    # am
    am_dist = pairwise_distances(anchor_coords, masif_coords)
    sele = np.where(am_dist<=6)
    i = torch.LongTensor(np.vstack(sele))
    v = torch.FloatTensor(am_dist[sele])
    am_sparse = torch.sparse.FloatTensor(i, v, torch.Size([am_dist.shape[0], am_dist.shape[1]]))
    am_dict[pdbid] = am_sparse

    # at
    at_dist = pairwise_distances(anchor_coords, atom_coords)
    sele = np.where(at_dist<=6)
    i = torch.LongTensor(np.vstack(sele))
    v = torch.FloatTensor(at_dist[sele])
    at_sparse = torch.sparse.FloatTensor(i, v, torch.Size([at_dist.shape[0], at_dist.shape[1]]))
    at_dict[pdbid] = at_sparse

#     except:
#         pass
    
print(len(am_dict), len(aa_dict), len(at_dict))

420 420 420  52 


In [20]:
with open(outdir+'am_dict', 'wb') as f:
    pickle.dump(am_dict, f)
with open(outdir+'aa_dict', 'wb') as f:
    pickle.dump(aa_dict, f)
with open(outdir+'at_dict', 'wb') as f:
    pickle.dump(at_dict, f)

### 3. Get label

In [21]:
import pymol
import numpy as np

def get_ligand_counts_coords(filename, chains, ligand_list, removeHs=True):
    """
    Input: 
    filename：path+name of pdb file
    ligand_list: list of ligand codes (3-letter IDs)
    removeHs: whether to removeHs from the coordinates

    Output:
    count_dict: key: ligand id, value: number of occurrence
    coord_dict: key: ligand id + chain + residue id, value: n*3 numpy array of compound coordinates
    """
    pymol.cmd.reinitialize()
    pymol.cmd.load(filename)
    if removeHs:
        pymol.cmd.remove('hydro')
    
    protein_coords = []
    pymol.cmd.iterate_state(-1, "chain "+"+".join([x for x in chains])+" and not het", "protein_coords.append((x,y,z))", space=locals())

    count_dict, coord_dict = {}, {}
    list_tabu = ["HOH", "DOD", "WAT", "NAG", "MAN", "UNK", "GLC", "ABA", "MPD", "GOL", "SO4", "PO4", '', 'U', 'HEM', 'PI']
    list_tabu += ['ASN', "GLY", "ALA", "PRO", "VAL", "LEU", "ILE", "MET", "PHE", "TYR", "TRP", "SER", "THR", "CYS", \
                 "GLN", "LYS", "HIS", "ARG", "ASP", "GLU"]
    list_ligand_ok = set()
    for ligand in ligand_list:
        if ligand in list_tabu:
            continue
        resi_set = set()
        ligand = ligand.upper()
        pymol.cmd.iterate('resname {}'.format(ligand), "resi_set.add(chain+'_'+resi)", space=locals())
        count_dict[ligand] = 0
        for chain_resi in resi_set:
            chain, resi = chain_resi.split('_')
            pymol.cmd.select('{}_{}'.format(ligand, chain_resi), 'chain {} and resi {}'.format(chain, resi))
            coords = []
            pymol.cmd.iterate_state(-1, '{}_{}'.format(ligand, chain_resi), "coords.append((x,y,z))", space=locals())
            if len(coords) < 5:
                continue
            coords = np.array(coords) 
            if pairwise_distances(protein_coords, coords).min() < 1.5:
                continue
            if pairwise_distances(protein_coords, coords).min() > 4:
                continue
            if pairwise_distances(protein_coords, np.mean(coords, 0, keepdims=True)).min() > 5.5:
                continue
            coord_dict['{}_{}'.format(ligand, chain_resi)] = coords
            count_dict[ligand] += 1
            list_ligand_ok.add(ligand)
    return count_dict, coord_dict, list(list_ligand_ok)

In [22]:
# os.mkdir("ligand_coords")
dict_ligand_coords = {}
dict_num_lig = {}
for i in table.index:
    pdbid = table.loc[i, 'pdbid']
    if pdbid in dict_new_name:
        newid = dict_new_name[pdbid]
    else:
        newid = pdbid
    chains = table.loc[i, 'eval_chains']
    if isinstance(table.loc[i, 'ligand'], str) and len(table.loc[i, 'ligand']) > 0:
        list_ligand = table.loc[i, 'ligand'].split(',')
    else:
        list_ligand = []
    if os.path.exists('MasifOutput/00-raw_pdbs/fixed_{}.pdb'.format(newid)):
        filename = 'MasifOutput/00-raw_pdbs/fixed_{}.pdb'.format(newid)
    elif os.path.exists('MasifOutput/00-raw_pdbs/{}.pdb'.format(newid)):
        filename = 'MasifOutput/00-raw_pdbs/{}.pdb'.format(newid)
    else:
        print("NO pdb file", pdbid)
        continue
    count_dict, coord_dict, list_ligand = get_ligand_counts_coords(filename, chains, list_ligand)
    dict_ligand_coords[pdbid] = coord_dict
    #     np.save("ligand_coords/{}.npy".format(pdbid), coord_dict)

    if len(list_ligand) == 0:
        table.loc[i, 'ligand_used'] = ""
    else:
        table.loc[i, 'ligand_used'] = ",".join(list_ligand)
    dict_num_lig[pdbid] = sum(count_dict.values())
    table.loc[i, 'num_ligands'] = int(sum(count_dict.values()))
    print(i, '\r', end="")
table['num_ligands'] = np.array(table['num_ligands'], dtype=int)

 PyMOL not running, entering library mode (experimental)
419  

In [23]:
pickle.dump(dict_ligand_coords, open(outdir+"ligand_coords_dict", "wb"))

In [24]:
# anchor_dict = pickle.load(open('anchor_dict_thre'+str(da), "rb"))
# dict_ligand_coords = pickle.load(open("ligand_coords_dict", "rb"))

In [25]:
total = 0
for i in dict_ligand_coords:
#     print(i)
    total += len(dict_ligand_coords[i])
print(total)

457


In [26]:
### save label
label_dict = {}
processed = 0
for i in table.index:
    print(i, " \r", end='')
    pdbid = table.loc[i, 'pdbid']
    try:
        anchor_coords = anchor_dict[pdbid]
        cpd_coords = np.concatenate(list(dict_ligand_coords[pdbid].values()))

        ag = pairwise_distances(anchor_coords, cpd_coords).min(axis=1)
        label = (ag <= 4).astype(int)
        label_dict[pdbid] = label
    except:
        pass
    
len(label_dict)

419      

348

In [27]:
with open(outdir+'anchor_label_n4_dict_'+str(da), 'wb') as f:
    pickle.dump(label_dict, f)

### 4. Save final dataset table

In [28]:
success_list = []
for i in table.index:
    print(i, " \r", end='')
    pdbid = table.loc[i, 'pdbid']
    if pdbid not in anchor_dict:
        continue
    if pdbid not in atom_dict:
        continue  
    if pdbid not in masif_feature_coord_nei_dict:
        continue  
    if pdbid not in am_dict:
        continue  
    if pdbid not in at_dict:
        continue  
    if pdbid not in aa_dict:
        continue     
#     if table.loc[i, 'num_ligands'] == 0:
#         continue
#     if pdbid not in label_dict:
#         continue  
    
    success_list.append(i)
print("success_list", len(success_list))

success_list 420


In [29]:
print(table.shape[0])
table = table.loc[success_list]
print(table.shape[0])

420
420


In [30]:
table.to_csv(outdir+"coach420_table_pocket_full.tsv", sep="\t", index=None)