In [1]:
from rdkit import Chem, RDLogger 
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers

RDLogger.DisableLog('rdApp.*')

def demap(smi):
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    smi = Chem.MolToSmiles(mol)
    return Chem.MolToSmiles(Chem.MolFromSmiles(smi))

def exact_match(preds, reac):
    try:
        for k, pred in enumerate(preds):
            if pred == reac:
                return k+1
        return -1
    except:
        return -1
    
def get_isomers(smi):
    mol = Chem.MolFromSmiles(smi)
    isomers = tuple(EnumerateStereoisomers(mol))
    isomers_smi = [Chem.MolToSmiles(x, isomericSmiles=True) for x in isomers]
    return isomers_smi
    
def isomer_match(preds, reac):
    try:
        reac_isomers = get_isomers(reac)
        for k, pred in enumerate(preds):
            pred_isomers = get_isomers(pred)
            if(set(pred_isomers).issubset(set(reac_isomers))):
                return k+1
        return -1
    except:
        return -1

In [2]:
import pandas as pd
dataset = 'USPTO_50K'
test_file = pd.read_csv('data/%s/raw_test.csv' % dataset)
ground_truth = [demap(rxn.split('>>')[0]) for rxn in test_file['reactants>reagents>production']]

In [3]:
GRA = True
use_class = False

if GRA:
    GRA = 'GRA'
else:
    GRA = 'noGRA'
    
if use_class:
    use_class = '_class'
else:
    use_class = ''
result_file = 'results/%s_%s_outputs/decoded%s_prediction.txt' % (dataset, GRA, use_class)
results = {}
with open(result_file, 'r') as f:
    for i, line in enumerate(f.readlines()):
        results[i] = line.split('\n')[0].split('\t')
        
exact_matches = []
isomer_matches = []
for i in range(len(results)):
    exact_matches.append(exact_match(results[i], ground_truth[i]))
    isomer_matches.append(isomer_match(results[i], ground_truth[i]))
    if i % 100 == 0:
        print ('\rCalculating exact accuracy... %s/%s' % (i, len(results)), end='', flush=True)

Calculating exact accuracy... 5000/5007

In [4]:
ks = [1, 3, 5, 10, 50]
exact_k = {k:0 for k in ks}
isomer_k = {k:0 for k in ks}

for i in range(len(results)):
    for k in ks:
        if exact_matches[i] <= k and exact_matches[i] != -1:
            exact_k[k] += 1
        if isomer_matches[i] <= k and isomer_matches[i] != -1:
            isomer_k[k] += 1

for k in ks:
    print ('Top-%d accuracy: exact: %.3f, isomer: %.3f' % (k, exact_k[k]/len(results), isomer_k[k]/len(results)))

Top-1 accuracy: exact: 0.529, isomer: 0.534
Top-3 accuracy: exact: 0.768, isomer: 0.775
Top-5 accuracy: exact: 0.852, isomer: 0.859
Top-10 accuracy: exact: 0.916, isomer: 0.924
Top-50 accuracy: exact: 0.969, isomer: 0.977
