In [1]:
import numpy as np
import pandas as pd
import json
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
from rdkit.DataManip.Metric.rdMetricMatrixCalc import GetTanimotoSimMat
from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity
from scipy.spatial.distance import squareform
from scipy.cluster.hierarchy import linkage, fcluster

In [28]:
class MoleculeSimilarity() :
    
    def __init__(self, smiles_list) :
        self.smiles_list = self.remove_duplicated_smiles(smiles_list)
        self.mols = [Chem.MolFromSmiles(smiles) for smiles in self.smiles_list]
        self.fps = [self.get_morgan_fingerprint(mol) for mol in self.mols]
        self.sim_matrix = self.get_sim_matrix()
        print('Similarity matrix ready')
        
    def remove_duplicated_smiles(self, smiles_list) :
        new_list = []
        for smiles in smiles_list :
            if not smiles in new_list :
                new_list.append(smiles)
        return new_list
        
    def get_sim_matrix(self) :
#         sim_triangle = GetTanimotoSimMat(self.fps)
#         sim_matrix = self.tri2mat(sim_triangle)
#         return sim_matrix
        
        n_fps = len(self.fps)
        sim_matrix = np.eye(n_fps)
        for i, fp1 in enumerate(self.fps) :
            fp1 = self.fps[i]
            other_fps = self.fps[i+1:]
            sims = BulkTanimotoSimilarity(fp1, other_fps)
            for j, sim in enumerate(sims) :
                sim_matrix[i, i + 1 + j] = sim_matrix[i + 1 + j, i] = sim
        return sim_matrix
        
    def get_morgan_fingerprint(self, mol) :
        return AllChem.GetMorganFingerprintAsBitVect(mol, 3, useChirality=True)
        
    def get_similarity(self, fp1, fp2) :
        return DataStructs.TanimotoSimilarity(fp1, fp2)
        
    def find_closest_in_set(self, smiles, n=1) :
        
        #import pdb; pdb.set_trace()
        
        if smiles in self.smiles_list :
            smiles_index = self.smiles_list.index(smiles)
            sims = self.sim_matrix[smiles_index]
            sims = [sim for i, sim in enumerate(sims) if i != smiles_index]
            smiles_list = [s for i, s in enumerate(self.smiles_list) if i != smiles_index]
        else :
            input_mol = Chem.MolFromSmiles(smiles)
            input_fp = self.get_morgan_fingerprint(input_mol)
            sims = [self.get_similarity(input_fp, fp2) for fp2 in self.fps]
            smiles_list = self.smiles_list
            
        sims = np.array(sims)
        best_sim_indexes = np.argsort(-sims)[:n] # negate to get best
        closest_smiles = [smiles_list[i] for i in best_sim_indexes]
        closest_sims = sims[best_sim_indexes]
        return closest_smiles, closest_sims
        
    def tri2mat(self, tri_arr):
        n = len(tri_arr)
        m = int((np.sqrt(1 + 4 * 2 * n) + 1) / 2)
        arr = np.ones([m, m])
        for i in range(m):
            for j in range(i):
                arr[i][j] = tri_arr[i + j - 1]
                arr[j][i] = tri_arr[i + j - 1]
        return arr

In [29]:
smiles_df = pd.read_csv('data/smiles_df.csv')
all_smiles = smiles_df['smiles'].unique()
ms = MoleculeSimilarity(all_smiles)

Similarity matrix ready


In [30]:
dm = 1 - ms.sim_matrix
condensed_dm = squareform(dm)

In [31]:
Z = linkage(condensed_dm)

In [32]:
max_value = 0.5
T = fcluster(Z, t=max_value, criterion='distance')
print(f"Found {max(T)+1} clusters with max {max_value} different.\n")

Found 7786 clusters with max 0.5 different.



In [33]:
cluster_filename = 'molecule_clusters.json'
with open(cluster_filename, 'w') as f:
    json.dump(T.tolist(), f, indent=4)
print('Flat cluster result save at {}\n'.format(cluster_filename))

smiles_to_cluster = {all_smiles[i] : int(cluster_id) for i, cluster_id in enumerate(T)}
cluster_filename = 'smiles_clusters.json'
with open(cluster_filename, 'w') as f:
    json.dump(smiles_to_cluster, f, indent=4)
print('Flat cluster result save at {}\n'.format(cluster_filename))

Flat cluster result save at molecule_clusters.json

Flat cluster result save at smiles_clusters.json



In [34]:
from collections import Counter
from pdbbind_metadata_processor import PDBBindMetadataProcessor

In [35]:
clusters = T
counter = Counter(clusters)

In [36]:
top10clusters = counter.most_common()[:10]

In [37]:
top10clusters

[(5252, 1763),
 (5152, 148),
 (4875, 125),
 (4433, 84),
 (5187, 76),
 (4846, 48),
 (4258, 44),
 (5081, 40),
 (4344, 35),
 (4599, 33)]

In [41]:
pp = PDBBindMetadataProcessor()
table = pp.get_master_dataframe()

In [43]:
for big_cluster_id, count in top10clusters :
    most_common_smiles = [smiles for smiles, cluster_id in smiles_to_cluster.items() if cluster_id == big_cluster_id]
    most_common_pdbs = smiles_df[smiles_df['smiles'].isin(most_common_smiles)]['id'].unique()
    filtered_table = table[table['PDB code'].isin(most_common_pdbs)]
    print(filtered_table['protein name'].value_counts())

THROMBIN LIGHT CHAIN                           17
TRYPSIN                                        15
THROMBIN                                       15
GLUTAMATE CARBOXYPEPTIDASE 2                   14
TRANSPORTER                                    11
                                               ..
AMINO ACID TRANSPORTER                          1
GLUTAMATE [NMDA] RECEPTOR SUBUNIT 3A            1
BETAINE--HOMOCYSTEINE S-METHYLTRANSFERASE 1     1
GAG-PRO-POL POLYPROTEIN                         1
BOTULINUM NEUROTOXIN A LIGHT CHAIN              1
Name: protein name, Length: 204, dtype: int64
GLYCOGEN PHOSPHORYLASE                           16
BIFUNCTIONAL LIGASE/REPRESSOR BIRA               10
SERINE/THREONINE-PROTEIN KINASE HASPIN            9
ADENOSINE KINASE                                  9
PANTOTHENATE SYNTHETASE                           7
                                                 ..
3-PHOSPHOINOSITIDE-DEPENDENT PROTEIN KINASE 1     1
MRNA CAPPING ENZYME                     

In [45]:
frac_train = 0.8
frac_val = 0.1

train_cutoff = int(frac_train * len(all_smiles))
val_cutoff = int((frac_train + frac_val) * len(all_smiles))
train_inds = []
val_inds = []
test_inds = []

In [46]:
unique_cluster_ids = list(counter.keys())

In [50]:
import os
import random
data_dir_path = 'data/'
ecfp_similarity_splits_dir_name = 'ecfp_similarity_splits'
ecfp_similarity_splits_dir_path = os.path.join(data_dir_path, ecfp_similarity_splits_dir_name)
if not os.path.exists(ecfp_similarity_splits_dir_path) :
    os.mkdir(ecfp_similarity_splits_dir_path)

In [58]:
cluster_smiles

['COC(=O)N1CC[NH+]([C@H]2CC[C@H](Nc3nc(Nc4cnn(C)c4)nc4ccc(CC#N)nc34)CC2)CC1']

In [62]:
from tqdm import tqdm

In [64]:
for i in range(5) :
    
    random.shuffle(unique_cluster_ids)
    
    train_smiles = []
    val_smiles = []
    test_smiles = []
    
    for current_cluster_id in tqdm(unique_cluster_ids):
        cluster_smiles = [smiles 
                          for smiles, cluster_id in smiles_to_cluster.items() 
                          if cluster_id == current_cluster_id]
        if len(train_smiles) + len(cluster_smiles) > train_cutoff:
            if len(train_smiles) + len(val_smiles) + len(cluster_smiles) > val_cutoff:
                test_smiles.extend(cluster_smiles)
            else:
                val_smiles.extend(cluster_smiles)
        else:
            train_smiles.extend(cluster_smiles)
    
    with open(os.path.join(ecfp_similarity_splits_dir_path, f'train_smiles_ecfp_similarity_split_{i}.txt'), 'w') as f :
        for pdb in train_smiles :
            f.write(pdb)
            f.write('\n')
        
    with open(os.path.join(ecfp_similarity_splits_dir_path, f'val_smiles_ecfp_similarity_split_{i}.txt'), 'w') as f :
        for pdb in val_smiles :
            f.write(pdb)
            f.write('\n')
        
    with open(os.path.join(ecfp_similarity_splits_dir_path, f'test_smiles_ecfp_similarity_split_{i}.txt'), 'w') as f :
        for pdb in test_smiles :
            f.write(pdb)
            f.write('\n')

100%|███████████████████████████████████████| 7785/7785 [01:37<00:00, 79.58it/s]
100%|███████████████████████████████████████| 7785/7785 [01:36<00:00, 80.47it/s]
100%|███████████████████████████████████████| 7785/7785 [01:37<00:00, 79.62it/s]
100%|███████████████████████████████████████| 7785/7785 [01:37<00:00, 79.79it/s]
100%|███████████████████████████████████████| 7785/7785 [01:38<00:00, 79.41it/s]


In [65]:
len(train_smiles)

11508

In [66]:
len(val_smiles)

1438

In [67]:
len(test_smiles)

1439