In [None]:
from rdkit import Chem, RDLogger 
RDLogger.DisableLog('rdApp.*')

def demap(smi):
    mol = Chem.MolFromSmiles(smi)
    [atom.SetAtomMapNum(0) for atom in mol.GetAtoms()]
    return set(Chem.MolToSmiles(mol).split('.'))

def clean_product(products):
    products = '.'.join([product for product in products.split('.') if Chem.MolFromSmiles(product).GetNumAtoms() > 1])
    return demap(products)
        
def exact_match(reactants, products, preds):
    if len(products) == 0:
        return 1
    try:
        for k, pred in enumerate(preds):
            pred_set = reactants.union(set(pred.split('.')))
            if pred_set.intersection(products):
                return k+1
        return -1
    except:
        return -1

In [None]:
import pandas as pd
dataset = 'USPTO_480k'

reactants = {}
products = {}
with open('data/%s/test.txt' % dataset, 'r') as f:
    for i, line in enumerate(f.readlines()):
        rxn = line.split(' ')[0]
        reactant, product = rxn.split('>>')
        reactants[i] = demap(reactant)
        products[i] = clean_product(product)

In [None]:
scenario = 'sep' # sep or mix

result_file = 'outputs/decoded_prediction/LocalTransform_%s.txt' % scenario

results = {}
scores = {}
with open(result_file, 'r') as f:
    for i, line in enumerate(f.readlines()):
        predictions = line.split('\n')[0].split('\t')[1:]
        results[i] = [eval(p)[0] for p in predictions]
        scores[i] = [eval(p)[1] for p in predictions]

In [None]:
exact_matches = []
for i in range(len(results)):
    exact_matches.append(exact_match(reactants[i], products[i], results[i]))
    if i % 100 == 0:
        print ('\rCalculating exact accuracy... %s/%s' % (i, len(results)), end='', flush=True)

In [None]:
ks = [1, 2, 3, 5]
exact_accu = {k:0 for k in ks}
for i in range(len(exact_matches)):
    for k in ks:
        if exact_matches[i] <= k and exact_matches[i] != -1:
            exact_accu [k] += 1

for k in ks:
    print ('Top-%d accuracy: %.3f,' % (k, exact_accu[k]/len(results)))