In [1]:
import pandas as pd
from rdkit import Chem
import rdkit.Chem.AllChem as AllChem
from rdkit.Chem import PandasTools
from rdkit import DataStructs
from rdkit import RDLogger
import numpy as np
import time
import func_timeout

In [2]:
def extract_template(rxn_smiles,max_prec):
    import rdchiral
    from rdchiral import template_extractor
    templates = []
    for i in index[0:max_prec]:
        rxnmapped = rxn_smiles[i]
        reaction = {
        'reactants': rxnmapped.split('>')[0],
        'products': rxnmapped.split('>')[-1],
        '_id': i,
        }
        template = template_extractor.extract_from_reaction(reaction)
        templates.append(template)
    #print(templates)
    #print(len(templates))
    #display(templates[0]['reaction_smarts'])
    #print('---------------------------')
    rxn_template = []
    n = 0
    for i in templates:
        if i is None or 'reaction_smarts' not in i:
            rxn_smarts = None
            n += 1
            #print(rxn_smarts)
            #print(i['reaction_id'])
            rxn_template.append(rxn_smarts)
            continue
        rxn_smarts = '(' + i['reaction_smarts'].replace('>>', ')>>')
        #print(rxn_smarts)
        #print(i['reaction_id'])
        rxn_template.append(rxn_smarts)
    #print(len(rxn_template))
    #print("number of invalid templates is %d" %n)
    return(rxn_template)

In [3]:
@func_timeout.func_set_timeout(20) 
def get_precru(i, rct, combine_enantiomers=False):
    from rdchiral.main import rdchiralRun
    return rdchiralRun(i, rct, combine_enantiomers=combine_enantiomers)

def propose_precursors(rxn_templates):
    from rdkit.Chem import rdChemReactions
    from rdchiral.main import rdchiralReactants, rdchiralReaction, rdchiralRun
    rct = rdchiralReactants(product_smiles)
    rxns = []
    precursors = []
    n=0
    for i in rxn_templates: 
        if i is None:
            #print('------------------------------')
            #print("invalid rxn_template")
            #print('------------------------------')
            rxn = None
            rxns.append(rxn)
            continue
        rxn_i = AllChem.ReactionFromSmarts(i)
        #num_r = rxn_i.GetNumReactantTemplates()
        #num_p = rxn_i.GetNumProductTemplates()
        #n += 1
        #print(num_r, num_p, n)
        #print('------------------------------')
        rxn = rdchiralReaction(i)
        rxns.append(rxn)
        #print(rxns)
    #print('##############################')
    for i in rxns:
        if i is None:
            precursor = None
            precursors.append(precursor)
            #print(precursor)
            continue
        
        try:
            precursor = get_precru(i, rct, combine_enantiomers=False)
        except:
            precursor = None
        precursors.append(precursor)

        # print(precursor)
    #print(len(precursors))
    return(precursors)

In [4]:
def get_rxns():
    rxns =[]
    coindexs = []
    for idx,i in enumerate(precursors):
        if i is None:
            rxnstr = None
            rxns.append(rxnstr)
            coindex = index[idx]
            coindexs.append(coindex)
            #print('------------------------------')
            #print("this reaction could not propose suitable smiles")
            #print('------------------------------')
            continue
        if len(i) ==0:
            rxnstr = None
            rxns.append(rxnstr)
            coindex = index[idx]
            coindexs.append(coindex)
            #print('------------------------------')
            #print("this reaction could not propose suitable smiles")
            #print('------------------------------')
            continue
        for j in i:
            m = Chem.MolFromSmiles(j,sanitize=False)
            if m is None:
                print('invalid smiles')
                rxnstr = None
                rxns.append(rxnstr)
                coindex = index[idx]
                coindexs.append(coindex)
            else:
                try:
                    Chem.SanitizeMol(m)
                    #print("smiles is ok")
                    #display(m)
                    rxnstr=j +'>>'+product_smiles
                    #print(rxnstr)
                    rxns.append(rxnstr)
                    coindex = index[idx]
                    coindexs.append(coindex)

                except:
                    #print('invalid chemistry')
                    rxnstr = None
                    rxns.append(rxnstr)
                    coindex = index[idx]
                    coindexs.append(coindex) 
    #print(rxns)
    #print(len(rxns))
    #print(coindexs)
    #print(len(coindexs))
    repeat_indexs = []
    t = {}
    for i in range(len(coindexs)):
        for j in range(i):
             if coindexs[i] == coindexs[j]:
                repeat_index = coindexs[j]
                #print(repeat_index)
                repeat_indexs.append(repeat_index)
    #print('------------------------------')
    t = list(zip(coindexs,rxns))
    #print(t)
    #print(len(t))
    return rxns,t

In [5]:
def remove_reactant_smiles():
    Reactant_smiles = []
    Reactant_mols = []
    for rxn_smi in df['Rxn']:
        rcts_smi = rxn_smi.split('>>')[0]
        rcts_mol = Chem.MolFromSmiles(rcts_smi)
        [atom.ClearProp('molAtomMapNumber') for atom in rcts_mol.GetAtoms()]
        Reactant_smile = Chem.MolToSmiles(rcts_mol, True)
        Reactant_smiles.append(Reactant_smile)
        Reactant_mol = Chem.MolFromSmiles(Reactant_smile, True)
        Reactant_mols.append(Reactant_mol) 
    return(Reactant_mols)

In [6]:
def rank():
    s_reacs =[]
    ss=[]
    for x in index:
        for k,v in t:
            if k == x:
                #print(k)

                if v is None:
                    #print(v)
                    s_reac = 0
                    s_reacs.append(s_reac)
                    s = 0
                    ss.append(s)
                    continue
                else:
                    i = str(v).split('>>')[0]
                    #print(i)
                    Proposal_mol = Chem.MolFromSmiles(i)
                    if Proposal_mol is None:
                        s_reac = 0
                        s_reacs.append(s_reac)
                        s = 0
                        ss.append(s)
                        continue
                    mfp_proposal = AllChem.GetMorganFingerprint(Proposal_mol,2,useFeatures=True)
                    mfp_Reactant_mol = AllChem.GetMorganFingerprint(Reactant_mols[x],2,useFeatures=True)
                    s_reac = DataStructs.TanimotoSimilarity(mfp_Reactant_mol, mfp_proposal)
                    s_reacs.append(s_reac)
                    s = s_reac*s_prods[x]
                    ss.append(s)
    #print(ss)
    order = np.argsort(ss)[::-1]
    #print(order)
    #print(len(ss))
    return(order)

In [7]:
def remove_atom_mapping():
    df_1=pd.read_csv("data_test_reaxys.csv",encoding="gbk")
    remove_prod_smiles = []
    remove_rcts_smiles = []
    remove_rxns_smiles = []

    for rxn_smi in df_1['Rxn']:
        prod_smi = rxn_smi.split('>>')[-1]
        prod_mol = Chem.MolFromSmiles(prod_smi)
        [atom.ClearProp('molAtomMapNumber') for atom in prod_mol.GetAtoms()]
        prod_smi_remove_atom_map = Chem.MolToSmiles(prod_mol, True)
        prod_smi_remove_atom_map = Chem.MolToSmiles(Chem.MolFromSmiles(prod_smi_remove_atom_map), True)
        remove_prod_smiles.append(prod_smi_remove_atom_map)      
    for rxn_smi in df_1['Rxn']:
        rcts_smi = rxn_smi.split('>>')[0]
        rcts_mol = Chem.MolFromSmiles(rcts_smi)
        [atom.ClearProp('molAtomMapNumber') for atom in rcts_mol.GetAtoms()]
        rcts_smi_remove_atom_map = Chem.MolToSmiles(rcts_mol, True)
        # Sometimes stereochem takes another canonicalization...
        rcts_smi_remove_atom_map = Chem.MolToSmiles(Chem.MolFromSmiles(rcts_smi_remove_atom_map), True)
        remove_rcts_smiles.append(rcts_smi_remove_atom_map)            

        rxn_smi_remove_atom_map = rcts_smi_remove_atom_map + '>>' + prod_smi_remove_atom_map
        remove_rxns_smiles.append(rxn_smi_remove_atom_map) 
    #print(remove_prod_smiles)
    #print('***********************')
    #print(remove_rcts_smiles)
    #print('***********************')
    #print(remove_rxns_smiles)    
    dataframe_test_remove = pd.DataFrame({'Product':remove_prod_smiles[:],'Reactant':remove_rcts_smiles[:],'Rxn':remove_rxns_smiles[:]})
    dataframe_test_remove.to_csv('remove_data_test_reaxys.csv',index=False)
    return()

In [8]:
def get_accuracy(n):
    accuracy = []
    m = 0
    for i,j in enumerate(order):
        if i < n:
            if rxns[j] is None:
                continue
            x = rxns[j].split('>>')[0]
            if x == true_precursors:
                display(Chem.MolFromSmiles(x))
                a = np.zeros(i,dtype = int)
                b = np.ones(n-i,dtype = int)
                accuracy = np.concatenate((a,b))
                #print(accuracy)
                m = m +1
                #print(m)
                break
        else:
            break
    if m == 0:
        accuracy = np.zeros(n,dtype = int)
        #print(accuracy)
    return(accuracy)

In [9]:
RDLogger.DisableLog('rdApp.*')
print('Starting......')
start_time = time.time()
#train:smiles2mol2fingerprint
Product_mols = []
df=pd.read_csv("data_train_USPTO.csv",encoding="gbk")
for i in df['Product']:
    Product_mol = Chem.MolFromSmiles(i)
    Product_mols.append(Product_mol)
mfps = []
for i in Product_mols:
    if i != None:
        mfp = AllChem.GetMorganFingerprint(i,2,useFeatures=True)
        mfps.append(mfp)


Starting......


In [10]:
#test:smiles2mol2fingerprint
df_1=pd.read_csv("data_test_reaxys.csv",encoding="gbk")
Reactant_mols = remove_reactant_smiles()
remove_atom_mapping()
n = 50
accuracies = np.empty([0,n])
total_length = len(df_1)
print('total length: ', total_length)

for a,i in enumerate(df_1['Product']):
    curr_time = time.time()
    spent_time = curr_time - start_time
    print('(', a+1, ' / ', total_length, ')......', "Spent time:", spent_time)
    start_time = time.time()
    
    product_smiles = i
    product_mol = Chem.MolFromSmiles(product_smiles)
    product_mfp = AllChem.GetMorganFingerprint(product_mol,2,useFeatures=True)
    #get product_similarity
    s_prods = []
    for i in mfps:
        s_prod = DataStructs.TanimotoSimilarity(product_mfp,i)
        s_prods.append(s_prod)
    #print(s_prods)
    index = np.argsort(s_prods)[::-1]
    #print(index)
    #print(len(index))
    rxn_template = extract_template(df['Rxn'],200)
    try:
        precursors = propose_precursors(rxn_template)
    except:
        continue
    rxns,t = get_rxns()
    order = rank()
    df_2 = pd.read_csv("remove_data_test_reaxys.csv",encoding="gbk")
    true_precursors = df_2['Reactant'][a]
    #true_precursors = df_1['Reactant'][a]
    accuracy = get_accuracy(n)
    accuracies = np.append(accuracies,[accuracy],axis=0)
    if a % int(total_length//20) == 0:
        print(5*a//(total_length//20), "%, ", 'accuracy:', np.mean(accuracies, axis=0))
#print(accuracies)
mean_accuracies = np.mean(accuracies, axis=0)
print(f"Top {n} accuracy: {mean_accuracies[n-1] * 100: .2f} %")

total length:  1086
( 1  /  1086 )...... Spent time: 103.48859071731567
0 %,  accuracy: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0.]
( 2  /  1086 )...... Spent time: 4.656772613525391
( 3  /  1086 )...... Spent time: 5.364403963088989
( 4  /  1086 )...... Spent time: 4.515038251876831
( 5  /  1086 )...... Spent time: 4.596171855926514
( 6  /  1086 )...... Spent time: 4.670725584030151
( 7  /  1086 )...... Spent time: 4.593396902084351
( 8  /  1086 )...... Spent time: 4.957525253295898
( 9  /  1086 )...... Spent time: 4.6630048751831055
( 10  /  1086 )...... Spent time: 4.430240631103516
( 11  /  1086 )...... Spent time: 4.878950834274292
( 12  /  1086 )...... Spent time: 5.124767541885376
( 13  /  1086 )...... Spent time: 4.4201273918151855
( 14  /  1086 )...... Spent time: 4.366297483444214
( 15  /  1086 )...... Spent time: 4.530383348464966
( 16  /  1086 )...... Spent time: 4.

( 148  /  1086 )...... Spent time: 5.66387939453125
( 149  /  1086 )...... Spent time: 5.8363869190216064
( 150  /  1086 )...... Spent time: 5.37966775894165
( 151  /  1086 )...... Spent time: 6.025655031204224
( 152  /  1086 )...... Spent time: 5.346535682678223
( 153  /  1086 )...... Spent time: 5.387300729751587
( 154  /  1086 )...... Spent time: 5.288697957992554
( 155  /  1086 )...... Spent time: 5.366674423217773
( 156  /  1086 )...... Spent time: 5.0377161502838135
( 157  /  1086 )...... Spent time: 5.351791620254517
( 158  /  1086 )...... Spent time: 5.179107189178467
( 159  /  1086 )...... Spent time: 5.144988775253296
( 160  /  1086 )...... Spent time: 5.248072385787964
( 161  /  1086 )...... Spent time: 5.216507196426392
( 162  /  1086 )...... Spent time: 5.476633548736572
( 163  /  1086 )...... Spent time: 5.491496562957764
15 %,  accuracy: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0

( 293  /  1086 )...... Spent time: 6.7810094356536865
( 294  /  1086 )...... Spent time: 6.736984491348267
( 295  /  1086 )...... Spent time: 6.179653882980347
( 296  /  1086 )...... Spent time: 6.404293060302734
( 297  /  1086 )...... Spent time: 6.707400560379028
( 298  /  1086 )...... Spent time: 6.6686179637908936
( 299  /  1086 )...... Spent time: 6.6951117515563965
( 300  /  1086 )...... Spent time: 6.070037364959717
( 301  /  1086 )...... Spent time: 6.651598930358887
( 302  /  1086 )...... Spent time: 6.542752981185913
( 303  /  1086 )...... Spent time: 6.74303936958313
( 304  /  1086 )...... Spent time: 6.315906763076782
( 305  /  1086 )...... Spent time: 6.438255071640015
( 306  /  1086 )...... Spent time: 7.84782600402832
( 307  /  1086 )...... Spent time: 6.927485466003418
( 308  /  1086 )...... Spent time: 6.563281297683716
( 309  /  1086 )...... Spent time: 6.991959095001221
( 310  /  1086 )...... Spent time: 8.020437002182007
( 311  /  1086 )...... Spent time: 7.09579682

( 439  /  1086 )...... Spent time: 7.2899229526519775
( 440  /  1086 )...... Spent time: 7.277369976043701
( 441  /  1086 )...... Spent time: 8.392072677612305
( 442  /  1086 )...... Spent time: 7.313794374465942
( 443  /  1086 )...... Spent time: 6.511108160018921
( 444  /  1086 )...... Spent time: 6.133446216583252
( 445  /  1086 )...... Spent time: 6.417776823043823
( 446  /  1086 )...... Spent time: 6.5715491771698
( 447  /  1086 )...... Spent time: 6.564323663711548
( 448  /  1086 )...... Spent time: 7.213678598403931
( 449  /  1086 )...... Spent time: 6.760992765426636
( 450  /  1086 )...... Spent time: 6.43156886100769
( 451  /  1086 )...... Spent time: 6.7306294441223145
( 452  /  1086 )...... Spent time: 6.447459936141968
( 453  /  1086 )...... Spent time: 6.794288158416748
( 454  /  1086 )...... Spent time: 6.540819406509399
( 455  /  1086 )...... Spent time: 6.740350246429443
( 456  /  1086 )...... Spent time: 8.744079113006592
( 457  /  1086 )...... Spent time: 11.507470607

( 588  /  1086 )...... Spent time: 9.103246450424194
( 589  /  1086 )...... Spent time: 9.060258388519287
( 590  /  1086 )...... Spent time: 10.081476211547852
( 591  /  1086 )...... Spent time: 12.502935647964478
( 592  /  1086 )...... Spent time: 11.946780681610107
( 593  /  1086 )...... Spent time: 10.442198038101196
( 594  /  1086 )...... Spent time: 11.777933597564697
( 595  /  1086 )...... Spent time: 10.670252799987793
55 %,  accuracy: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0.]
( 596  /  1086 )...... Spent time: 9.957935571670532
( 597  /  1086 )...... Spent time: 9.410788297653198
( 598  /  1086 )...... Spent time: 8.56952953338623
( 599  /  1086 )...... Spent time: 9.596198797225952
( 600  /  1086 )...... Spent time: 8.375972986221313
( 601  /  1086 )...... Spent time: 8.256152629852295
( 602  /  1086 )...... Spent time: 8.641246318817139
( 603  /  1086 )...... Spent 

( 733  /  1086 )...... Spent time: 6.992573022842407
( 734  /  1086 )...... Spent time: 7.169259786605835
( 735  /  1086 )...... Spent time: 7.045241594314575
( 736  /  1086 )...... Spent time: 7.769818305969238
( 737  /  1086 )...... Spent time: 8.174843072891235
( 738  /  1086 )...... Spent time: 11.142019510269165
( 739  /  1086 )...... Spent time: 9.861464023590088
( 740  /  1086 )...... Spent time: 8.659758567810059
( 741  /  1086 )...... Spent time: 8.384944200515747
( 742  /  1086 )...... Spent time: 7.583571195602417
( 743  /  1086 )...... Spent time: 7.697145223617554
( 744  /  1086 )...... Spent time: 8.50637936592102
( 745  /  1086 )...... Spent time: 8.029481172561646
( 746  /  1086 )...... Spent time: 10.193994998931885
( 747  /  1086 )...... Spent time: 8.977096557617188
( 748  /  1086 )...... Spent time: 7.906083106994629
( 749  /  1086 )...... Spent time: 9.734364032745361
( 750  /  1086 )...... Spent time: 7.870616674423218
( 751  /  1086 )...... Spent time: 8.07194375

( 878  /  1086 )...... Spent time: 5.81367301940918
( 879  /  1086 )...... Spent time: 5.867951154708862
( 880  /  1086 )...... Spent time: 5.97847843170166
( 881  /  1086 )...... Spent time: 5.956532716751099
( 882  /  1086 )...... Spent time: 6.1105265617370605
( 883  /  1086 )...... Spent time: 5.93735933303833
( 884  /  1086 )...... Spent time: 5.982226848602295
( 885  /  1086 )...... Spent time: 6.128127098083496
( 886  /  1086 )...... Spent time: 6.05482816696167
( 887  /  1086 )...... Spent time: 5.991153240203857
( 888  /  1086 )...... Spent time: 5.6335976123809814
( 889  /  1086 )...... Spent time: 5.8981146812438965
( 890  /  1086 )...... Spent time: 6.008402585983276
( 891  /  1086 )...... Spent time: 5.663800954818726
( 892  /  1086 )...... Spent time: 6.0067784786224365
( 893  /  1086 )...... Spent time: 6.322032928466797
( 894  /  1086 )...... Spent time: 6.398798704147339
( 895  /  1086 )...... Spent time: 6.446604251861572
( 896  /  1086 )...... Spent time: 6.322132825

( 1026  /  1086 )...... Spent time: 8.97342324256897
( 1027  /  1086 )...... Spent time: 7.523941516876221
95 %,  accuracy: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0.]
( 1028  /  1086 )...... Spent time: 7.023612976074219
( 1029  /  1086 )...... Spent time: 6.578040599822998
( 1030  /  1086 )...... Spent time: 7.237592458724976
( 1031  /  1086 )...... Spent time: 7.1860010623931885
( 1032  /  1086 )...... Spent time: 7.5683581829071045
( 1033  /  1086 )...... Spent time: 7.316396951675415
( 1034  /  1086 )...... Spent time: 8.076908349990845
( 1035  /  1086 )...... Spent time: 8.976901292800903
( 1036  /  1086 )...... Spent time: 9.246550559997559
( 1037  /  1086 )...... Spent time: 8.448448657989502
( 1038  /  1086 )...... Spent time: 10.141088247299194
( 1039  /  1086 )...... Spent time: 8.081944465637207
( 1040  /  1086 )...... Spent time: 7.5294764041900635
( 1041  /  1086 