In [1]:
from rdkit import Chem, RDLogger 
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers

RDLogger.DisableLog('rdApp.*')

def demap(smi):
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    smi = Chem.MolToSmiles(mol)
    return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
    
def get_isomers(smi):
    mol = Chem.MolFromSmiles(smi)
    isomers = tuple(EnumerateStereoisomers(mol))
    isomers_smi = [Chem.MolToSmiles(x, isomericSmiles=True) for x in isomers]
    return isomers_smi
    
def get_MaxFrag(smiles):
    return max(smiles.split('.'), key=len)

def isomer_match(preds, reac, MaxFrag = False):
    try:
        if MaxFrag:
            reac = get_MaxFrag(reac)
        reac_isomers = get_isomers(reac)
        for k, pred in enumerate(preds):
            if MaxFrag:
                pred = get_MaxFrag(pred)
            pred_isomers = get_isomers(pred)
            if(set(pred_isomers).issubset(set(reac_isomers))):
                return k+1
        return -1
    except:
        return -1

In [2]:
import pandas as pd
dataset = 'USPTO_50K'
test_file = pd.read_csv('data/%s/raw_test.csv' % dataset)

In [3]:
rxn_ps = [rxn.split('>>')[1] for rxn in test_file['reactants>reagents>production']]

In [4]:
ground_truth = [demap(rxn.split('>>')[0]) for rxn in test_file['reactants>reagents>production']]

In [5]:
GRA = True
class_given = False

result_dir = 'outputs/decoded_prediction' 
if class_given:
    result_dir += '_class'
    
if GRA:
    result_file = '%s/%s.txt' % (result_dir, dataset)
else:
    result_file = '%s/%s_noGRA.txt' % (result_dir, dataset)


results = {}
with open(result_file, 'r') as f:
    for i, line in enumerate(f.readlines()):
        results[i] = line.split('\n')[0].split('\t')

In [6]:
Exact_matches = []
MaxFrag_matches = [] # Description in Supporting Information

Exact_matches_multi = []
MaxFrag_matches_mumlti = [] 
for i in range(len(results)):
    match_exact = isomer_match(results[i], ground_truth[i], False)
    match_maxfrag = isomer_match(results[i], ground_truth[i], True)
    if len(rxn_ps[i].split('.')) > 1:
        Exact_matches_multi.append(match_exact)
        MaxFrag_matches_mumlti.append(match_maxfrag)
    Exact_matches.append(match_exact)
    MaxFrag_matches.append(match_maxfrag)
    if i % 100 == 0:
        print ('\rCalculating accuracy... %s/%s' % (i, len(results)), end='', flush=True)

Calculating accuracy... 5000/5007

In [7]:
ks = [1, 3, 5, 10, 50]
exact_k = {k:0 for k in ks}
MaxFrag_k = {k:0 for k in ks}

print(len(Exact_matches))
for i in range(len(Exact_matches)):
    for k in ks:
        if Exact_matches[i] <= k and Exact_matches[i] != -1:
            exact_k[k] += 1
        if MaxFrag_matches[i] <= k and MaxFrag_matches[i] != -1:
            MaxFrag_k[k] += 1

for k in ks:
    print ('Top-%d Exact accuracy: %.3f, MaxFrag accuracy: %.3f' % (k, exact_k[k]/len(Exact_matches), MaxFrag_k[k]/len(MaxFrag_matches)))

5007
Top-1 Exact accuracy: 0.534, MaxFrag accuracy: 0.578
Top-3 Exact accuracy: 0.775, MaxFrag accuracy: 0.808
Top-5 Exact accuracy: 0.859, MaxFrag accuracy: 0.883
Top-10 Exact accuracy: 0.924, MaxFrag accuracy: 0.940
Top-50 Exact accuracy: 0.977, MaxFrag accuracy: 0.983
