In [22]:
# Load data science packages
import numpy as np
import pandas as pd

# Load argument packages
import argparse
import re

# Load chemistry packages
import rdkit.Chem as Chem
import rdkit.Chem.AllChem as AllChem
from rdkit.Chem.Draw import IPythonConsole 
from rdkit.Chem.Draw.MolDrawing import MolDrawing, DrawingOptions
from rdkit.Chem import PandasTools
from rdkit import RDLogger                                                                                                                                                               

RDLogger.DisableLog('rdApp.*')
PandasTools.RenderImagesInAllDataFrames(images=True)

# Load visualization package and display settings
import matplotlib.pyplot as plt
IPythonConsole.molSize = (1000, 300)   # Change image size
IPythonConsole.ipython_useSVG = False  # Show as PNG

path_src  = '../data/MIT_reactants_pred_x10/src-test.txt'
path_tgt  = '../data/MIT_reactants_pred_x10/tgt-test.txt'
path_pred = '../results/predictions_MIT_reactants_pred_x10_model_average_20_on_MIT_reactants_pred_x10_test.txt'
# Load data
src  = pd.read_csv(path_src, header=None).replace('\s+', '', regex=True).values.flatten().tolist()
pred = pd.read_csv(path_pred, header=None).replace('\s+', '', regex=True).values.flatten().tolist()
tgt  = pd.read_csv(path_tgt, header=None).replace('\s+', '', regex=True).values.flatten().tolist()

# From SMILES to tokens
def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    assert smi == ''.join(tokens)
    return ' '.join(tokens)

# SMILES functions
canonicalize_smi = lambda smi: 'NA' if not Chem.MolFromSmiles(smi) else Chem.MolToSmiles(Chem.MolFromSmiles(smi))
equivalent_smi   = lambda smi: 'NA' if not Chem.MolFromSmiles(smi) else Chem.MolToSmiles(Chem.MolFromSmiles(smi), doRandom=True)

# Create round trip dataset, 'pred_reactiants.reagents'
round_trip_src = []
round_trip_tgt = []
for i, s in enumerate(src):
    reactants = pred[i]
    product   = s.split('.')[-1]
    reagents  = [n for n in s.split('.') if n != product]
    if len(reagents) > 0:
        round_trip_src.append(reactants + '.' + '.'.join(reagents))
    else:
        round_trip_src.append(reactants)
    round_trip_tgt.append(product)

In [23]:
save_path = '../data/round_trip/'
with open(save_path+'src-rt-8M.txt', 'w+') as f:
    for r in round_trip_src:
        f.write("%s\n" % (smi_tokenizer(r)))
with open(save_path+'tgt-rt-8M.txt', 'w+') as f:
    for r in round_trip_tgt:
        f.write("%s\n" % (smi_tokenizer(r)))