In [1]:
import json
import pickle
import numpy as np
from queue import Queue
from tqdm import tqdm
from rdkit import Chem
from collections import Counter
from tqdm import trange, tqdm
from BLEU_utils import *

In [2]:
all_routes = pickle.load(open('data/all_routes.pickle', 'rb'))
all_templates = json.load(open('../teamdrive/projects/n5routes/templates/all_routes_templates_1_0_0.json'))
golden_dict = pickle.load(open('../teamdrive/projects/n5routes/templates/golden_dict.pickle', 'rb'))
test_dict = json.load(open(f'data/test_routes_templates_1_0_0.json'))

golden_routes = pickle.load(open('test_routes/routes_golden.pkl', 'rb'))
retrostar_routes = json.load(open('test_routes/routes_retrostar.json'))['routes']
retrostar_routes = [[rxn.split('>')[0] + '>>' + rxn.split('>')[-1] for rxn in route.split('|')] for route in retrostar_routes if isinstance(route, str)]
retrostarplus_routes = json.load(open('test_routes/routes_retrostarplus.json'))['routes']
retrostarplus_routes = [[rxn.split('>')[0] + '>>' + rxn.split('>')[-1] for rxn in route.split('|')] for route in retrostarplus_routes if isinstance(route, str)]
egmcts_routes = json.load(open('test_routes/routes_egmcts.json'))['routes']
egmcts_routes = [[rxn.split('>')[0] + '>>' + rxn.split('>')[-1] for rxn in route.split('|')] for route in egmcts_routes if isinstance(route, str)]
retrograph_routes = json.load(open('test_routes/routes_retrograph.json'))
golden_route_dict = build_route_dict(golden_routes)
retrostar_route_dict = build_route_dict(retrostar_routes)
retrostarplus_route_dict = build_route_dict(retrostarplus_routes)
egmcts_route_dict = build_route_dict(egmcts_routes)
retrograph_route_dict = build_route_dict(retrograph_routes)

test_route_dicts = [golden_route_dict, retrostar_route_dict, retrostarplus_route_dict, egmcts_route_dict, retrograph_route_dict]
targets = [routes[0].split('>>')[0] for routes in golden_routes]

In [3]:
all_ngram_rxns, all_ngram_templates, all_rxns = build_vocab(all_routes, all_templates)
set_ngram_rxns = [set(ngram_rxns) for ngram_rxns in all_ngram_rxns]
set_ngram_templates = [set(ngram_templates) for ngram_templates in all_ngram_templates]
print('Number of n-gram rxns for n in [2, 3, 4, 5]:', [len(ngram_rxns) for ngram_rxns in all_ngram_rxns])
print('Number of n-gram rxns for n in [2, 3, 4, 5] (after removing duplicates):', [len(ngram_rxns) for ngram_rxns in set_ngram_rxns])
print('Number of n-gram templates for n in [2, 3, 4, 5] (after removing duplicates):', [len(ngram_templates) for ngram_templates in set_ngram_templates])

Number of n-gram rxns for n in [2, 3, 4, 5]: [818622, 342358, 141804, 56046]
Number of n-gram rxns for n in [2, 3, 4, 5] (after removing duplicates): [253758, 106978, 44796, 17836]
Number of n-gram templates for n in [2, 3, 4, 5] (after removing duplicates): [112047, 69456, 31191, 12816]


In [4]:
vocab_routes = []
test_routes = []
for idx, route in enumerate(all_routes):
    rxn_nodes = extract_rxns(route)
    ID = rxn_nodes[0]['metadata']['ID'].split(';')[0][2:]
    if ID[:4] in ['2014', '2015', '2016']:
        test_routes.append(route)
    else:
        vocab_routes.append(route)
print('Number of vocab routes:', len(vocab_routes))
print('Number of test routes:', len(test_routes))
all_ngram_rxns, all_ngram_templates, all_rxns = build_vocab(vocab_routes, all_templates)
all_ngram_templates = [[tuple([all_templates[rxn] for rxn in curr_rxn_set]) for curr_rxn_set in n_gram_rxns] for n_gram_rxns in all_ngram_rxns]
set_ngram_rxns = [set(ngram_rxns) for ngram_rxns in all_ngram_rxns]
set_ngram_templates = [set(ngram_templates) for ngram_templates in all_ngram_templates]
print('Number of n-gram rxns for n in [2, 3, 4, 5] in routes before 2013:', [len(ngram_rxns) for ngram_rxns in all_ngram_rxns])
print('Number of n-gram rxns for n in [2, 3, 4, 5] in routes before 2013 (after removing duplicates):', [len(ngram_rxns) for ngram_rxns in set_ngram_rxns])
print('Number of n-gram templates for n in [2, 3, 4, 5] in routes before 2013 (after removing duplicates):', [len(ngram_templates) for ngram_templates in set_ngram_templates])

Number of vocab routes: 395121
Number of test routes: 62326
Number of n-gram rxns for n in [2, 3, 4, 5] in routes before 2013: [701854, 290944, 119587, 46628]
Number of n-gram rxns for n in [2, 3, 4, 5] in routes before 2013 (after removing duplicates): [227607, 95849, 39985, 15895]
Number of n-gram templates for n in [2, 3, 4, 5] in routes before 2013 (after removing duplicates): [102480, 62376, 27888, 11426]


In [5]:
bleu_ratio, bleu_template_ratio = evaluate_routes(test_routes, all_templates, set_ngram_templates, set_ngram_rxns)
print(f'n-gram rxn ratio for n in [2, 3, 4, 5] in routes after 2014:', [np.mean(bleu_ratio[i]) for i in range(4)])
print(f'n-gram template ratio for n in [2, 3, 4, 5] in patents after 2014:', [np.mean(bleu_template_ratio[i]) for i in range(4)])

n-gram rxn ratio for n in [2, 3, 4, 5] in routes after 2014: [0.6996204799240451, 0.7093509749685522, 0.7057417117264733, 0.7172911202697639]
n-gram template ratio for n in [2, 3, 4, 5] in patents after 2014: [0.8236269022598164, 0.7340199872266381, 0.718275165219291, 0.7224085175471462]


In [6]:
vocab_routes = []
test_routes = []
ID_random_dict = {}
for idx, route in enumerate(all_routes):
    rxn_nodes = extract_rxns(route)
    ID = rxn_nodes[0]['metadata']['ID'].split(';')[0][2:]
    if ID not in ID_random_dict:
        ID_random_dict[ID] = np.random.random()
    if ID_random_dict[ID] < 0.2:
        test_routes.append(route)
    else:
        vocab_routes.append(route)
print('Number of vocab routes:', len(vocab_routes))
print('Number of test routes:', len(test_routes))
all_ngram_rxns, all_ngram_templates, all_rxns = build_vocab(vocab_routes, all_templates)
all_ngram_templates = [[tuple([all_templates[rxn] for rxn in curr_rxn_set]) for curr_rxn_set in n_gram_rxns] for n_gram_rxns in all_ngram_rxns]
set_ngram_rxns = [set(ngram_rxns) for ngram_rxns in all_ngram_rxns]
set_ngram_templates = [set(ngram_templates) for ngram_templates in all_ngram_templates]
print('Number of n-gram rxns for n in [2, 3, 4, 5] in randomly 80% vocab patents:', [len(ngram_rxns) for ngram_rxns in all_ngram_rxns])
print('Number of n-gram rxns for n in [2, 3, 4, 5] in randomly 80% vocab patents (after removing duplicates):', [len(ngram_rxns) for ngram_rxns in set_ngram_rxns])
print('Number of n-gram templates for n in [2, 3, 4, 5] in randomly 80% vocab patents (after removing duplicates):', [len(ngram_templates) for ngram_templates in set_ngram_templates])

Number of vocab routes: 365508
Number of test routes: 91939
Number of n-gram rxns for n in [2, 3, 4, 5] in randomly 80% vocab patents: [653028, 272586, 112492, 44470]
Number of n-gram rxns for n in [2, 3, 4, 5] in randomly 80% vocab patents (after removing duplicates): [232795, 98590, 41546, 16550]
Number of n-gram templates for n in [2, 3, 4, 5] in randomly 80% vocab patents (after removing duplicates): [103951, 64090, 28893, 11911]


In [7]:
bleu_ratio, bleu_template_ratio = evaluate_routes(test_routes, all_templates, set_ngram_templates, set_ngram_rxns)
print(f'n-gram rxn ratio for n in [2, 3, 4, 5] in randomly 20% test patents:', [np.mean(bleu_ratio[i]) for i in range(4)])
print(f'n-gram template ratio for n in [2, 3, 4, 5] in randomly 20% test patents:', [np.mean(bleu_template_ratio[i]) for i in range(4)])

n-gram rxn ratio for n in [2, 3, 4, 5] in randomly 20% test patents: [0.8174986741868043, 0.8259449513344698, 0.8437340846367436, 0.8506033522606918]
n-gram template ratio for n in [2, 3, 4, 5] in randomly 20% test patents: [0.9007241673368126, 0.8503401783925865, 0.8599397965266724, 0.8575460030475097]


In [8]:
vocab_routes = []
test_routes = []
ID_random_dict = {}
for idx, route in enumerate(all_routes):
    rxn_nodes = extract_rxns(route)
    ID = rxn_nodes[0]['metadata']['ID'].split(';')[0][2:]
    if np.random.random() < 0.2:
        test_routes.append(route)
    else:
        vocab_routes.append(route)
print('Number of vocab routes:', len(vocab_routes))
print('Number of test routes:', len(test_routes))
all_ngram_rxns, all_ngram_templates, all_rxns = build_vocab(vocab_routes, all_templates)
all_ngram_templates = [[tuple([all_templates[rxn] for rxn in curr_rxn_set]) for curr_rxn_set in n_gram_rxns] for n_gram_rxns in all_ngram_rxns]
set_ngram_rxns = [set(ngram_rxns) for ngram_rxns in all_ngram_rxns]
set_ngram_templates = [set(ngram_templates) for ngram_templates in all_ngram_templates]
print('Number of n-gram rxns for n in [2, 3, 4, 5] in randomly 80% vocab routes:', [len(ngram_rxns) for ngram_rxns in all_ngram_rxns])
print('Number of n-gram rxns for n in [2, 3, 4, 5] in randomly 80% vocab routes (after removing duplicates):', [len(ngram_rxns) for ngram_rxns in set_ngram_rxns])
print('Number of n-gram templates for n in [2, 3, 4, 5] in randomly 80% vocab routes (after removing duplicates):', [len(ngram_templates) for ngram_templates in set_ngram_templates])

Number of vocab routes: 365933
Number of test routes: 91514
Number of n-gram rxns for n in [2, 3, 4, 5] in randomly 80% vocab routes: [655125, 274117, 113554, 44930]
Number of n-gram rxns for n in [2, 3, 4, 5] in randomly 80% vocab routes (after removing duplicates): [233872, 98821, 41494, 16559]
Number of n-gram templates for n in [2, 3, 4, 5] in randomly 80% vocab routes (after removing duplicates): [105544, 64914, 29156, 11996]


In [9]:
bleu_ratio, bleu_template_ratio = evaluate_routes(test_routes, all_templates, set_ngram_templates, set_ngram_rxns)
print(f'n-gram rxn ratio for n in [2, 3, 4, 5] in randomly 20% test routes:', [np.mean(bleu_ratio[i]) for i in range(4)])
print(f'n-gram template ratio for n in [2, 3, 4, 5] in randomly 20% test routes:', [np.mean(bleu_template_ratio[i]) for i in range(4)])

n-gram rxn ratio for n in [2, 3, 4, 5] in randomly 20% test routes: [0.8335099166810709, 0.834709017229501, 0.83926453091901, 0.8429471848530145]
n-gram template ratio for n in [2, 3, 4, 5] in randomly 20% test routes: [0.9431573610781163, 0.9071773454215905, 0.9019015940725265, 0.8990800056943555]


In [10]:
np.mean([len(extract_rxns(route)) for route in all_routes])

2.789545018329992

In [11]:
all_ngram_rxns, all_ngram_templates, all_rxns = build_vocab(all_routes, all_templates)
set_ngram_rxns = [set(ngram_rxns) for ngram_rxns in all_ngram_rxns]
set_ngram_templates = [set(ngram_templates) for ngram_templates in all_ngram_templates]

In [12]:
golden_routes = pickle.load(open('test_routes/routes_golden.pkl', 'rb'))
bleu_ratio, bleu_template_ratio = evaluate_routes(golden_routes, test_dict, set_ngram_templates, set_ngram_rxns)
print(f'n-gram rxn ratio for n in [2, 3, 4, 5]:', [np.mean(bleu_ratio[i]) for i in range(4)])
print(f'n-gram template ratio for n in [2, 3, 4, 5]:', [np.mean(bleu_template_ratio[i]) for i in range(4)])

n-gram rxn ratio for n in [2, 3, 4, 5]: [0.0, 0.0, 0.0, 0.0]
n-gram template ratio for n in [2, 3, 4, 5]: [0.24456292245765926, 0.13057631967515687, 0.09487612612612613, 0.06789617486338799]


In [13]:
retro_routes = json.load(open('test_routes/routes_retrostar.json'))['routes']
retro_routes = [[rxn.split('>')[0] + '>>' + rxn.split('>')[-1] for rxn in route.split('|')] for route in retro_routes if isinstance(route, str)]
bleu_ratio, bleu_template_ratio = evaluate_routes(retro_routes, test_dict, set_ngram_templates, set_ngram_rxns)
print(f'n-gram template ratio for n in [2, 3, 4, 5]:', [np.mean(bleu_template_ratio[i]) for i in range(4)])

n-gram template ratio for n in [2, 3, 4, 5]: [0.20682935751117565, 0.07747942696572833, 0.05364823348694316, 0.0405040504050405]


In [14]:
retro_routes = json.load(open('test_routes/routes_retrostarplus.json'))['routes']
retro_routes = [[rxn.split('>')[0] + '>>' + rxn.split('>')[-1] for rxn in route.split('|')] for route in retro_routes if isinstance(route, str)]
bleu_ratio, bleu_template_ratio = evaluate_routes(retro_routes, test_dict, set_ngram_templates, set_ngram_rxns)
print(f'n-gram template ratio for n in [2, 3, 4, 5]:', [np.mean(bleu_template_ratio[i]) for i in range(4)])

n-gram template ratio for n in [2, 3, 4, 5]: [0.20545199336182943, 0.08639840748274483, 0.06358447488584475, 0.041666666666666664]


In [15]:
retro_routes = json.load(open('test_routes/routes_egmcts.json'))['routes']
retro_routes = [[rxn.split('>')[0] + '>>' + rxn.split('>')[-1] for rxn in route.split('|')] for route in retro_routes if isinstance(route, str)]
bleu_ratio, bleu_template_ratio = evaluate_routes(retro_routes, test_dict, set_ngram_templates, set_ngram_rxns)
print(f'n-gram template ratio for n in [2, 3, 4, 5]:', [np.mean(bleu_template_ratio[i]) for i in range(4)])


n-gram template ratio for n in [2, 3, 4, 5]: [0.09017300246808442, 0.03085858585858586, 0.009667024704618688, 0.0]


In [16]:
retro_routes = json.load(open('test_routes/routes_retrograph.json'))
bleu_ratio, bleu_template_ratio = evaluate_routes(retro_routes, test_dict, set_ngram_templates, set_ngram_rxns)
print(f'n-gram template ratio for n in [2, 3, 4, 5]:', [np.mean(bleu_template_ratio[i]) for i in range(4)])

n-gram template ratio for n in [2, 3, 4, 5]: [0.12513487394439776, 0.034314807999018516, 0.011267605633802818, 0.0022727272727272726]
