In [1]:
# Load argument packages
import argparse
import re

# Load transformer package
from onmt.translate.translator import Translator
from onmt.translate import GNMTGlobalScorer
from onmt.model_builder import load_test_model
import onmt.opts as opts
import onmt

# Load data science packages
import pandas as pd
import numpy as np
import torch

# Load chemical packages
from rdkit import Chem
from rdkit.Chem import Descriptors, Descriptors3D, MolFromSmiles, Lipinski, AllChem
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions

import matplotlib.pyplot as plt

IPythonConsole.molSize = (1000, 300)   # Change image size
IPythonConsole.ipython_useSVG = False  # Show as PNG

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

# Path to model
MODEL_all = '../available_models/MIT_reactants_pred_x10/MIT_reactants_pred_x10_model_average_20.pt'
MODEL_one = '../available_models/MIT_1reactant_pred_x10/MIT_1reactant_pred_x10_model_average_20.pt'

# Path to data
path_src  = '../data/MIT_reactants_pred_x10/src-test.txt'
path_tgt  = '../data/MIT_reactants_pred_x10/tgt-test.txt'

# Set number of predicted products
number_of_products = 3

# From SMILES to tokens
def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    assert smi == ''.join(tokens)
    return ' '.join(tokens)

# Get molecule descriptors
def get_descriptors_from_mol(mol_obj, descriptors_list, random_seed=0):

    descriptors_dict = {k: None for k in descriptors_list}
    for k in descriptors_list:
        try:
            if hasattr(Descriptors, k):
                descriptors_dict[k] = getattr(Descriptors, k)(mol_obj)
                continue

            if hasattr(Descriptors3D, k):
                hmol_obj = AllChem.AddHs(mol_obj)
                AllChem.EmbedMolecule(hmol_obj, useExpTorsionAnglePrefs=True,
                useBasicKnowledge=True, randomSeed=random_seed)
                AllChem.UFFOptimizeMolecule(hmol_obj)
                descriptors_dict[k] = getattr(Descriptors3D, k)(hmol_obj)
                continue

            if hasattr(Lipinski, k):
                descriptors_dict[k] = getattr(Lipinski, k)(mol_obj)

            else:
                raise NotImplementedError

        except:
                descriptors_dict[k] = None

    return descriptors_dict

# Reaction prediction function
def reactionPrediction(translator, reac_smi):
    
    """    
        Input:
            Model translator:
                translator (object)
            Reactants and reagents in SMILES
                reac_smi (str)                
                Example: reac_smi = 'N#Cc1ccsc1N.O=[N+]([O-])c1cc(F)c(F)cc1F>C1CCOC1.[H-].[Na+]'
                
        Return:
            Scores and products in SMILES:
                (list (float32), (list (str))
                Example: ([tensor(1.0000)], ['N#Cc1ccsc1Nc1cc(F)c(F)cc1[N+](=O)[O-]'])
            
        Footnote from Schwaller 2019:
            The product of the probabilities of all predicted
            tokens are used as a confidence score
    """

    # Tokenize SMILE molecules
    reac_tok = smi_tokenizer(reac_smi)

    # Output tokenized product
    scores, product_tok = translator.translate_test(src=[reac_tok], batch_size=1)

    # Obtain SMILES product from tokenized product
    product_smi = [pred.replace(' ','') for pred in product_tok[0]]
    
    # Transform log-probs into probs
    scores = [torch.exp(score) for score in scores[0]]
        
    return scores, product_smi

# Display products and scores in terminal
def show_products(scores, products):
    print("-------------------------\n")
    print("Score\t\tProduct\n")
    print("-------------------------\n")
    for iproduct, product in enumerate(products):
        properties = get_descriptors_from_mol(MolFromSmiles(product), descriptors_list, random_seed=0)
        print("%.2e\t%s\n"%(scores[iproduct], product))
        print(properties)
        print("-------------------------\n")
        
def canonicalize_smi(smi: str, remove_atom_mapping=False) -> str:
    """ Convert a SMILES string into its canonicalized form
    Args:
        smi: Reaction SMILES
        remove_atom_mapping: If True, remove atom mapping information from the canonicalized SMILES output
    Returns:
        SMILES reaction, canonicalized, as a string
    """
    mol = Chem.MolFromSmiles(smi)
    if not mol:
        raise NotCanonicalizableSmilesException("Molecule not canonicalizable")
    if remove_atom_mapping:
        for atom in mol.GetAtoms():
            if atom.HasProp("molAtomMapNumber"):
                atom.ClearProp("molAtomMapNumber")
    return Chem.MolToSmiles(mol)

class NotCanonicalizableSmilesException(ValueError):
    pass
        
# Loads model translator
def load_model(MODEL, number_of_products=1):

    # Parsing model parameters
    parser = argparse.ArgumentParser(description='translate.py',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    #opts.add_md_help_argument(parser)
    opts.translate_opts(parser)
    opt = parser.parse_args(['-model=%s'%MODEL,
                             '-src=%s'%'CCC',
                             '-batch_size=%s'%'64',
                             '-replace_unk',
                             '-max_length=%s'%'200'])
    dummy_parser = argparse.ArgumentParser(description='train.py')
    opts.model_opts(dummy_parser)
    dummy_opt = dummy_parser.parse_known_args([])[0]

    # Load transformer model
    fields, model, model_opt = load_test_model(opt)

    # Set score parameters
    scorer = GNMTGlobalScorer(opt.alpha, opt.beta,
                              opt.coverage_penalty,
                              opt.length_penalty)

    # Create dictionary with model parameters
    kwargs = {k: getattr(opt, k)
              for k in ["beam_size", "max_length", "min_length",
                        "stepwise_penalty", "block_ngram_repeat",
                        "ignore_when_blocking", "dump_beam",
                        "data_type", "replace_unk"]}

    # Create transfomer
    translator = Translator(model, fields=fields, global_scorer=scorer,
                            report_score=True, out_file=None,
                            copy_attn=model_opt.copy_attn, logger=None,
                            src_reader=onmt.inputters.str2reader["text"],
                            tgt_reader=onmt.inputters.str2reader["text"],
                            n_best=number_of_products, gpu=1, **kwargs)
    
    return translator

# Get descriptors
descriptors_list = ["MolLogP", "SlogP_VSA1", "Asphericity", "TPSA", "MolWt", "NumHDonors", "NumHAcceptors"]

# Generate first predicter
first_predicter = load_model(MODEL_all)

# Define canonalizer#
canonicalize_smi = lambda smi: 'NA' if not Chem.MolFromSmiles(smi) else Chem.MolToSmiles(Chem.MolFromSmiles(smi))

# Generate second predicter
second_predicter = load_model(MODEL_one)

# Reactant prediction
def reactantsPrediction(reac_smi):

    # Canonicalize reaction
    reac_smi_canon = '.'.join([canonicalize_smi(n) for n in reac_smi.split('.')])

    # Run model for a given reaction
    scores, reactants = reactionPrediction(first_predicter, reac_smi_canon)

    # Canonicalize reactants and put in a list
    reactants_list = [canonicalize_smi(reactant) for reactant in reactants[0].split('.')]

    # Use second predicter if one 'NA' is present in first prediction
    if reactants_list.count('NA') == 1:
        reactants_list.remove('NA')
        scores_one, reactant_one = reactionPrediction(second_predicter,  '.'.join(reactants_list) + '.' + reac_smi_canon)
        return reactant_one[0] + '.' + '.'.join(reactants_list)
    else:
        return reactants

In [2]:
# Load data
src  = pd.read_csv(path_src, header=None).replace('\s+', '', regex=True).values.flatten().tolist()

In [15]:
"""pred = []
N = len(pred)
for reaction in src[N:]:
    pred.append(reactantsPrediction(reaction)[0])"""

In [17]:
# Load target
tgt  = pd.read_csv(path_tgt, header=None).replace('\s+', '', regex=True).values.flatten().tolist()

# Calulate top-1 accuracy
counter_all = 0
counter_one = 0
for i in range(len(pred)):
    tgt_list = sorted([canonicalize_smi(j) for j in tgt[i].split('.')])
    pred_list = sorted([canonicalize_smi(j) for j in pred[i].split('.')])
    if tgt_list == pred_list:
        counter_all += 1
    for j in pred_list:
        if j in tgt_list and j != 'NA':
            counter_one += 1
            break
print('Top-1 accuracy (all reactants) = %.2f %%'%(100 * counter_all / len(pred)))
print('Top-1 accuracy (at least 1)    = %.2f %%'%(100 * counter_one / len(pred)))

Top-1 accuracy (all reactants) = 70.59 %
Top-1 accuracy (at least 1)    = 82.28 %


In [16]:
save_path = '../results/predictions_MIT_enhanced_reactants_pred_model_on_MIT_reactants_pred_test_V2.txt'
with open(save_path, 'w') as f:
    for item in pred:
        f.write("%s\n" % item)