In [1]:
from rdkit import Chem
from rdkit.Chem import RDConfig
from rdkit.Chem.QED import qed
import os
from tqdm import tqdm
import seaborn as sns
import numpy as np

In [2]:
import sys
sys.path.append('../evaluation/')
sys.path.append(os.path.join(RDConfig.RDContribDir, "SA_Score"))
import sascorer

In [3]:
from utils import build_pdb_dict

In [29]:
bridge_type = 'vp'

# egnn
# fixed point init
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-05-31_14_11_45.077216'
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-05-30_23_27_53.688104'
# Gaussian noise init
# root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-06-17_23_04_23.779433'
root_path = '../lightning_logs/vp_bridge_egnn_CombinedSparseGraphDataset_2024-06-17_23_06_28.209248'

# transformer
# root_path = '../lightning_logs/vp_bridge_CombinedSparseGraphDataset_2024-06-01_21_36_34.208973'
# root_path = '../lightning_logs/vp_bridge_CombinedSparseGraphDataset_2024-05-31_23_42_37.443630'
aromatic = False

gen_path = os.path.join(root_path, 'reconstructed_mols')
gen_path = gen_path + '_aromatic_mode' if aromatic else gen_path

In [30]:
raw_data_path = '../../data/cleaned_crossdocked_data/raw'
pdb_dict, pdb_rev_dict = build_pdb_dict(raw_data_path)

In [31]:
def get_mols(gen_path, raw_data_path=raw_data_path):
    gen_mols, ref_mols = {}, {}
    for file in tqdm(os.listdir(gen_path)):
        ligand = file.split('.')[0]
        ref_folder = pdb_rev_dict[file]

        gen_m = Chem.MolFromMolFile(os.path.join(gen_path, file))
        ref_m = Chem.MolFromMolFile(os.path.join(raw_data_path, ref_folder, file))
        if gen_m == None or ref_m == None:
            continue
            
        gen_mols[ligand] = gen_m
        ref_mols[ligand] = ref_m

    return gen_mols, ref_mols

In [32]:
gen_mols, ref_mols = get_mols(gen_path)

  0%|                                                                                                                                                                  | 0/27169 [00:00<?, ?it/s][23:04:39] Explicit valence for atom # 11 N, 4, is greater than permitted
[23:04:39] Explicit valence for atom # 6 N, 4, is greater than permitted
[23:04:39] Explicit valence for atom # 4 N, 4, is greater than permitted
  0%|▌                                                                                                                                                      | 102/27169 [00:00<00:28, 954.50it/s][23:04:39] Explicit valence for atom # 17 N, 4, is greater than permitted
[23:04:39] Explicit valence for atom # 23 N, 4, is greater than permitted
[23:04:39] Explicit valence for atom # 23 N, 4, is greater than permitted
[23:04:39] Explicit valence for atom # 54 N, 4, is greater than permitted
[23:04:39] Explicit valence for atom # 23 N, 4, is greater than permitted
  1%|█                  

In [33]:
len(gen_mols), len(ref_mols)

(25792, 25792)

In [34]:
def compute_sa_score(mols, threshold = 5.5):
    
    sa_scores = [sascorer.calculateScore(mol) if mol!=None else 10 for mol in tqdm(mols)]
    pct_easily_synthesized = len([score for score in sa_scores if score <= threshold])/len(sa_scores)
    return sa_scores, pct_easily_synthesized

In [35]:
sa_scores, pct_easily_synthesized = compute_sa_score(gen_mols.values())
pct_easily_synthesized

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25792/25792 [00:14<00:00, 1814.40it/s]


0.5446650124069479

In [36]:
ref_sa_scores, ref_pct_easily_synthesized = compute_sa_score(ref_mols.values())

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25792/25792 [00:12<00:00, 2034.61it/s]


In [37]:
ref_pct_easily_synthesized

0.9805366004962779

In [38]:
np.mean(sa_scores), np.mean(ref_sa_scores)

(5.376894217425706, 2.8132263697875293)

In [39]:
qed_scores = [qed(mol) for mol in tqdm(list(gen_mols.values()))]
ref_qed_scores = [qed(mol) for mol in tqdm(ref_mols.values())]
qed_scores, ref_qed_scores

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25792/25792 [01:00<00:00, 425.07it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25792/25792 [00:45<00:00, 573.11it/s]


([0.6330878093502369,
  0.680371168077178,
  0.5341584722139744,
  0.6620506962890078,
  0.22077220517015492,
  0.46298719118228016,
  0.4991127520896325,
  0.6750072052018974,
  0.22037244531495875,
  0.5594883792779578,
  0.3257543829146056,
  0.2273610765698463,
  0.47604674484479204,
  0.4523845983369914,
  0.4791108038614709,
  0.45368977080340855,
  0.40364897881552686,
  0.6287021629921485,
  0.5005123294608407,
  0.44296800122194463,
  0.5717796844939236,
  0.4379503787441874,
  0.412144360685853,
  0.30572545697169157,
  0.7529898065103162,
  0.7033069577281051,
  0.18988091813189584,
  0.4532122486343476,
  0.42956919752537503,
  0.6625778799211935,
  0.27721456093472013,
  0.4848768622294365,
  0.5692666781656894,
  0.32146019735884707,
  0.5088310301358154,
  0.598073809158448,
  0.6651493862440453,
  0.5906324728519738,
  0.4875791480222217,
  0.42779815060323223,
  0.38127065449816,
  0.09481519907300501,
  0.559900864675293,
  0.5063979188875107,
  0.7834056952954008,
  

In [40]:
np.mean(qed_scores), np.mean(ref_qed_scores)

(0.47958556739743874, 0.5838308142332064)