In [1]:
import os
import pandas as pd
import torch
import pickle
import matplotlib.pyplot as plt
import seaborn as sns

from IPython.display import display
from tqdm import tqdm
from litschnet import LitSchNet
from conf_ensemble_dataset_in_memory import ConfEnsembleDataset
from rmsd_predictor_evaluator import RMSDPredictorEvaluator
from pdbbind_metadata_processor import PDBBindMetadataProcessor
from chembl_connector import ChEMBLConnector
from torch.utils.data import ConcatDataset
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles, MakeScaffoldGeneric, GetScaffoldForMol

In [60]:
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.DataManip.Metric.rdMetricMatrixCalc import GetTanimotoSimMat
from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity
class SimilaritySearch() :
    
    def __init__(self, smiles_list) :
        self.smiles_list = list(set(smiles_list))
        self.mols = [Chem.MolFromSmiles(smiles) for smiles in self.smiles_list]
        self.fps = [self.get_morgan_fingerprint(mol) for mol in self.mols]
        self.sim_matrix = self.get_sim_matrix()
        print('Similarity matrix ready')
        
    def get_sim_matrix(self) :
#         sim_triangle = GetTanimotoSimMat(self.fps)
#         sim_matrix = self.tri2mat(sim_triangle)
#         return sim_matrix
        
        n_fps = len(self.fps)
        sim_matrix = np.eye(n_fps)
        for i, fp1 in enumerate(self.fps) :
            fp1 = self.fps[i]
            other_fps = self.fps[i+1:]
            sims = BulkTanimotoSimilarity(fp1, other_fps)
            for j, sim in enumerate(sims) :
                sim_matrix[i, i + 1 + j] = sim_matrix[i + 1 + j, i] = sim
        return sim_matrix
        
    def get_morgan_fingerprint(self, mol) :
        return AllChem.GetMorganFingerprintAsBitVect(mol, 3, useChirality=True)
        
    def get_similarity(self, fp1, fp2) :
        return DataStructs.TanimotoSimilarity(fp1, fp2)
        
    def find_closest_in_set(self, smiles, n=1) :
        
        #import pdb; pdb.set_trace()
        
        if smiles in self.smiles_list :
            smiles_index = self.smiles_list.index(smiles)
            sims = self.sim_matrix[smiles_index]
            sims = [sim for i, sim in enumerate(sims) if i != smiles_index]
            smiles_list = [s for i, s in enumerate(self.smiles_list) if i != smiles_index]
        else :
            input_mol = Chem.MolFromSmiles(smiles)
            input_fp = self.get_morgan_fingerprint(input_mol)
            sims = [self.get_similarity(input_fp, fp2) for fp2 in self.fps]
            smiles_list = self.smiles_list
            
        sims = np.array(sims)
        best_sim_indexes = np.argsort(-sims)[:n] # negate to get best
        closest_smiles = [smiles_list[i] for i in best_sim_indexes]
        closest_sims = sims[best_sim_indexes]
        return closest_smiles, closest_sims
        
    def tri2mat(self, tri_arr):
        n = len(tri_arr)
        m = int((np.sqrt(1 + 4 * 2 * n) + 1) / 2)
        arr = np.ones([m, m])
        for i in range(m):
            for j in range(i):
                arr[i][j] = tri_arr[i + j - 1]
                arr[j][i] = tri_arr[i + j - 1]
        return arr
        
#     def find_closest_in_set(self, smiles, n=1) :
#         if smiles in self.smiles_list :
#             #print('smiles in set, returning next closests')
#             other_smiles_index = [i for i, s in enumerate(self.smiles_list) if s != smiles]
#             smiles_list = [s for i, s in enumerate(self.smiles_list) if i in other_smiles_index]
#             mols = [mol for i, mol in enumerate(self.mols) if i in other_smiles_index]
#             fps = [fp for i, fp in enumerate(self.fps) if i in other_smiles_index]
#         else :
#             smiles_list = self.smiles_list
#             mols = self.mols
#             fps = self.fps
            
#         input_mol = Chem.MolFromSmiles(smiles)
#         input_fp = self.get_morgan_fingerprint(input_mol)
#         sims = [self.get_similarity(input_fp, fp2) for fp2 in fps]
#         sims = np.array(sims)
#         best_sim_indexes = np.argsort(-sims)[:n] # negate to get best
#         closest_smiles = [smiles_list[i] for i in best_sim_indexes]
#         closest_sims = sims[best_sim_indexes]
#         return closest_smiles, closest_sims

In [49]:
pdbbind_table = PDBBindMetadataProcessor().get_master_dataframe(remove_peptide_ligands=False)
chembl_table = ChEMBLConnector().get_target_table(level=3)
merged_table = pdbbind_table.merge(chembl_table, left_on='Uniprot ID', right_on='accession')
smiles_df = pd.read_csv('data/smiles_df.csv')
included_smiles = smiles_df[smiles_df['included']]['smiles'].unique()
included_pdb_ids = smiles_df[smiles_df['included']]['id'].unique()

In [61]:
%%time
ss = SimilaritySearch(smiles_list=included_smiles)

Similarity matrix ready
CPU times: user 23.1 s, sys: 288 ms, total: 23.4 s
Wall time: 23.4 s


In [67]:
merged_table[merged_table['gene_symbol_lowercase'] == 'brpf1']

Unnamed: 0,PDB code,resolution,release year_x,-logKd/Ki,Kd/Ki,reference,ligand name,activity_list,sep,value,units,release year_y,Uniprot ID,protein name,active,accession,component_synonym,protein_class_desc,level3,gene_symbol_lowercase
4438,3mo8,1.69,2010,3.10,Kd=0.8mM //,3mo8.pdf,(12-mer),"[mM, =, 0.8]",=,800000.00,nM,2010,P55201,PEREGRIN,False,P55201,BRPF1,epigenetic regulator reader brd,epigenetic regulator reader brd,brpf1
4439,3mo8,1.69,2010,3.10,Kd=0.8mM //,3mo8.pdf,(12-mer),"[mM, =, 0.8]",=,800000.00,nM,2010,P55201,PEREGRIN,False,P55201,BRPF1,epigenetic regulator reader phd,epigenetic regulator reader phd,brpf1
4440,3mo8,1.69,2010,3.10,Kd=0.8mM //,3mo8.pdf,(12-mer),"[mM, =, 0.8]",=,800000.00,nM,2010,P55201,PEREGRIN,False,P55201,BRPF1,epigenetic regulator reader methyl-lysine pwwp,epigenetic regulator reader methyl-lysine,brpf1
4441,5o4t,1.50,2018,3.40,IC50>400uM //,5mwg.pdf,(9KT),"[uM, >, 400]",>,400000.00,nM,2018,P55201,PEREGRIN,False,P55201,BRPF1,epigenetic regulator reader brd,epigenetic regulator reader brd,brpf1
4442,5o4t,1.50,2018,3.40,IC50>400uM //,5mwg.pdf,(9KT),"[uM, >, 400]",>,400000.00,nM,2018,P55201,PEREGRIN,False,P55201,BRPF1,epigenetic regulator reader phd,epigenetic regulator reader phd,brpf1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4532,5myg,2.30,2017,7.51,Kd=31nM //,5myg.pdf,(LS8),"[nM, =, 31]",=,31.00,nM,2017,P55201,PEREGRIN,True,P55201,BRPF1,epigenetic regulator reader phd,epigenetic regulator reader phd,brpf1
4533,5myg,2.30,2017,7.51,Kd=31nM //,5myg.pdf,(LS8),"[nM, =, 31]",=,31.00,nM,2017,P55201,PEREGRIN,True,P55201,BRPF1,epigenetic regulator reader methyl-lysine pwwp,epigenetic regulator reader methyl-lysine,brpf1
4534,4uye,1.65,2014,8.02,Kd=9.54nM //,4uyd.pdf,(9F9),"[nM, =, 9.54]",=,9.54,nM,2014,P55201,PEREGRIN,True,P55201,BRPF1,epigenetic regulator reader brd,epigenetic regulator reader brd,brpf1
4535,4uye,1.65,2014,8.02,Kd=9.54nM //,4uyd.pdf,(9F9),"[nM, =, 9.54]",=,9.54,nM,2014,P55201,PEREGRIN,True,P55201,BRPF1,epigenetic regulator reader phd,epigenetic regulator reader phd,brpf1


In [64]:
for gene_symbol in merged_table['gene_symbol_lowercase'].value_counts().index :
    print(gene_symbol)
    gene_table = merged_table[(merged_table['gene_symbol_lowercase'] == gene_symbol)
                          & (merged_table['PDB code'].isin(included_pdb_ids))]
    gene_pdbs = gene_table['PDB code'].unique()
    gene_smiles = smiles_df[(smiles_df['id'].isin(gene_pdbs)) 
                            & (smiles_df['included'])]['smiles'].unique()
    
    gene_in_closest = []
    for smiles in gene_smiles :
        mol = Chem.MolFromSmiles(smiles)
        #display(mol)
        closest_smiles, closest_sims = ss.find_closest_in_set(smiles)
        #display(Draw.MolsToGridImage([Chem.MolFromSmiles(smiles) for smiles in closest_smiles]))
        closest_pdbs = smiles_df[(smiles_df['smiles'].isin(closest_smiles)) 
                             & (smiles_df['included'])]['id'].unique()
        closest_pdbbind = merged_table[merged_table['PDB code'].isin(closest_pdbs)]
        genes = closest_pdbbind['gene_symbol_lowercase'].values
        gene_in_closest.append(gene_symbol in genes)
        
    print(np.mean(gene_in_closest))

ca2
0.8021978021978022
bace1
0.9535603715170279
gag-pol
0.6748466257668712
cdk2
0.6666666666666666
brd4
0.6684782608695652
f2
0.7439024390243902
hsp90aa1
0.7730061349693251
mapk14
0.7
jak2
0.5737704918032787
kdm4a
0.7142857142857143
tnks2
0.819047619047619
pygm
0.7625
brpf1
0.3548387096774194
ephx2
0.3125
prkaca
0.7352941176470589
f10
0.8350515463917526
aurka
0.5212765957446809
ptpn1
0.8923076923076924
chek1
0.6744186046511628
bla
0.75
plau
0.5483870967741935
pim1
0.37209302325581395
ache
0.7710843373493976
dpp4
0.8170731707317073
pde10a
0.6428571428571429
crebbp
0.675
egfr
0.7878787878787878
src
0.4782608695652174
ampc
0.5675675675675675
mdm2
0.7924528301886793
gria2
0.6981132075471698
esr1
0.6530612244897959
ren
0.859375
jak3
0.696969696969697
pparg
0.5172413793103449
mapk1
0.6779661016949152
wdr5
0.6190476190476191
ttr
0.5932203389830508
kdm5a
0.7857142857142857
pik3cg
0.6290322580645161
rorc
0.6949152542372882
csnk2a1
0.6981132075471698
met
0.7169811320754716
folh1
0.91666666666666

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


0.36363636363636365
ca12
0.625
cdpk1
0.625
bcl2
0.75
ptpn11
0.6363636363636364
parp14
0.5833333333333334
rpa1
0.6923076923076923
mmp8
0.42857142857142855
lpxc
0.5333333333333333
dck
0.8125
hint1
0.6666666666666666
hspa8
0.5
nr1h2
0.8125
ada
0.6666666666666666
gart
0.8666666666666667
cgas
0.5555555555555556
grb2
nan
hprt1
0.0
dxr
nan
cpa1
0.4166666666666667
ephb4
0.6428571428571429
gls
0.9230769230769231
ppif
0.8333333333333334
ywhaz
nan
ripk2
0.13333333333333333
lgb
0.5714285714285714
pde6d
0.7692307692307693
alpha-lp
1.0
eif4e
0.0
rabggta
0.5555555555555556
crtm
0.0
grin1
0.6428571428571429
fimh
0.7142857142857143
stk24
0.14285714285714285
pfkfb3
0.9285714285714286
egln1
0.46153846153846156
chka
0.7692307692307693
mapkapk2
0.4166666666666667
fto
0.42857142857142855
lacz
0.0
furin
0.6666666666666666
gckr
1.0
bts1
nan
phgdh
0.46153846153846156
mapk8
0.1
gstp1
0.2222222222222222
shc
0.7692307692307693
prep
0.9230769230769231
setd7
0.4
adcy10
0.36363636363636365
hmgcr
1.0
pde9a
0.83333333