In [None]:
import pandas as pd
from rdkit import Chem
import rdkit.Chem.AllChem as AllChem
from rdkit.Chem import PandasTools
from rdkit import DataStructs
import numpy as np
import time
import func_timeout
from rdkit import RDLogger

def extract_template(rxn_smiles,max_prec):
    import rdchiral
    from rdchiral import template_extractor
    templates = []
    for i in index[0:max_prec]:
        rxnmapped = rxn_smiles[i]
        reaction = {
        'reactants': rxnmapped.split('>')[0],
        'products': rxnmapped.split('>')[-1],
        '_id': i,
        }
        template = template_extractor.extract_from_reaction(reaction)
        templates.append(template)
    #print(templates)
    #print(len(templates))
    #display(templates[0]['reaction_smarts'])
    #print('---------------------------')
    rxn_template = []
    n = 0
    for i in templates:
        if i is None or 'reaction_smarts' not in i:
            rxn_smarts = None
            n += 1
            #print(rxn_smarts)
            #print(i['reaction_id'])
            rxn_template.append(rxn_smarts)
            continue
        rxn_smarts = '(' + i['reaction_smarts'].replace('>>', ')>>')
        #print(rxn_smarts)
        #print(i['reaction_id'])
        rxn_template.append(rxn_smarts)
    #print(len(rxn_template))
    #print("number of invalid templates is %d" %n)
    return(rxn_template)

@func_timeout.func_set_timeout(20) 
def get_precru(i, rct, combine_enantiomers=False):
    from rdchiral.main import rdchiralRun
    return rdchiralRun(i, rct, combine_enantiomers=combine_enantiomers)

def propose_precursors(rxn_templates):
    from rdkit.Chem import rdChemReactions
    from rdchiral.main import rdchiralReactants, rdchiralReaction, rdchiralRun
    rct = rdchiralReactants(product_smiles)
    rxns = []
    precursors = []
    n=0
    for i in rxn_templates: 
        if i is None:
            #print('------------------------------')
            #print("invalid rxn_template")
            #print('------------------------------')
            rxn = None
            rxns.append(rxn)
            continue
        rxn_i = AllChem.ReactionFromSmarts(i)
        #num_r = rxn_i.GetNumReactantTemplates()
        #num_p = rxn_i.GetNumProductTemplates()
        #n += 1
        #print(num_r, num_p, n)
        #print('------------------------------')
        rxn = rdchiralReaction(i)
        rxns.append(rxn)
        #print(rxns)
    #print('##############################')
    for i in rxns:
        if i is None:
            precursor = None
            precursors.append(precursor)
            #print(precursor)
            continue
        
        try:
            precursor = get_precru(i, rct, combine_enantiomers=False)
        except:
            precursor = None
        precursors.append(precursor)

        # print(precursor)
    #print(len(precursors))
    return(precursors)

def get_rxns():
    rxns =[]
    coindexs = []
    for idx,i in enumerate(precursors):
        if i is None:
            rxnstr = None
            rxns.append(rxnstr)
            coindex = index[idx]
            coindexs.append(coindex)
            #print('------------------------------')
            #print("this reaction could not propose suitable smiles")
            #print('------------------------------')
            continue
        if len(i) ==0:
            rxnstr = None
            rxns.append(rxnstr)
            coindex = index[idx]
            coindexs.append(coindex)
            #print('------------------------------')
            #print("this reaction could not propose suitable smiles")
            #print('------------------------------')
            continue
        for j in i:
            m = Chem.MolFromSmiles(j,sanitize=False)
            if m is None:
                print('invalid smiles')
                rxnstr = None
                rxns.append(rxnstr)
                coindex = index[idx]
                coindexs.append(coindex)
            else:
                try:
                    Chem.SanitizeMol(m)
                    #print("smiles is ok")
                    #display(m)
                    rxnstr=j +'>>'+product_smiles
                    #print(rxnstr)
                    rxns.append(rxnstr)
                    coindex = index[idx]
                    coindexs.append(coindex)

                except:
                    #print('invalid chemistry')
                    rxnstr = None
                    rxns.append(rxnstr)
                    coindex = index[idx]
                    coindexs.append(coindex) 
    #print(rxns)
    #print(len(rxns))
    #print(coindexs)
    #print(len(coindexs))
    repeat_indexs = []
    t = {}
    for i in range(len(coindexs)):
        for j in range(i):
             if coindexs[i] == coindexs[j]:
                repeat_index = coindexs[j]
                #print(repeat_index)
                repeat_indexs.append(repeat_index)
    #print('------------------------------')
    t = list(zip(coindexs,rxns))
    #print(t)
    #print(len(t))
    return rxns,t

def remove_reactant_smiles():
    Reactant_smiles = []
    Reactant_mols = []
    for rxn_smi in df['Rxn']:
        rcts_smi = rxn_smi.split('>>')[0]
        rcts_mol = Chem.MolFromSmiles(rcts_smi)
        [atom.ClearProp('molAtomMapNumber') for atom in rcts_mol.GetAtoms()]
        Reactant_smile = Chem.MolToSmiles(rcts_mol, True)
        Reactant_smiles.append(Reactant_smile)
        Reactant_mol = Chem.MolFromSmiles(Reactant_smile, True)
        Reactant_mols.append(Reactant_mol) 
    return(Reactant_mols)

def rank():
    s_reacs =[]
    ss=[]
    for x in index:
        for k,v in t:
            if k == x:
                #print(k)

                if v is None:
                    #print(v)
                    s_reac = 0
                    s_reacs.append(s_reac)
                    s = 0
                    ss.append(s)
                    continue
                else:
                    i = str(v).split('>>')[0]
                    #print(i)
                    Proposal_mol = Chem.MolFromSmiles(i)
                    if Proposal_mol is None:
                        s_reac = 0
                        s_reacs.append(s_reac)
                        s = 0
                        ss.append(s)
                        continue
                    mfp_proposal = AllChem.GetMorganFingerprint(Proposal_mol,2,useFeatures=True)
                    mfp_Reactant_mol = AllChem.GetMorganFingerprint(Reactant_mols[x],2,useFeatures=True)
                    s_reac = DataStructs.TanimotoSimilarity(mfp_Reactant_mol, mfp_proposal)
                    s_reacs.append(s_reac)
                    s = s_reac*s_prods[x]
                    ss.append(s)
    #print(ss)
    order = np.argsort(ss)[::-1]
    #print(order)
    #print(len(ss))
    return(order)

def remove_atom_mapping():
    df_1=pd.read_csv("data_processed_test.csv",encoding="gbk")
    remove_prod_smiles = []
    remove_rcts_smiles = []
    remove_rxns_smiles = []

    for rxn_smi in df_1['Rxn']:
        prod_smi = rxn_smi.split('>>')[-1]
        prod_mol = Chem.MolFromSmiles(prod_smi)
        [atom.ClearProp('molAtomMapNumber') for atom in prod_mol.GetAtoms()]
        prod_smi_remove_atom_map = Chem.MolToSmiles(prod_mol, True)
        prod_smi_remove_atom_map = Chem.MolToSmiles(Chem.MolFromSmiles(prod_smi_remove_atom_map), True)
        remove_prod_smiles.append(prod_smi_remove_atom_map)      
    for rxn_smi in df_1['Rxn']:
        rcts_smi = rxn_smi.split('>>')[0]
        rcts_mol = Chem.MolFromSmiles(rcts_smi)
        [atom.ClearProp('molAtomMapNumber') for atom in rcts_mol.GetAtoms()]
        rcts_smi_remove_atom_map = Chem.MolToSmiles(rcts_mol, True)
        # Sometimes stereochem takes another canonicalization...
        rcts_smi_remove_atom_map = Chem.MolToSmiles(Chem.MolFromSmiles(rcts_smi_remove_atom_map), True)
        remove_rcts_smiles.append(rcts_smi_remove_atom_map)            

        rxn_smi_remove_atom_map = rcts_smi_remove_atom_map + '>>' + prod_smi_remove_atom_map
        remove_rxns_smiles.append(rxn_smi_remove_atom_map) 
    #print(remove_prod_smiles)
    #print('***********************')
    #print(remove_rcts_smiles)
    #print('***********************')
    #print(remove_rxns_smiles)    
    dataframe_test_remove = pd.DataFrame({'Product':remove_prod_smiles[:],'Reactant':remove_rcts_smiles[:],'Rxn':remove_rxns_smiles[:]})
    dataframe_test_remove.to_csv('remove_data_processed_test.csv',index=False)
    return()

def get_accuracy(n):
    accuracy = []
    m = 0
    for i,j in enumerate(order):
        if i < n:
            if rxns[j] is None:
                continue
            x = rxns[j].split('>>')[0]
            if x == true_precursors:
                display(Chem.MolFromSmiles(x))
                a = np.zeros(i,dtype = int)
                b = np.ones(n-i,dtype = int)
                accuracy = np.concatenate((a,b))
                #print(accuracy)
                m = m +1
                #print(m)
                break
        else:
            break
    if m == 0:
        accuracy = np.zeros(n,dtype = int)
        #print(accuracy)
    return(accuracy)

#prepare_data
data=pd.read_csv(r'C:\Users\Administrator\Desktop\data_processed_USPTO.csv',encoding="gbk")
dataframe = pd.DataFrame({'Product':data['prod_smiles'][:45000],'Rxn':data['rxn_smiles'][:45000]})
dataframe.to_csv('data_processed_train.csv',index=False)
dataframe_test = pd.DataFrame({'Product':data['prod_smiles'][45000:],'Rxn':data['rxn_smiles'][45000:]})
dataframe_test.to_csv('data_processed_test.csv',index=False)


RDLogger.DisableLog('rdApp.*') 
print('Starting......')
start_time = time.time()
#train:smiles2mol2fingerprint
Product_mols = []
df=pd.read_csv("data_processed_train.csv",encoding="gbk")
for i in df['Product']:
    Product_mol = Chem.MolFromSmiles(i)
    Product_mols.append(Product_mol)
mfps = []
for i in Product_mols:
    mfp = AllChem.GetMorganFingerprint(i,2,useFeatures=True)
    mfps.append(mfp)

#test:smiles2mol2fingerprint
df_1=pd.read_csv("data_processed_test.csv",encoding="gbk")
Reactant_mols = remove_reactant_smiles()
remove_atom_mapping()
n = 50
accuracies = np.empty([0,n])
total_length = len(df_1)
print('total length: ', total_length)

for a,i in enumerate(df_1['Product']):
    curr_time = time.time()
    spent_time = curr_time - start_time
    print('(', a+1, ' / ', total_length, ')......', "Spent time:", spent_time)
    start_time = time.time()
    
    product_smiles = i
    product_mol = Chem.MolFromSmiles(product_smiles)
    product_mfp = AllChem.GetMorganFingerprint(product_mol,2,useFeatures=True)
    #get product_similarity
    s_prods = []
    for i in mfps:
        s_prod = DataStructs.TanimotoSimilarity(product_mfp,i)
        s_prods.append(s_prod)
    #print(s_prods)
    index = np.argsort(s_prods)[::-1]
    #print(index)
    #print(len(index))
    rxn_template = extract_template(df['Rxn'],200)
    try:
        precursors = propose_precursors(rxn_template)
    except:
        continue
    rxns,t = get_rxns()
    order = rank()
    df_2 = pd.read_csv("remove_data_processed_test.csv",encoding="gbk")
    true_precursors = df_2['Reactant'][a]
    accuracy = get_accuracy(n)
    accuracies = np.append(accuracies,[accuracy],axis=0)
    if a % int(total_length//20) == 0:
        print(5*a//(total_length//20), "%, ", 'accuracy:', np.mean(accuracies, axis=0))
#print(accuracies)
mean_accuracies = np.mean(accuracies, axis=0)

In [None]:
for i in [1,3,5,10,20,50]:
    print(f"Top {i} accuracy: {mean_accuracies[i-1] * 100: .2f} %")