In [1]:
from rdkit import Chem, RDLogger 
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers

RDLogger.DisableLog('rdApp.*')

def demap(smi):
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    smi = Chem.MolToSmiles(mol)
    return Chem.MolToSmiles(Chem.MolFromSmiles(smi))

def get_isomers(smi):
    mol = Chem.MolFromSmiles(smi)
    isomers = tuple(EnumerateStereoisomers(mol))
    isomers_smi = [Chem.MolToSmiles(x, isomericSmiles=True) for x in isomers]
    return isomers_smi
    
def isomer_match(preds, reac):
    try:
        reac_isomers = get_isomers(reac)
        for k, pred in enumerate(preds):
            pred_isomers = get_isomers(pred)
            if(set(pred_isomers).issubset(set(reac_isomers))):
                return k+1
        return -1
    except:
        return -1

In [2]:
import pandas as pd
dataset = 'USPTO_50K'
test_file = pd.read_csv('data/%s/raw_test.csv' % dataset)
ground_truth = [demap(rxn.split('>>')[0]) for rxn in test_file['reactants>reagents>production']]

In [11]:
GRA = True
use_class = False

if GRA:
    GRA = 'GRA'
else:
    GRA = 'noGRA'
    
if use_class:
    use_class = '_class'
else:
    use_class = ''
result_file = 'Results/%s_%s_outputs/decoded%s_prediction.txt' % (dataset, GRA, use_class)
results = {}
with open(result_file, 'r') as f:
    for i, line in enumerate(f.readlines()):
        results[i] = line.split('\t')
        
correct_k = []
for i in range(len(results)):
    correct_k.append(isomer_match(results[i], ground_truth[i]))
    if i % 100 == 0:
        print ('\rCalculating exact accuracy... %s/%s' % (i, len(results)), end='', flush=True)

Calculating exact accuracy... 3000/3005

In [10]:
ks = [1, 3, 5, 10, 50]
k_match = {k:0 for k in ks}

for k in correct_k:
    for kk in k_match:
        if k <= kk and k != -1:
            k_match[kk] += 1

for kk in k_match:
    print (kk, k_match[kk]/len(correct_k))

1 0.5696821515892421
3 0.7799511002444988
5 0.8410757946210269
10 0.8948655256723717
50 0.9682151589242054


In [5]:
print (len(correct_k))

665
