In [24]:
# Load data science packages
import numpy as np
import pandas as pd

# Load argument packages
import argparse
import re

# Load chemistry packages
import rdkit.Chem as Chem
import rdkit.Chem.AllChem as AllChem
from rdkit.Chem.Draw import IPythonConsole 
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions
from rdkit.Chem import PandasTools
from rdkit import RDLogger                                                                                                                                                               

RDLogger.DisableLog('rdApp.*')
PandasTools.RenderImagesInAllDataFrames(images=True)

# Load visualization package and display settings
import matplotlib.pyplot as plt
IPythonConsole.molSize = (1000, 300)   # Change image size
IPythonConsole.ipython_useSVG = False  # Show as PNG

#path_src  = '../data/MIT_reactants_pred_8M_noreagents/src-test.txt'
#path_tgt  = '../data/MIT_reactants_pred_8M_noreagents/tgt-test.txt'
#path_pred = '../results/predictions_MIT_reactants_pred_8M_noreagents_model_average_20_on_MIT_reactants_pred_8M_noreagents_test.txt'

path_src  = '../data/round_trip/src-rt-800k.txt'
path_tgt  = '../data/round_trip/tgt-rt-800k.txt'
path_pred = '../results/round_trip/predictions_rt_800k.txt'

# Load data
src  = pd.read_csv(path_src, header=None).replace('\s+', '', regex=True).values.flatten().tolist()
pred = pd.read_csv(path_pred, header=None).replace('\s+', '', regex=True).values.flatten().tolist()
tgt  = pd.read_csv(path_tgt, header=None).replace('\s+', '', regex=True).values.flatten().tolist()

# SMILES functions
canonicalize_smi = lambda smi: 'NA' if not Chem.MolFromSmiles(smi) else Chem.MolToSmiles(Chem.MolFromSmiles(smi))
equivalent_smi   = lambda smi: 'NA' if not Chem.MolFromSmiles(smi) else Chem.MolToSmiles(Chem.MolFromSmiles(smi), doRandom=True)

In [25]:
counter_all = 0
counter_one = 0
for i in range(len(tgt)):
    tgt_list = sorted([canonicalize_smi(j) for j in tgt[i].split('.')])
    pred_list = sorted([canonicalize_smi(j) for j in pred[i].split('.')])
    if tgt_list == pred_list:
        counter_all += 1
    for j in pred_list:
        if j in tgt_list and j != 'NA':
            counter_one += 1
            break
print('Top-1 accuracy (all reactants) = %.2f %%'%(100 * counter_all / len(tgt)))
print('Top-1 accuracy (at least 1)    = %.2f %%'%(100 * counter_one / len(tgt)))

Top-1 accuracy (all reactants) = 90.49 %
Top-1 accuracy (at least 1)    = 90.72 %


In [13]:
# src, tgt, pred, error
error_reactants  = []

for i in range(len(pred)):
    tgt_list = sorted([canonicalize_smi(j) for j in tgt[i].split('.')])
    pred_list = sorted([canonicalize_smi(j) for j in pred[i].split('.')])
    if 'NA' in pred_list:
        continue
    if 'NA' in tgt_list:
        print('Error in target: %s'%(tgt[i]))
        continue
    if tgt_list != pred_list:
        error_reactants.append(','.join([src[i], tgt[i], pred[i]]))

In [68]:
round_trip = []
for i, s in enumerate(error_reactants):
    
    reactants = s.split(',')[-1]
    product   = s.split(',')[0].split('.')[-1]
    reagents  = '.'.join([n for n in s.split(',')[0].split('.') if n != product])
    round_trip.append([reactants + '.' + reagents, reactants, product])


In [70]:
# From SMILES to tokens
def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    assert smi == ''.join(tokens)
    return ' '.join(tokens)

In [81]:
# Load argument packages
import argparse
import re

# Load transformer package
from onmt.translate.translator import Translator
from onmt.translate import GNMTGlobalScorer
from onmt.model_builder import load_test_model
import onmt.opts as opts
import onmt

# Load data science packages
import numpy as np
import torch

# Load chemical packages
from rdkit import Chem
from rdkit.Chem import Descriptors, Descriptors3D, MolFromSmiles, Lipinski, AllChem
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions

import matplotlib.pyplot as plt

IPythonConsole.molSize = (1000, 300)   # Change image size
IPythonConsole.ipython_useSVG = False  # Show as PNG

# Path to model
MODEL = '../available_models/MIT_mixed_augm/MIT_mixed_augm_model_average_20.pt'

# Set number of predicted products
number_of_products = 1

# Reaction prediction function
def reactionPrediction(translator, reac_smi):

    # Tokenize SMILE molecules
    reac_tok = smi_tokenizer(reac_smi)

    # Output tokenized product
    scores, product_tok = translator.translate_test(src=[reac_tok], batch_size=1)

    # Obtain SMILES product from tokenized product
    product_smi = [pred.replace(' ','') for pred in product_tok[0]]
    
    # Transform log-probs into probs
    scores = [torch.exp(score) for score in scores[0]]
        
    return scores, product_smi
        
# Loads model translator
def load_model(number_of_products=1):

    # Parsing model parameters
    parser = argparse.ArgumentParser(description='translate.py',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    #opts.add_md_help_argument(parser)
    opts.translate_opts(parser)
    opt = parser.parse_args(['-model=%s'%MODEL,
                             '-src=%s'%'CCC',
                             '-batch_size=%s'%'64',
                             '-replace_unk',
                             '-max_length=%s'%'200'])
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    # Load transformer model
    fields, model, model_opt = load_test_model(opt)

    # Set score parameters
    scorer = GNMTGlobalScorer(opt.alpha, opt.beta,
                              opt.coverage_penalty,
                              opt.length_penalty)

    # Create dictionary with model parameters
    kwargs = {k: getattr(opt, k)
              for k in ["beam_size", "max_length", "min_length",
                        "stepwise_penalty", "block_ngram_repeat",
                        "ignore_when_blocking", "dump_beam",
                        "data_type", "replace_unk"]}

    # Create transfomer
    translator = Translator(model, fields=fields, global_scorer=scorer,
                            report_score=True, out_file=None,
                            copy_attn=model_opt.copy_attn, logger=None,
                            src_reader=onmt.inputters.str2reader["text"],
                            tgt_reader=onmt.inputters.str2reader["text"],
                            n_best=number_of_products, gpu=-1, **kwargs)
    
    return translator

# Generate translator
translator = load_model(number_of_products)

In [101]:
%%time
pass_fail = []
for i,reaction in enumerate(round_trip):
    scores, products = reactionPrediction(translator, reaction[0])
    if products[0] == reaction[-1]:
        pass_fail.append('PASS')
    else:
        pass_fail.append('FAIL')
    if i%1000 == 0:
        print('[%d / %d]'%(i+1, len(round_trip)))

[1 / 11379]
[1001 / 11379]
[2001 / 11379]
[3001 / 11379]
[4001 / 11379]
[5001 / 11379]
[6001 / 11379]
[7001 / 11379]
[8001 / 11379]
[9001 / 11379]
[10001 / 11379]
[11001 / 11379]
CPU times: user 1d 22h 58min 27s, sys: 1min 59s, total: 1d 23h 26s
Wall time: 54min 19s


In [103]:
print('Pass: %.2f %%'%(100 * pass_fail.count('PASS') / len(pass_fail)))
print('Fail: %.2f %%'%(100 * pass_fail.count('FAIL') / len(pass_fail)))

Pass: 72.78 %
Fail: 27.22 %


In [114]:
print('Round trip pass = %.2f %%'%(100 * (len(pred) - len(error_reactants) + pass_fail.count('PASS')) / len(pred)))

Round trip pass = 92.05 %


In [120]:
save_path = '../results/wrong_reactant_predictions_V2.csv'
with open(save_path, 'w+') as f:
    f.write("src,tgt,pred,round_trip\n" )
    for it, item in enumerate(error_reactants):
        f.write("%s\n" % (item + ',' + pass_fail[it]))