In [16]:
from rdkit import Chem, RDLogger 
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers

RDLogger.DisableLog('rdApp.*')

def demap(smi):
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    smi = Chem.MolToSmiles(mol)
    return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
    
def get_isomers(smi):
    mol = Chem.MolFromSmiles(smi)
    isomers = tuple(EnumerateStereoisomers(mol))
    isomers_smi = [Chem.MolToSmiles(x, isomericSmiles=True) for x in isomers]
    return isomers_smi
    
def get_MaxFrag(smiles):
    return max(smiles.split('.'), key=len)

def isomer_match(preds, reac):
    try:
        reac_isomers = get_isomers(reac)
        for k, pred in enumerate(preds):
            pred_isomers = get_isomers(pred)
            if(set(pred_isomers).issubset(set(reac_isomers))):
                return k+1
        return -1
    except:
        return -1

In [25]:
import pandas as pd
dataset = 'USPTO_MIT'
test_file = pd.read_csv('data/%s/raw_test.csv' % dataset)

In [26]:
rxn_ps = [rxn.split('>>')[1] for rxn in test_file['reactants>reagents>production']]

In [27]:
ground_truth = [demap(rxn.split('>>')[0]) for rxn in test_file['reactants>reagents>production']]
ground_truth_MaxFrag = [get_MaxFrag(demap(rxn.split('>>')[0])) for rxn in test_file['reactants>reagents>production']]

In [36]:
GRA = True
class_given = False

result_dir = 'outputs/decoded_prediction' 
if class_given:
    result_dir += '_class'
    
if GRA:
    result_file = '%s/%s.txt' % (result_dir, dataset)
else:
    result_file = '%s/%s_noGRA.txt' % (result_dir, dataset)


results = {}
with open(result_file, 'r') as f:
    for i, line in enumerate(f.readlines()):
        results[i] = line.split('\n')[0].split('\t')
        
results_MaxFrag = {}
with open(result_file, 'r') as f:
    for i, line in enumerate(f.readlines()):
        Unique_Max_Frags = []
        for smiles in line.split('\n')[0].split('\t'):
            MaxFrag = get_MaxFrag(smiles) 
            if MaxFrag not in Unique_Max_Frags:
                Unique_Max_Frags.append(MaxFrag)
        results_MaxFrag[i] = Unique_Max_Frags

In [37]:
Exact_matches = []
MaxFrag_matches = [] # Description in Supporting Information

Exact_matches_multi = []
MaxFrag_matches_multi = [] 
for i in range(len(results)):
    match_exact = isomer_match(results[i], ground_truth[i])
    match_maxfrag = isomer_match(results_MaxFrag[i], ground_truth_MaxFrag[i])
    if len(rxn_ps[i].split('.')) > 1:
        Exact_matches_multi.append(match_exact)
        MaxFrag_matches_multi.append(match_maxfrag)
    Exact_matches.append(match_exact)
    MaxFrag_matches.append(match_maxfrag)
    if i % 100 == 0:
        print ('\rCalculating accuracy... %s/%s' % (i, len(results)), end='', flush=True)

Calculating accuracy... 39900/40000

In [38]:
ks = [1, 3, 5, 10, 50]
exact_k = {k:0 for k in ks}
MaxFrag_k = {k:0 for k in ks}

print(len(Exact_matches))
for i in range(len(Exact_matches)):
    for k in ks:
        if Exact_matches[i] <= k and Exact_matches[i] != -1:
            exact_k[k] += 1
        if MaxFrag_matches[i] <= k and MaxFrag_matches[i] != -1:
            MaxFrag_k[k] += 1

for k in ks:
    print ('Top-%d Exact accuracy: %.3f, MaxFrag accuracy: %.3f' % (k, exact_k[k]/len(Exact_matches), MaxFrag_k[k]/len(MaxFrag_matches)))

40000
Top-1 Exact accuracy: 0.541, MaxFrag accuracy: 0.603
Top-3 Exact accuracy: 0.737, MaxFrag accuracy: 0.800
Top-5 Exact accuracy: 0.794, MaxFrag accuracy: 0.853
Top-10 Exact accuracy: 0.844, MaxFrag accuracy: 0.899
Top-50 Exact accuracy: 0.904, MaxFrag accuracy: 0.939


In [39]:
ks = [1, 3, 5, 10, 50]
exact_k_multi = {k:0 for k in ks}
MaxFrag_k_multi = {k:0 for k in ks}

print(len(Exact_matches_multi))
for i in range(len(Exact_matches_multi)):
    for k in ks:
        if Exact_matches_multi[i] <= k and Exact_matches_multi[i] != -1:
            exact_k_multi[k] += 1
        if MaxFrag_matches_multi[i] <= k and MaxFrag_matches_multi[i] != -1:
            MaxFrag_k_multi[k] += 1

for k in ks:
    print ('Top-%d Exact accuracy: %.3f, MaxFrag accuracy: %.3f' % (k, exact_k_multi[k]/len(Exact_matches_multi), MaxFrag_k_multi[k]/len(MaxFrag_matches_multi)))

471
Top-1 Exact accuracy: 0.444, MaxFrag accuracy: 0.469
Top-3 Exact accuracy: 0.669, MaxFrag accuracy: 0.699
Top-5 Exact accuracy: 0.730, MaxFrag accuracy: 0.766
Top-10 Exact accuracy: 0.752, MaxFrag accuracy: 0.790
Top-50 Exact accuracy: 0.758, MaxFrag accuracy: 0.798
