In [1]:
from __future__ import absolute_import, division, print_function

import cPickle, gzip, glob
from six.moves import xrange
import numpy as np
import tensorflow as tf

from tensorflow.models.rnn.translate import data_utils
import seq2seq_model

from rdkit import Chem
from rdkit.Chem import AllChem
import parser.Smipar as Smipar

from rdkit import DataStructs
from rdkit.Chem.Fingerprints import FingerprintMols

In [2]:
flags = tf.app.flags

flags.DEFINE_float("learning_rate", 0.5, "Learning rate.")
flags.DEFINE_float("learning_rate_decay_factor", 0.99, "Learning rate decays by this much.")
flags.DEFINE_float("max_gradient_norm", 5.0, "Clip gradients to this norm.")
flags.DEFINE_integer("batch_size", 1, "Batch size to use during training.")
flags.DEFINE_integer("size", 600, "Size of each model layer.")
flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
flags.DEFINE_integer("reactant_vocab_size", 311, "Reactant vocabulary size.")
flags.DEFINE_integer("product_vocab_size", 180, "Product vocabulary size.")
flags.DEFINE_string("train_dir", "checkpoint/saved_models", "Training dir.")
FLAGS = flags.FLAGS

_buckets = [(54, 54), (70, 60), (90, 65), (150, 80)]

with gzip.open('data/vocab/vocab_list.pkl.gz', 'rb') as list_file:
    reactants_token_list, products_token_list = cPickle.load(list_file)

In [3]:
def create_model(session, forward_only):
    model = seq2seq_model.Seq2SeqModel( \
        FLAGS.reactant_vocab_size, FLAGS.product_vocab_size, _buckets, \
        FLAGS.size, FLAGS.num_layers, FLAGS.max_gradient_norm, FLAGS.batch_size, \
        FLAGS.learning_rate, FLAGS.learning_rate_decay_factor, forward_only=forward_only)
    ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
    if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())
    return model

def cano(smiles): # canonicalize smiles by MolToSmiles function
    return Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) if (smiles != '') else ''

def fp_sim(smiles1, smiles2): # fingerprint similarity (Tanimoto)
    try:
        fp1 = FingerprintMols.FingerprintMol(Chem.MolFromSmiles(smiles1))
        fp2 = FingerprintMols.FingerprintMol(Chem.MolFromSmiles(smiles2))
        return DataStructs.FingerprintSimilarity(fp1, fp2)
    except:
        return 0
    
def avg(l):
    return sum(l)/float(len(l))

In [4]:
loss_list = [[] for _ in range(10)]
sim_list = [[] for _ in range(10)]
products_list = [[] for _ in range(10)]

with tf.Session() as sess:
    model = create_model(sess, True)
    
    for i, rxn_file in enumerate(glob.iglob('data/problems/*.txt')):
        with open(rxn_file, 'r') as f:
            rsmi_list = f.read().splitlines()
        
        for rsmi in rsmi_list: # list of RSMIs to be tested

            reactant_list = []
            agent_list = []
            product_list = []

            split_rsmi = rsmi.split('>')
            reactants = cano(split_rsmi[0]).split('.')
            agents = cano(split_rsmi[1]).split('.')
            products = cano(split_rsmi[2]).split('.')

            for reactant in reactants:
                reactant_list += Smipar.parser_list(reactant)
                reactant_list += '.'
            for agent in agents:
                agent_list += Smipar.parser_list(agent)
                agent_list += '.'
            for product in products:
                product_list += Smipar.parser_list(product)
                product_list += '.'

            reactant_list.pop() # to pop last '.'
            agent_list.pop()
            product_list.pop()

            products.append(data_utils.EOS_ID)

            reactant_list += '>'
            reactant_list += agent_list

            token_ids = [reactants_token_list.index(r) for r in reactant_list]
            product_ids = [products_token_list.index(p) for p in product_list]

            bucket_id = min([b for b in xrange(len(_buckets)) if _buckets[b][0] > len(token_ids)])
            encoder_inputs, decoder_inputs, target_weights = model.get_batch( \
                                       {bucket_id: [(token_ids, product_ids)]}, bucket_id)
            _, loss, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
                                           target_weights, bucket_id, True)

            outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
            if data_utils.EOS_ID in outputs:
                outputs = outputs[:outputs.index(data_utils.EOS_ID)]
            products_ = ''.join([tf.compat.as_str(products_token_list[output]) for output in outputs])

            loss_list[i].append(loss)
            sim_list[i].append(fp_sim(split_rsmi[2], products_))
            products_list[i].append(products_)
            
            print(rsmi, products_list[i][-1], sim_list[i][-1])
            
        print(rxn_file, avg(loss_list[i]), avg(sim_list[i]), sim_list[i].count(1), sim_list[i].count(0), sep='\t')
        print('-' * 15)
        
    for i, rxn_file in enumerate(glob.iglob('data/problems/*.txt')):
        print(rxn_file, avg(loss_list[i]), avg(sim_list[i]), sim_list[i].count(1), sim_list[i].count(0), sep='\t')

Reading model parameters from checkpoint/saved_models/gen.ckpt-8
CC=C(C)C.Cl>>CCC(C)(C)Cl CCC(C)(C)Cl 1.0
CC1=CCCC1.BrBr>ClC(Cl)(Cl)Cl>CC1(Br)CCCC1Br CC1(Br)CCCC1Br 1.0
CC1=CCCC1.O>[H]B([H])[H].C1CCOC1.OO.[Na+].[OH-]>CC1CCCC1O CC1CCCC1O 1.0
CCC(C)=CC.[O-][O+]=O>CSC>CC=O.CCC(C)=O CC=O.CCC(C)=O 1.0
CC=C1CCCCC1.Br>COOC>CC(Br)C1CCCCC1 CC(Br)C1CCCCC1 1.0
CC=C1CCCCC1.Cl>COOC>CCC1(Cl)CCCCC1 CCC1(O)CCCCC1 0.4625
CCC=CC>OOC(=O)c1ccccc1>CCC1OC1C CCC1OC1C 1.0
CC1=CCCC1>O=[Os](=O)(=O)=O.OO>CC1(O)CCCC1O CC1(O)CCCC1O 1.0
CC1=CCCCC1>[K+].[O-][Mn](=O)(=O)=O.[OH-]>CC1(O)CCCCC1O CC1(O)CCCCC1O 1.0
CC1=CCCCC1>CC(=O)OO.OS(O)(=O)=O>CC(=O)CCCCC(O)=O CC1=C(C2CCCC)OCC1 0
CC1=CCCCC1.[O-][O+]=O>CSC>CC(=O)CCCCC=O CC(=O)CCCCC=O 1.0
CC1=C(C)CCCC1.[H][H]>[Pt]>CC1CCCCC1C CC1CCCCC1C 1.0
C1CCC2=CCCCC2C1.O>OS(O)(=O)=O>OC12CCCCC1CCCC2 OC12CCCCC1CCCC2 1.0
C1CCC2=CCCCC2C1.O>CC(=O)O[Hg]OC(C)=O.[Na+].[H][B-]([H])([H])[H]>OC12CCCCC1CCCC2 OC12CCCCC1CCCC2 1.0
C1CCC2=CCCCC2C1.ClCl.O>>OC12CCCCC1CCCC2Cl OC12CCCCC1CCCC2Cl 1.0
data/

number, average loss(cross-entropy), average Tanimoto similarity, right, error