In [1]:
from rdkit import Chem
from rdkit.Chem import RDConfig
from rdkit.Chem.QED import qed
import os
from tqdm import tqdm
import seaborn as sns
import numpy as np

In [2]:
import sys
sys.path.append('../evaluation/')
sys.path.append(os.path.join(RDConfig.RDContribDir, "SA_Score"))
import sascorer

In [3]:
from utils import build_pdb_dict

In [4]:
bridge_type = 'vp'

# egnn
# fixed point init
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-05-31_14_11_45.077216'
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-05-30_23_27_53.688104'
# Gaussian noise init
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-06-17_23_04_23.779433'
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-06-17_23_06_28.209248'

# transformer
# root_path = '../lightning_logs/vp_bridge_CombinedSparseGraphDataset_2024-06-01_21_36_34.208973'
# root_path = '../lightning_logs/vp_bridge_CombinedSparseGraphDataset_2024-05-31_23_42_37.443630'

# only basic

root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-07-19_14_29_28.164795'
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-07-19_14_29_44.297462'

# basic + aromatic

# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-07-20_02_24_30.913781'
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-07-20_13_18_12.746286'

# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-07-21_23_58_36.562980'

aromatic = False

gen_path = os.path.join(root_path, 'reconstructed_mols')
gen_path = gen_path + '_aromatic_mode' if aromatic else gen_path

In [5]:
raw_data_path = '../../data/cleaned_crossdocked_data/raw'
pdb_dict, pdb_rev_dict = build_pdb_dict(raw_data_path)

In [6]:
def get_mols(gen_path, raw_data_path=raw_data_path):
    gen_mols, ref_mols = {}, {}
    for file in tqdm(os.listdir(gen_path)):
        ligand = file.split('.')[0]
        ref_folder = pdb_rev_dict[file]

        gen_m = Chem.MolFromMolFile(os.path.join(gen_path, file))
        ref_m = Chem.MolFromMolFile(os.path.join(raw_data_path, ref_folder, file))
        if gen_m == None or ref_m == None:
            continue
            
        gen_mols[ligand] = gen_m
        ref_mols[ligand] = ref_m

    return gen_mols, ref_mols

In [7]:
gen_mols, ref_mols = get_mols(gen_path)

  0%|▍                                                                                                                                       | 93/28871 [00:00<01:34, 303.37it/s][21:11:29] Explicit valence for atom # 19 N, 4, is greater than permitted
  1%|█                                                                                                                                      | 214/28871 [00:00<01:54, 249.79it/s][21:11:29] Explicit valence for atom # 13 N, 4, is greater than permitted
  1%|█▎                                                                                                                                     | 273/28871 [00:01<01:46, 269.59it/s][21:11:29] Explicit valence for atom # 0 N, 4, is greater than permitted
  2%|██▊                                                                                                                                    | 602/28871 [00:02<01:36, 292.54it/s][21:11:30] Explicit valence for atom # 15 N, 4, is greater than permitt

In [8]:
len(gen_mols), len(ref_mols)

(28819, 28819)

In [9]:
def compute_sa_score(mols, threshold = 5.5):
    
    sa_scores = [sascorer.calculateScore(mol) if mol!=None else 10 for mol in tqdm(mols)]
    pct_easily_synthesized = len([score for score in sa_scores if score <= threshold])/len(sa_scores)
    return sa_scores, pct_easily_synthesized

In [10]:
sa_scores, pct_easily_synthesized = compute_sa_score(gen_mols.values())
pct_easily_synthesized

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28819/28819 [00:03<00:00, 7738.25it/s]


0.7863562233248899

In [11]:
ref_sa_scores, ref_pct_easily_synthesized = compute_sa_score(ref_mols.values())

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28819/28819 [00:03<00:00, 7319.93it/s]


In [12]:
ref_pct_easily_synthesized

0.9764391547243139

In [13]:
np.mean(sa_scores), np.mean(ref_sa_scores)

(4.615065682319571, 2.8614071372342096)

In [14]:
qed_scores = [qed(mol) for mol in tqdm(list(gen_mols.values()))]
ref_qed_scores = [qed(mol) for mol in tqdm(ref_mols.values())]
qed_scores, ref_qed_scores

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28819/28819 [00:23<00:00, 1227.60it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28819/28819 [00:22<00:00, 1307.42it/s]


([0.7217362549431576,
  0.4681048987635945,
  0.200123627369949,
  0.4816019169622866,
  0.6825206414359867,
  0.43537546360164203,
  0.44210961702042856,
  0.36304652681546207,
  0.4909691069150022,
  0.31041772849835914,
  0.3665752032363416,
  0.6035002406315105,
  0.23710536580837446,
  0.09278417492544642,
  0.1971591364962201,
  0.5647276065407333,
  0.35275816313825537,
  0.29975436196084215,
  0.47708121717536184,
  0.28465565148342153,
  0.4262123620618739,
  0.5265440122916606,
  0.5101790934381121,
  0.42467595773448463,
  0.5890571009862796,
  0.7038474143529596,
  0.24092971047257586,
  0.49680712743480754,
  0.6687182969332831,
  0.679897234138452,
  0.39985429803244105,
  0.46311772070346324,
  0.5671232539696827,
  0.7472989276585272,
  0.5456732277192704,
  0.4825489761530565,
  0.4672790498320149,
  0.5225843587187353,
  0.6197455583870372,
  0.5298963017621042,
  0.5076178485076364,
  0.12281493397303918,
  0.17953428131281757,
  0.6762813913281088,
  0.2799069039952

In [15]:
np.mean(qed_scores), np.mean(ref_qed_scores)

(0.4022812828086389, 0.5738257956063266)