In [6]:
import pickle
import os
from collections import Counter, defaultdict

import pulp
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs
import pandas as pd
from myopic_mces import MCES
from joblib import Parallel, delayed
from tqdm import tqdm
from tqdm_joblib import tqdm_joblib


from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

In [7]:
try:
    from rdkit.Chem.MolStandardize.tautomer import TautomerCanonicalizer, TautomerTransform
    _RD_TAUTOMER_CANONICALIZER = 'v1'
    _TAUTOMER_TRANSFORMS = (
        TautomerTransform('1,3 heteroatom H shift',
                          '[#7,S,O,Se,Te;!H0]-[#7X2,#6,#15]=[#7,#16,#8,Se,Te]'),
        TautomerTransform('1,3 (thio)keto/enol r', '[O,S,Se,Te;X2!H0]-[C]=[C]'),
    )
except ModuleNotFoundError:
    from rdkit.Chem.MolStandardize.rdMolStandardize import TautomerEnumerator  # newer rdkit
    _RD_TAUTOMER_CANONICALIZER = 'v2'

def canonical_mol_from_inchi(inchi):
    """Canonicalize mol after Chem.MolFromInchi
    Note that this function may be 50 times slower than Chem.MolFromInchi"""
    mol = Chem.MolFromInchi(inchi)
    if mol is None:
        return None
    if _RD_TAUTOMER_CANONICALIZER == 'v1':
        _molvs_t = TautomerCanonicalizer(transforms=_TAUTOMER_TRANSFORMS)
        mol = _molvs_t.canonicalize(mol)
    else:
        _te = TautomerEnumerator()
        mol = _te.Canonicalize(mol)
    return mol

def mol2smiles(mol):
    try:
        Chem.SanitizeMol(mol)
    except ValueError:
        return None
    return Chem.MolToSmiles(mol)

def is_valid(mol):
    smiles = mol2smiles(mol)
    if smiles is None:
        return False

    try:
        mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
    except:
        return False
    if len(mol_frags) > 1:
        return False
    
    return True

def compute_metrics_for_one(t_inchi, p_inchi, solver, doMCES=False, doFull=False):
    RDLogger.DisableLog('rdApp.*')

    true_mol = canonical_mol_from_inchi(t_inchi)
    true_smi = Chem.MolToSmiles(true_mol)
    true_fp = AllChem.GetMorganFingerprintAsBitVect(true_mol, 2, nBits=2048)
    true_num_bonds = true_mol.GetNumBonds()

    # Precompute metrics for each predicted molecule
    p_mces = []
    p_tanimoto = []
    p_cosine = []
    for pi in p_inchi:
        pmol = canonical_mol_from_inchi(pi)

        try:
            pmol_smi = Chem.MolToSmiles(pmol)
            if doMCES:
                p_mces.append(MCES(true_smi, pmol_smi, solver=solver, threshold=100, always_stronger_bound=False, solver_options=dict(msg=0))[1])
            else:
                p_mces.append(true_num_bonds + pmol.GetNumBonds())
        except:
            p_mces.append(true_num_bonds + pmol.GetNumBonds())

        try:
            pmol_fp = AllChem.GetMorganFingerprintAsBitVect(pmol, 2, nBits=2048)
            p_tanimoto.append(DataStructs.TanimotoSimilarity(true_fp, pmol_fp))
            p_cosine.append(DataStructs.CosineSimilarity(true_fp, pmol_fp))
        except:
            p_tanimoto.append(0.0)
            p_cosine.append(0.0)

    # Build prefix arrays for best (min) MCES, best (max) Tanimoto, best (max) Cosine
    prefix_min_mces = [100]
    prefix_max_tanimoto = [0.0]
    prefix_max_cosine = [0.0]
    for j in range(len(p_inchi)):
        prefix_min_mces.append(min(prefix_min_mces[-1], p_mces[j]))
        prefix_max_tanimoto.append(max(prefix_max_tanimoto[-1], p_tanimoto[j]))
        prefix_max_cosine.append(max(prefix_max_cosine[-1], p_cosine[j]))

    # Earliest index of true InChI, if present
    try:
        earliest_idx = p_inchi.index(t_inchi)
    except ValueError:
        earliest_idx = -1

    if doFull:
        # Compute metrics using prefix arrays
        m_local = defaultdict(float)
        for k in range(1, 101):
            m_local[f'acc@{k}'] = 1.0 if (earliest_idx != -1 and earliest_idx < k) else 0.0
            idx = min(k, len(p_inchi))
            m_local[f'mces@{k}'] = prefix_min_mces[idx]
            m_local[f'tanimoto@{k}'] = prefix_max_tanimoto[idx]
            m_local[f'cosine@{k}'] = prefix_max_cosine[idx]
            m_local[f'close_match@{k}'] = 1.0 if (prefix_max_tanimoto[idx] >= 0.675) else 0.0
            m_local[f'meaningful_match@{k}'] = 1.0 if (prefix_max_tanimoto[idx] >= 0.4) else 0.0
    else:
        m_local = defaultdict(float)
        for k in range(1, 11):
            m_local[f'acc@{k}'] = 1.0 if (earliest_idx != -1 and earliest_idx < k) else 0.0
            idx = min(k, len(p_inchi))
            m_local[f'mces@{k}'] = prefix_min_mces[idx]
            m_local[f'tanimoto@{k}'] = prefix_max_tanimoto[idx]
            m_local[f'cosine@{k}'] = prefix_max_cosine[idx]
            m_local[f'close_match@{k}'] = 1.0 if (prefix_max_tanimoto[idx] >= 0.675) else 0.0
            m_local[f'meaningful_match@{k}'] = 1.0 if (prefix_max_tanimoto[idx] >= 0.4) else 0.0

    return m_local

def compute_metrics(true, pred, csv_path, doMCES=False, doFull=False):
    true_inchi = []
    pred_inchi = []
    for i in range(len(true)):
        local_pred_inchi = []
        for j in range(len(pred[i])):
            if is_valid(pred[i][j]):
                local_pred_inchi.append(Chem.MolToInchi(pred[i][j]))

        # sort local_pred_inchi by frequency
        inchi_counts = Counter(local_pred_inchi)
        local_pred_inchi = [item for item, count in inchi_counts.most_common()]

        if not doFull:
            local_pred_inchi = local_pred_inchi[:11]

        pred_inchi.append(local_pred_inchi)
        true_inchi.append(Chem.MolToInchi(true[i]))

    solver = pulp.listSolvers(onlyAvailable=True)[0]

    with tqdm_joblib(tqdm(total=len(true_inchi))) as progress_bar:
        results = Parallel(n_jobs=-1)(
            delayed(compute_metrics_for_one)(
                true_inchi[i],
                pred_inchi[i],
                solver,
                doMCES=doMCES,
                doFull=doFull
            )
            for i in range(len(true_inchi))
        )

    # aggregate results
    final_metrics = defaultdict(float)
    for r in results:
        for key, val in r.items():
            final_metrics[key] += val

    if doFull:
        for k in range(1, 101):
            final_metrics[f'acc@{k}'] /= len(true_inchi)
            final_metrics[f'mces@{k}'] /= len(true_inchi)
            final_metrics[f'tanimoto@{k}'] /= len(true_inchi)
            final_metrics[f'cosine@{k}'] /= len(true_inchi)
            final_metrics[f'close_match@{k}'] /= len(true_inchi)
            final_metrics[f'meaningful_match@{k}'] /= len(true_inchi)
    else:
        for k in range(1, 11):
            final_metrics[f'acc@{k}'] /= len(true_inchi)
            final_metrics[f'mces@{k}'] /= len(true_inchi)
            final_metrics[f'tanimoto@{k}'] /= len(true_inchi)
            final_metrics[f'cosine@{k}'] /= len(true_inchi)
            final_metrics[f'close_match@{k}'] /= len(true_inchi)
            final_metrics[f'meaningful_match@{k}'] /= len(true_inchi)

    df = pd.DataFrame(final_metrics, index=[0])
    df.to_csv(csv_path, index=False)


In [None]:
# example code of loading model predictions as saved in the diffusion_model_spec2mol.py test step
# paths/loading will be different

canopus_true = []
canopus_pred = []
for idx in range(1, 5):
    i = idx-1
    while os.path.exists(f"../final_results/canopus/spec2mol-canopus-eval-{idx}_resume_pred_{i}.pkl"):
        with open(f"../final_results/canopus/spec2mol-canopus-eval-{idx}_resume_true_{i}.pkl", 'rb') as f:
            canopus_true.extend(pickle.load(f))
        with open(f"../final_results/canopus/spec2mol-canopus-eval-{idx}_resume_pred_{i}.pkl", 'rb') as f:
            canopus_pred.extend(pickle.load(f))
        i += 4

In [None]:
compute_metrics(canopus_true, canopus_pred, "canopus_metrics.csv", doMCES=False, doFull=True)