In [1]:
import pickle
import glob
from collections import defaultdict
import math
from tqdm import tqdm
from time import time
#import json
#import molvs
import random
#import policies

#from mcts import Node, mcts
import tensorflow as tf
from rdkit import Chem
from rdkit.Chem import AllChem

import os
import numpy as np

from highway_layer import Highway
#匯入深度學習的框架函式庫：keras
import keras
from keras import backend as K
from keras.initializers import Constant
from keras.utils import plot_model
#keras用以建立模型架構的函數
from keras.models import Sequential, load_model, Model

#keras中建立深度學習layer的函數

from keras.layers import Dense, Dropout, BatchNormalization, Activation, Multiply, Add, Lambda, Input

#keras訓練演算法函數
from keras import regularizers
from keras.optimizers import Adam

#keras提早判停的函數
from keras.callbacks import EarlyStopping, ModelCheckpoint

#it's hard to reproduce results, so close all seeds
#os.environ['PYTHONHASHSEED'] = '0'
#np.random.seed(0)
#tf.set_random_seed(0)
#random.seed(0)

#to solve problem:Blas GEMM launch failed
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
#config = tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allocator_type = 'BFC' #A "Best-fit with coalescing" algorithm, simplified from a version of dlmalloc.
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.allow_growth = True
set_session(tf.Session(config=config)) 

EXPLORE_PARAM =5.
# Monte Carlo tree search numbers: iters
iters = 5000
# rollout maximum depth: max_d
max_d = 5
# when node is visited by "visitn+2" times, expansions start
visitn = 0
win = 10.
lmax = 25.
seed = 0
fp_dim_e=16384
#fp_dim_e = 23086
fp_dim_r = 8192
fp_dim_i = 16384
recfp_dim = 2048
#random.seed(seed)
#np.random.seed(seed)

def wlmax(length, acprob):
    ksi = length-0.99*acprob
    return max(0,1-ksi/lmax)

class Node:
    def __init__(self, state, parent=None, action=None, is_terminal=False, length=0, prob=0, acprob=0):
        self.state = state
        self.children = []
        self.parent = parent
        self.n_visits = 0
        self.reward = 0
        self.action = action
        self.is_terminal = is_terminal
        self.length = length
        self.prob = prob
        self.acprob = acprob

    @property
    def value(self):
        """UCB1"""
        if self.n_visits == 0:
#            return float('inf')
            return self.prob*1e6
        return self.reward/self.n_visits + \
            EXPLORE_PARAM*self.prob*math.sqrt(math.log(self.parent.n_visits)/self.n_visits)

#    @property
    def score(self):    
        return wlmax(self.length, self.acprob)

    def best_child(self):
        return max(self.children, key=lambda n: n.value)


def mcts(root, expansion_net, filter_net, rollout_net, iterations=2000, max_depth=200):
    """
    Monte Carlo Tree Search
    - `expansion_policy` should be a function that takes a node and returns a
    list of child nodes
    - `rollout_policy` should be a function that takes a node and returns a
    reward for that node
    """
    pathall=[]
    root.children = expansion(root, expansion_net, filter_net)
    if not root.children: 
        print('No synthesis path found. Try adding more data to train model or increasing the rule selection number of `expansion`.')
        return None

    # MCTS
    for _ in tqdm(range(iterations)):
        cur_node = root

        # Selection
        while True:
            if cur_node.n_visits >= 0 and cur_node.children:
                cur_node = cur_node.best_child()
            else:
                break

        if cur_node.n_visits > visitn:
            # If selection took us to a terminal node,
            # this seems to be the best path
            if cur_node.is_terminal:
                # Return best path
                cur_node1 = root
                path1 = [cur_node1]
                #for _ in range(lmax*10):
                while True:
                    #if not cur_node.children: continue 
                    cur_node1 = cur_node1.best_child()
                    path1.append(cur_node1)
                    if cur_node1.is_terminal:
                        break      
                if path1 not in pathall:
                    pathall.append(path1)
                    
                # Update, reward-1 to avoid repeated process
                #cur_node.reward += win*wlmax(cur_node.length, cur_node.acprob)
                cur_node.reward += -win*(visitn+1)*wlmax(cur_node.length, cur_node.acprob)
                #cur_node.reward += -win*(cur_node.n_visits+1)*wlmax(cur_node.length, cur_node.acprob)
                cur_node.n_visits += 1
                parent = cur_node.parent
                while parent is not None:
                    #parent.reward += -win*wlmax(cur_node.length, parent.acprob)
                    #parent.reward += win*wlmax(cur_node.length, cur_node.acprob)
                    parent.reward += -win*(visitn+1)*wlmax(cur_node.length, cur_node.acprob)
                    #parent.reward += -1*(parent.n_visits+1)*wlmax(cur_node.length, cur_node.acprob)
                    parent.n_visits += 1
                    parent = parent.parent
                #print('test')
                #return pathall
                continue


            # Expansion
            s = time()
            cur_node.children = expansion(cur_node, expansion_net, filter_net)
#            print('Expansion took:', time() - s)
            if not cur_node.children:
                # Update
                cur_node.reward += -1*wlmax(cur_node.length, cur_node.acprob)
                cur_node.n_visits += 1
                parent = cur_node.parent
                while parent is not None:
                    #parent.reward += -1*wlmax(cur_node.length, parent.acprob)
                    parent.reward += -1*wlmax(cur_node.length, cur_node.acprob)
                    parent.n_visits += 1
                    parent = parent.parent
                continue
                
            cur_node = cur_node.best_child()

        # Rollout
        s = time()
        reward, length, acprob = rollout(cur_node, rollout_net, max_depth=max_depth)
#        print('Rollout took:', time() - s)

        # Update
        #cur_node.reward += reward*wlmax(length, cur_node.acprob)
        cur_node.reward += reward*wlmax(length, acprob)
        cur_node.n_visits += 1
        parent = cur_node.parent
        while parent is not None:
            #parent.reward += reward*wlmax(length, parent.acprob)
            parent.reward += reward*wlmax(length, acprob)
            parent.n_visits += 1
            parent = parent.parent

#    # Return best path
#    cur_node = root
#    path = [cur_node]
#    for _ in range(lmax*10):
#        if not cur_node.children: continue 
#        cur_node = cur_node.best_child()
#        path.append(cur_node)
#        if cur_node.is_terminal:
#            break

#    # Max depth exceeded, no path found
#    else:
#        return None

#    return path

    # Max depth exceeded, no path found
    #if not pathall: return None
    return pathall

def fps_to_arr_r(fps):
    """Faster conversion to ndarray"""
    arrs = []
    for fp in fps:
        onbits = list(fp.GetOnBits())
        arr = np.zeros(fp.GetNumBits())
        arr[onbits] = 1
        arrs.append(arr)
    arrs = np.array(arrs)
    return arrs




def fingerprint_mols_r(mols, fp_dim):
    fps = []
    for mol in mols:
        mol = Chem.MolFromSmiles(mol)

        # Necessary for fingerprinting
        # Chem.GetSymmSSSR(mol)

        # "When comparing the ECFP/FCFP fingerprints and
        # the Morgan fingerprints generated by the RDKit,
        # remember that the 4 in ECFP4 corresponds to the
        # diameter of the atom environments considered,
        # while the Morgan fingerprints take a radius parameter.
        # So the examples above, with radius=2, are roughly
        # equivalent to ECFP4 and FCFP4."
        # <http://www.rdkit.org/docs/GettingStartedInPython.html>
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=int(fp_dim), useChirality=1)
        # fold_factor = fp.GetNumBits()//fp_dim
        # fp = DataStructs.FoldFingerprint(fp, fold_factor)
        fps.append(fp)
    return fps

def fps_to_arr(fps):
    """Faster conversion to ndarray"""
    arrs = []
    for fp, info in zip(fps[0],fps[1]):
        onbits = list(fp.GetOnBits())
        arr = np.zeros(fp.GetNumBits())
        for onbit in onbits:
            arr[onbit] = len(info[onbit])
        arrs.append(arr)
    arrs = np.array(arrs)
    return arrs




def fingerprint_mols(mols, fp_dim):
    fps = []
    infos = []
    for mol in mols:
        mol = Chem.MolFromSmiles(mol)
        info={}
        # Necessary for fingerprinting
        # Chem.GetSymmSSSR(mol)

        # "When comparing the ECFP/FCFP fingerprints and
        # the Morgan fingerprints generated by the RDKit,
        # remember that the 4 in ECFP4 corresponds to the
        # diameter of the atom environments considered,
        # while the Morgan fingerprints take a radius parameter.
        # So the examples above, with radius=2, are roughly
        # equivalent to ECFP4 and FCFP4."
        # <http://www.rdkit.org/docs/GettingStartedInPython.html>
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=int(fp_dim), useChirality=1, bitInfo=info)
        # fold_factor = fp.GetNumBits()//fp_dim
        # fp = DataStructs.FoldFingerprint(fp, fold_factor)
        fps.append(fp)
        infos.append(info)
    return fps, infos

def preprocess_e(X, fp_dim):
    # Compute fingerprints
    dataX = fps_to_arr(fingerprint_mols(X, fp_dim))
    # Apply variance threshold
    # return np.log(X[:,self.idx] + 1) 
    #FPs = np.log(dataX[:,idx]+1)
    FPs = np.log(dataX+1)
    return FPs

def preprocess_r(X,fp_dim):
    # Compute fingerprints
    dataX = fps_to_arr_r(fingerprint_mols_r(X, fp_dim))
    FPs = np.log(dataX+1)
    return FPs

def preprocess_i(X, fp_dim):
    # Compute fingerprints
    FPs = fps_to_arr(fingerprint_mols(X, fp_dim))
    # Apply variance threshold
    # return np.log(X[:,self.idx] + 1) 
    #FPs = np.log(dataX[:,idx]+1)
#    FPs = np.log(dataX+1)
    return FPs
def smi_list_from_str(inchis):
    '''string separated by ++ to list of RDKit molecules'''
    return [inchi.strip() for inchi in inchis.split('++')]

def acc_top50(y_true, y_pred):
    return keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=50)

def acc_top10(y_true, y_pred):
    return keras.metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=10)

def fold(x):
    z=tf.subtract(x[0], x[1])
#    z_shape=tf.Tensor.shape(z)

#    z_shape=z.get_shape().as_list()
    zv=tf.reshape(z,[-1,8,2048])
    return tf.reduce_sum(zv, 1) 

def cosine(x):
    prod_net = x[0]
    react_net = x[1]
#    prod_norm = tf.nn.l2_normalize(prod_net, axis=-1)
#    react_norm = tf.nn.l2_normalize(react_net, axis=-1)
    cosine_sim = tf.reduce_sum(tf.multiply(prod_net, react_net), axis=-1,keepdims=True)
#    cosine_sim = tf.squeeze(cosine_sim,[1])
#    return tf.nn.sigmoid(cosine_sim)
    return tf.nn.sigmoid(cosine_sim)
# get average auc between different batches over the epoch, so don't use. otherwise validation process always get wrong results
def auc2(y_true, y_pred):
    auc = tf.metrics.auc(y_true, y_pred)[1]
    K.get_session().run(tf.local_variables_initializer())
    return auc

# AUC for a binary classifier, this AUC is a little underestimated due to minimum areas.
def auc1(y_true, y_pred):
    ptas = tf.stack([binary_PTA(y_true,y_pred,k) for k in np.linspace(0, 1, 1000)],axis=0)
    pfas = tf.stack([binary_PFA(y_true,y_pred,k) for k in np.linspace(0, 1, 1000)],axis=0)
    pfas = tf.concat([tf.ones((1,)) ,pfas],axis=0)
    binSizes = -(pfas[1:]-pfas[:-1])
    s = ptas*binSizes
    return K.sum(s, axis=0)
#-----------------------------------------------------------------------------------------------------------------------------------------------------
# PFA, prob false alert for binary classifier(FPR)
def binary_PFA(y_true, y_pred, threshold=K.variable(value=0.5)):
    y_pred = K.cast(y_pred >= threshold, 'float32')
    # N = total number of negative labels
    N = K.sum(1 - y_true)
    # FP = total number of false alerts, alerts from the negative class labels
    FP = K.sum(y_pred - y_pred * y_true)
    return FP/N
#-----------------------------------------------------------------------------------------------------------------------------------------------------
# P_TA prob true alerts for binary classifier(TPR)
def binary_PTA(y_true, y_pred, threshold=K.variable(value=0.5)):
    y_pred = K.cast(y_pred >= threshold, 'float32')
    # P = total number of positive labels
    P = K.sum(y_true)
    # TP = total number of correct alerts, alerts from the positive class labels
    TP = K.sum(y_pred * y_true)
    return TP/P
# PFA, prob false alert for binary classifier(FPR)
def FPR(y_true, y_pred):
    y_pred = K.cast(y_pred >= 0.9, 'float32')
    # N = total number of negative labels
    N = K.sum(1 - y_true)
    # FP = total number of false alerts, alerts from the negative class labels
    FP = K.sum(y_pred - y_pred * y_true)
    return FP/N
#-----------------------------------------------------------------------------------------------------------------------------------------------------
# P_TA prob true alerts for binary classifier(TPR)
def TPR(y_true, y_pred):
    y_pred = K.cast(y_pred >= 0.9, 'float32')
    # P = total number of positive labels
    P = K.sum(y_true)
    # TP = total number of correct alerts, alerts from the positive class labels
    TP = K.sum(y_pred * y_true)
    return TP/P

# ACC= (TP + TN) / (P + N)
def ACCR(y_true, y_pred):
    y_pred = K.cast(y_pred >= 0.9, 'float32')
    # P = total number of positive labels
    P = K.sum(y_true)
    # N = total number of negative labels
    N = K.sum(1 - y_true)    
    # TP = total number of correct alerts, alerts from the positive class labels
    TP = K.sum(y_pred * y_true)
    # TN = total number of correct alerts, alerts from the negtive class labels
    TN = K.sum((1-y_pred) * (1-y_true))    
    return (TP+TN)/(P+N)

    
# Load base compounds
starting_mols = set()
expansion_rules = []
rollout_rules = []
#with open('data/emolecules.smi', 'r') as f:
#    for line in tqdm(f, desc='Loading base compounds'):
#        smi = line.strip()
###        smi = molvs.standardize_smiles(smi)
#        smi = Chem.MolFromSmiles(smi)
#        if not smi: continue
#        smi = Chem.MolToSmiles(smi,allHsExplicit=0,allBondsExplicit=0)
#        starting_mols.add(smi)
#with open('data/emoleculestandard.dat', 'w') as f:
#    f.write('\n'.join(starting_mols))  

#'''    
#with open('data/emoleculestandard.dat', 'r') as f:

with open('data/emoleculestandard0701.dat', 'r') as f:
    for line in tqdm(f, desc='Loading base compounds'):
        smi = line.strip()
#        smi = molvs.standardize_smiles(smi)
#        smi = Chem.MolFromSmiles(smi)
#        if not smi: continue
#        smi = Chem.MolToSmiles(smi,allHsExplicit=0,allBondsExplicit=0)
        starting_mols.add(smi)
        
print('Base compounds:', len(starting_mols))
#'''
#'''    


with open('data/zincagent.dat', 'r') as f:
    for line in tqdm(f, desc='Loading base compounds'):
        smi = line.strip()
#        smi = molvs.standardize_smiles(smi)
#        smi = Chem.MolFromSmiles(smi)
#        if not smi: continue
#        smi = Chem.MolToSmiles(smi,allHsExplicit=0,allBondsExplicit=0)
        starting_mols.add(smi)

with open('data/electric_charge.dat', 'r') as f:
    for line in tqdm(f, desc='Loading base compounds'):
        smi = line.strip()
#        smi = molvs.standardize_smiles(smi)
#        smi = Chem.MolFromSmiles(smi)
#        if not smi: continue
#        smi = Chem.MolToSmiles(smi,allHsExplicit=0,allBondsExplicit=0)
        starting_mols.add(smi)    
print('Base compounds:', len(starting_mols))
#'''

'''
start=time()
with open('data/emoleculestandard.pickle', 'rb') as f:
    starting_mols = pickle.load(f)
print(time()-start, 's')        
print('Base compounds:', len(starting_mols))
'''

# Load policy networks
#with open('model/rules.json', 'r') as f:
#    rules = json.load(f)
#    rollout_rules = rules['rollout']
#    expansion_rules = rules['expansion']
    
with open('data/expansion_expansion.dat', 'r') as f:
    for i, l in tqdm(enumerate(f), desc='expansion'):
        rule = l.strip()
        expansion_rules.append(rule)
with open('data/rollout_rollout.dat', 'r') as f:
    for i, l in tqdm(enumerate(f), desc='rollout'):
        rule = l.strip()
        rollout_rules.append(rule)
        

save_dir = os.path.join(os.getcwd(), 'saved_models')
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path_r = os.path.join(save_dir, 'trained_model_rollout_all')  #2

#model_path_e = os.path.join(save_dir, 'trained_model_expansion_all')#4 #fpdim 16384 not 23086s 
#model_path_e = os.path.join(save_dir, 'trained_model_expansion_all_1')#4 #not all rules
model_path_e = os.path.join(save_dir, 'trained_model_expansion_0-testall')#4  #fpdim 16384 not 23086 #not all rules
model_path_i = os.path.join(save_dir, 'trained_model_inscope_all-test0-self-ratiohalf') #2
#model_path_i = os.path.join(save_dir, 'trained_model_inscope_test0-self-ratiohalf') #2
threshold = 0.9

rollout_net = load_model(model_path_r, custom_objects={'acc_top10': acc_top10,'acc_top50': acc_top50})
expansion_net = load_model(model_path_e, custom_objects={'acc_top10': acc_top10,'acc_top50': acc_top50, 'Highway': Highway})
filter_net =  load_model(model_path_i, custom_objects={'ACCR': ACCR,'auc2': auc2,'auc1': auc1,'TPR': TPR, 'FPR': FPR,'Highway': Highway,'fold': fold,'cosine': cosine, 'tf': tf})

def convert_to_retro(transform):
    '''This function takes a forward synthesis and converts it to a
    retrosynthesis. Only transforms with a single product are kept, since
    retrosyntheses should have a single reactant (and split it up accordingly).'''

    # Split up original transform
    reactants = transform.split('>>')[0]
    products  = transform.split('>>')[1]

    # Don't force products to be from different molecules (?)
    # -> any reaction template can be intramolecular (might remove later)
    #products = products[1:-1].replace(').(', '.')

    # Don't force the "products" of a retrosynthesis to be two different molecules!
    #reactants = reactants[1:-1].replace(').(', '.')

    return '>>'.join([products, reactants])

def transform(mol, rule, mode='exp'):
    """Apply transformation rule to a molecule to get reactants"""
    reactants = []
    reactants_smi = []
    results = []
    rxn = AllChem.ReactionFromSmarts(rule)
    #the below should be no isomerics for comparing with original molecule, because all rules I use cannot distinguish whether the product is isometrc or not 
    if mode == 'exp':
        mol_smi = Chem.MolToSmiles(mol,allHsExplicit=0,allBondsExplicit=0, isomericSmiles=0)
    try:
        results = rxn.RunReactants([mol])
    except Exception as e:
        print('error: {}'.format(e))
        print('rxn: {}'.format(rule))    
    
    if not results: return []    
    for result in results:
        mols = []
        mols_obj= []
        for i,mol in enumerate(result):
            repeat = 0
            try: 
                #To save time, so i cancell the following step in the below
                Chem.SanitizeMol(mol)
#                product.UpdatePropertyCache()
                #create product or reactant using molfromsmarts+sanitizemol is sometimes better than molfromsmiles, but still using molfromsmiles as possible as you can
                #To save time, so i cancell the following step in the below
                #mol=Chem.MolFromSmiles(Chem.MolToSmiles(mol,allHsExplicit=0,allBondsExplicit=0))
            except Exception as e:
                #print('warning1: {}'.format(e))
                #use pass is not good behavior, however i have validation finally
                break
            if not mol:
                break
            a = Chem.MolToSmiles(mol,allHsExplicit=0,allBondsExplicit=0, isomericSmiles=1)
            if '.' in a: break
            '''    
            if reactants and i == 0:
                for reac in reactants:
                    if a in reac:
                        repeat = 1
                        break
            if repeat: break
            '''
            mols.append(a)
            mols_obj.append(mol)
            if i == len(result)-1 and mode == 'exp':
                retrorule = convert_to_retro(rule)
                retrorxn = AllChem.ReactionFromSmarts(retrorule)
    
                try:
                    retroresults = retrorxn.RunReactants(mols_obj)
                except Exception as e:
                    print('error retro: {}'.format(e))
                    print('rxn retro: {}'.format(retrorule))    

                if not retroresults: break 
                test_mol = []    
                for result in retroresults:
                    
                    for i,mol in enumerate(result):
            
                        try: 
                            #To save time, so i cancell the following step in the below
                            Chem.SanitizeMol(mol)
            #                product.UpdatePropertyCache()
                            #create product or reactant using molfromsmarts+sanitizemol is sometimes better than molfromsmiles, but still using molfromsmiles as possible as you can
                            #To save time, so i cancell the following step in the below
                            #mol=Chem.MolFromSmiles(Chem.MolToSmiles(mol,allHsExplicit=0,allBondsExplicit=0))
                        except Exception as e:
                            #print('warning1: {}'.format(e))
                            #use pass is not good behavior, however i have validation finally
                            break
                        if not mol:
                            break
                        #the below should be no isomerics for comparing with original molecule, because all rules I use cannot distinguish whether the product is isometrc or not 
                        b = Chem.MolToSmiles(mol,allHsExplicit=0,allBondsExplicit=0, isomericSmiles=0)
                        test_mol.append(b)
                if not test_mol: break 
                if mol_smi not in test_mol: break
                
        else:
            reactants_smi_one = '++'.join(mols)
            if reactants_smi_one not in reactants_smi: 
                reactants_smi.append(reactants_smi_one)
                reactants.append(mols)
            
    # Only look at first set of results (TODO any reason not to?)
    #results = results[0]
    #reactants = [Chem.MolToSmiles(smi) for smi in results]
    return reactants


def expansion(node, expansion_net, filter_net):
    """Try expanding each molecule in the current state
    to possible reactants"""

    # Assume each mol is a SMILES string
    mols = node.state

    # Convert mols to format for prediction
    # If the mol is in the starting set, ignore
    mols = [mol for mol in mols if mol not in starting_mols]
    e_x = preprocess_e(mols, fp_dim_e)

    # Predict applicable rules
    predict = expansion_net.predict_on_batch(e_x)
    # get the rules index from high probability to low probability
    preds =np.argsort(-predict, axis=1)[:,:50]
    
    #below is for special case (dont use)
    #if node.length==0:
    #    preds =np.argsort(-predict, axis=1)[:,:49999]
 

    # Generate children for reactants
    children = []
    count = -1
    for mol, rule_idxs in zip(mols, preds):
        count += 1
        # State for children will
        # not include this mol
        new_state = node.state - {mol}

        prod = Chem.MolFromSmiles(mol)
        if not prod: return []       
        for idx in rule_idxs:
            # Extract actual rule
            rule = expansion_rules[idx]
            
            #below is for special case (dont use)
            #if node.length==0:
            #    if 'F-' not in rule: continue
            
            # TODO filter_net should check if the reaction will work?
            # should do as a batch

            # Apply rule
            reactants = transform(prod, rule, mode='exp')
            #, mode='exp'
            if not reactants: continue
#            reactants = list(reactant1) 
            X = np.zeros((len(reactants), fp_dim_i))
            X[0] = preprocess_i([mol], fp_dim_i)                
            for i, reactant in enumerate(reactants):
                X[i] = X[0]
                y = np.zeros((len(reactants),recfp_dim))
                
                n = np.zeros((1,recfp_dim))
                for b in reactant:
                    n += preprocess_i([b], recfp_dim)
                p = X[i].reshape((-1, recfp_dim))    
                y[i] = np.sum(p, 0, keepdims=True)- n 
            # Predict applicable rules
            predict_i = filter_net.predict_on_batch([X, y])

                
            Treactants = [reactant for reactant, pred in zip(reactants, predict_i) if pred >= threshold]
            Tpreds = [pred for reactant, pred in zip(reactants, predict_i) if pred >= threshold]
            if not Treactants: continue
            for reactant, pred in zip(Treactants, Tpreds):
                state = new_state | set(reactant)
                terminal = all(mol in starting_mols for mol in state)
                #child = Node(state=state, is_terminal=terminal, parent=node, action=rule, length=node.length+1, prob= predict[count,idx], acprob=node.acprob+predict[count,idx])
                child = Node(state=state, is_terminal=terminal, parent=node, action=rule, length=node.length+1, prob= pred[0]/30, acprob=node.acprob+pred[0]/30)
                children.append(child)
                #below is for special case (dont use)
                #if node.length==0:
                #    print(state)
                
    return children


def rollout(node, rollout_net, max_depth=30):
    cur = node
    ## only focus on unsolved molecules to avoid getting good reward but unsolved molecules appear
    #state = {mol for mol in cur.state if mol not in starting_mols}
    #cur = Node(state=state, is_terminal=cur.is_terminal, parent=cur.parent, action=cur.action, length=cur.length, prob= cur.prob, acprob=cur.acprob)
    for _ in range(max_depth):
        if cur.is_terminal:
            break

        # Select a random mol (that's not a starting mol)
        mols = [mol for mol in cur.state if mol not in starting_mols]
        mol = random.choice(mols)
        prod = Chem.MolFromSmiles(mol)
#        if not prod: return -1., cur.length 
        if not prod:
            continue
            '''
            # Partial reward if some starting molecules are found
            reward = sum(1 for mol in cur.state if mol in starting_mols)/len(cur.state)

            # Reward of -1 if no starting molecules are found
            if reward == 0:
                return -1., cur.length, cur.acprob

            return reward, cur.length, cur.acprob
            '''
        
        r_x = preprocess_r([mol], fp_dim_r)

        # Predict applicable rules
        predict = rollout_net.predict_on_batch(r_x)
        # get the rules index from high probability to low probability
        preds =np.argsort(-predict, axis=1)[:,:10]
        idx = np.random.choice(preds[0])
        rule = rollout_rules[idx]
        
        reactants = transform(prod, rule, mode='rol')
        #, mode='rol'
        if not reactants:
            continue
            '''
            # Partial reward if some starting molecules are found
            reward = sum(1 for mol in cur.state if mol in starting_mols)/len(cur.state)

            # Reward of -1 if no starting molecules are found
            if reward == 0:
                return -1., cur.length, cur.acprob

            return reward, cur.length, cur.acprob
            '''
#        reactants = list(reactant1)
        reactant = random.choice(reactants)
        
            
        state = cur.state | set(reactant)

        # State for children will
        # not include this mol
        state = state - {mol}

        terminal = all(mol in starting_mols for mol in state)
        cur = Node(state=state, is_terminal=terminal, parent=cur, action=rule, length=cur.length+1, prob= predict[0,idx], acprob=cur.acprob+predict[0,idx])

    # Max depth exceeded
    else:
        #print('Rollout reached max depth')

        # Partial reward if some starting molecules are found
        reward = sum(1 for mol in cur.state if mol in starting_mols)/len(cur.state)

        # Reward of -1 if no starting molecules are found
        if reward == 0:
            return -1., cur.length, cur.acprob
        if reward == 1:
            return win , cur.length, cur.acprob        

        return reward, cur.length, cur.acprob

    # Reward of 1 if solution is found
    return win , cur.length, cur.acprob


def plan(target_mol, expansion_net, filter_net, rollout_net, iterations=2000, max_depth=200):
    """Generate a synthesis plan for a target molecule (in SMILES form).
    If a path is found, returns a list of (action, state) tuples.
    If a path is not found, returns None."""
    root = Node(state={target_mol})
    pathall = []
    path = mcts(root, expansion_net, filter_net, rollout_net, iterations=iterations, max_depth=max_depth)
    if not path:
        print('No synthesis path found. Try increasing `iterations` or `max_depth`.')
    else:
        print('Path found:')
        #path = [(n.action, n.state) for n in path[1:]]
        for i in path:
            #ii = [(n.action, n.state, n.reward/n.n_visits) for n in i[1:]]
            ii = [(n.action, n.state) for n in i[1:]]
            ii.append(i[-1].score())
            pathall.append(ii)
    return pathall


if __name__ == '__main__':
    #Tropantiol-TRODAT-1:CN1[C@H]2CC[C@@H]1[C@H]([C@H](C2)C3=CC=C(C=C3)Cl)CN(CCNCCS)CCS
    #target_mol1 = 'CN1C2CCC1C(C(C2)C3=CC=C(C=C3)Cl)CN(CCNCCS)CCS'
    #target_mol1 = 'CN1[C@H]2CC[C@@H]1[C@H]([C@H](C2)C3=CC=C(C=C3)Cl)CN(CCNCCS)CCS'
    m1= 'COC(=O)[C@H]1[C@@H](c2ccc(Cl)cc2)C[C@@H]2CC[C@H]1N2C'
    m2= 'COC(=O)[C@H]1[C@@H](O)C[C@@H]2CC[C@H]1N2C'
    m3= 'CCOC(=O)[C@H]1[C@@H](O)C[C@@H]2CC[C@H]1N2C'
    m4= 'COC(=O)[C@H]1[C@@H](OC(=O)c2ccccc2)C[C@@H]2CC[C@H]1N2C'
    m5= 'CCOC(=O)[C@H]1[C@@H](OC(=O)c2ccccc2)C[C@@H]2CC[C@H]1N2C'
    m6= 'CN1[C@H]2CC[C@@H]1[C@@H](C(=O)O)[C@@H](O)C2'
    m7= 'CN1[C@H]2CC[C@@H]1[C@@H](C(=O)O)[C@@H](OC(=O)c1ccccc1)C2'
    m8= 'CC(C)OC(=O)[C@H]1[C@@H](OC(=O)c2ccccc2)C[C@@H]2CC[C@H]1N2C'
    m9= 'CN1[C@H]2CC[C@@H]1CC(=O)C2'
    m10= 'CN1[C@H]2CC[C@@H]1[C@@H](COC(=O)c1ccccc1)[C@@H](O)C2'
    m11='COC(=O)[C@H]1CC[C@@H]2CC[C@H]1N2'
    m12='CN1[C@@H]2CC=C(C(=O)O)[C@H]1CC2'
    m13='CN1[C@H]2CC[C@@H]1[C@@H](C(=O)O)[C@@H](OC(=O)c1ccc(O)cc1)C2'
    m14='CCOC(=O)C1=CC[C@@H]2CC[C@H]1N2C'
    m15='COC(=O)[C@H]1[C@@H](c2ccccc2)C[C@@H]2CC[C@H]1N2C'
    m16='CN1[C@H]2CC[C@@H]1[C@@H](C(=O)Oc1ccccc1)[C@@H](c1ccc(Cl)cc1)C2'
    m17='CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@@](C#N)(c2ccc3c(N)ncnn23)[C@H](O)[C@@H]1O)Oc1ccccc1' # REMDESIVIR
    m18='N#C[C@@]1(c2ccc3c(N)ncnn23)O[C@H](CO)[C@@H](O)[C@H]1O' # initiative of REMDESIVIR 
    m19='N#C[C@@]1(c2ccc3c(N)ncnn23)O[C@H](COCc2ccccc2)[C@@H](OCc2ccccc2)[C@H]1OCc1ccccc1'# initiative of REMDESIVIR
    m20='COC(=O)[C@H]1[C@@H](c2ccc(Cl)c(Cl)c2)C[C@@H]2CC[C@H]1N2C'
    m21='COC(=O)[C@H]1[C@@H](OC(=O)c2ccccc2)C[C@@H]2CC[C@H]1N2'
    m22='O=C(O[C@H]1C[C@@H]2CC[C@@H](N2)[C@H]1C(=O)O)c1ccccc1'
    m23='COC(=O)[C@H]1[C@@H](OC(=O)c2ccccc2)C[C@@H]2CC[C@H]1N2C(=O)O'
    m24='COC(=O)C1=CC[C@@H]2CC[C@H]1N2C'
    m25='COC(=O)C1=CC[C@@H]2CC[C@H]1N2C(=O)OC(C)(C)C'
    m26='COC(=O)C1=CC[C@@H]2CC[C@H]1N2Cc1ccccc1'
    #del_mol1 = [m1,m15,m16,m20]
    #del_mol1 = [m1,m2,m3,m4,m5,m6,m7,m8,m9 ,m10,m11,m13,m14,m17,m18,m19,m20,m21,m22,m23,m24,m25,m26]
    #starting_mols.difference_update(del_mol1)
    del_mol2 = m24
    starting_mols.remove(del_mol2)
    
    # new initiative reactant for trodat1
    #target_mol1 = 'COC(=O)[C@H]1[C@@H](c2ccc(Cl)cc2)C[C@@H]2CC[C@H]1N2C'
    #target_mol1 = 'COC(=O)[C@H]1[C@@H](O)C[C@@H]2CC[C@H]1N2C'
    #target_mol1 = 'COC(=O)C1=CC[C@@H]2CC[C@H]1N2C'  #origin initiative for trodat 1 ,also 'Clc1ccc([Mg]Br)cc1'
    #target_mol1 = 'CN1[C@H]2CC[C@@H]1[C@@H](C(=O)Oc1ccccc1)[C@@H](c1ccc(Cl)cc1)C2'
    #target_mol1 = 'COC(=O)[C@H]1[C@@H](c2ccccc2)C[C@@H]2CC[C@H]1N2C'
    #del_mol1 = ['COC(=O)[C@H]1[C@@H](c2ccc(Cl)cc2)C[C@@H]2CC[C@H]1N2C']
    #starting_mols.difference_update(del_mol1)
    
    # new initiative reactant for trodat1
    #target_mol1 = 'COC(=O)[C@H]1[C@@H](O)C[C@@H]2CC[C@H]1N2C'
    #del_mol1 = ['COC(=O)[C@H]1[C@@H](O)C[C@@H]2CC[C@H]1N2C']
    #starting_mols.difference_update(del_mol1)
    
    #below is 同位素
    #target_mol1 ='FCCn4ccc(c3cnc(n2ccc1ccncc12)nc3)n4'
    #target_mol1 ='CNc1ccc(/C=C/c2ccc(OCCCCCCF)nc2)cc1'
    #target_mol1 ='CNc3ccc(C2Nc1ccc(O)cc1S2)cc3'
    
    #target_mol1 ='Nc3cccc4cnc(n2ccc1ccncc12)cc34'
    #del_mol1 = 'Nc3cccc4cnc(n2ccc1ccncc12)cc34'
    #del_mol1 = [Chem.MolToSmiles(Chem.MolFromSmiles(del_mol1),allHsExplicit=0,allBondsExplicit=0, isomericSmiles=1)]
    #starting_mols.difference_update(del_mol1)

    #target_mol1 ='OCCC4CCN(c3ccc(n2ccc1ccncc12)nc3)CC4'
    #target_mol1 ='NNc1ccc(/C=C/c2ccc(OCCCCCCF)nc2)cc1'
    #target_mol1 ='FNc3ccc(C2Nc1ccc(O)cc1S2)cc3'
    #target_mol1 ='Fc3cccc4cnc(n2ccc1ccncc12)cc34'
    #target_mol1 ='NCCC4CCN(c3ccc(n2ccc1ccncc12)nc3)CC4'
    #target_mol1 ='Cc5ccc(S(=O)(=O)OCOc4ccc3cc(n2ccc1ccccc12)ncc3c4)cc5'
    #target_mol1 ='Cc5ccc(S(=O)(=O)OCCCOc4ccc3cc(n2ccc1ccncc12)ncc3c4)cc5'
    
    #target_mol1 ='FCCCOc4ccc3cc(n2ccc1ccncc12)ncc3c4'
    
    
    #target_mol1 = 'CC(=O)c2ccc1cc(N(C)CCOCCF)ccc1c2' #M5
    #target_mol1 = 'CC(c2ccc1cc(N(C)CCOCCF)ccc1c2)C(C#N)C#N' #M6 FEONM代謝產物 
    
    #target_mol1 = 'CC(=C(C#N)C#N)c2ccc1cc(N(C)CCOCCF)ccc1c2' #7 FEONM

    #target_mol1 = 'CCOC(=O)[C@H](CS)NCCN[C@@H](CS)C(=O)OCC' #ECD
    #target_mol1 = 'CCOC(=O)C(CS)NCCNC(CS)C(=O)OCC' #ECD    
    #target_mol1 = 'CCOC(=O)[C@H](C[S-])NCC[N-][C@@H](C[S-])C(=O)OCC' #ECD
    #target_mol1 = 'CCOC(=O)C(C[S-])NCC[N-]C(C[S-])C(=O)OCC' #ECD
    #CCOC(=O)[C@H](C[S-])NCC[N-][C@@H](C[S-])C(=O)OCC.O=[99Tc+4] #ECD
    #CCOC(=O)C(C[S-])NCC[N-]C(C[S-])C(=O)OCC.O=[Tc+4] #ECD

    #target_mol1 ='O=C(O)CNC(=O)CNC(=O)CNC(=O)CSC(=O)c1ccccc1' #S-benzoyl-MAG3
    #starting_mols.remove(target_mol1)
    
    #target_mol1 ='Cc3cc(Nc2ccc1nc(F)ccc1n2)ccn3' #JNJ311
    #target_mol1 ='Fc3ccc2nc(Nc1ccncc1)ccc2n3' #derivative of JNJ311 without methyl 
    #target_mol1 ='CC1=C(C=CN=C1)NC2=NC3=C(C=C2)N=C(C=C3)F' #derivative of JNJ311 with methyl
    
    #target_mol1 ='CC(C)C[C@@H]1CN2CCC3=CC(=C(C=C3[C@H]2C[C@H]1O)OC)OCCCF' #F-18-AV-133
    #target_mol1 ='CC(C)CC1CN2CCC3=CC(=C(C=C3C2CC1O)OC)OCCCF' #F-18-AV-133

    
    #starting_mols.add('[F-]')
    #target_mol1 = 'ICCF'
    #target_mol1 = Chem.MolToSmiles(Chem.MolFromSmiles(target_mol1),allHsExplicit=0,allBondsExplicit=0, isomericSmiles=1)
    #starting_mols.remove(target_mol1)
    
    target_mol1 = 'CC(=C(C#N)C#N)c3ccc2cc(N(C)CCOCCOS(=O)(=O)c1ccc(C)cc1)ccc2c3' #6TEONM
    #del_1 = 'CC(=O)c1ccc2cc(N(C)CCO)ccc2c1'
    #del_2 = 'CC(=O)c1ccc2cc(O)ccc2c1'
    #del_mol3 = [del_1,del_2]
    #starting_mols.difference_update(del_mol3)
    #target_mol1 = 'CC(=C(C#N)C#N)c1ccc2cc(N(C)CCOCCO)ccc2c1' #initiative of 6TEONM
    #target_mol1 = 'CC(=C(C#N)C#N)c1ccc2cc(N(C)CCO)ccc2c1' #initiative of 6TEONM
    #target_mol1 = 'CC(=O)c1ccc2cc(O)ccc2c1' #initiative of 6TEONM
    
    #below is ABtest1-1
    #target_mol1 ='[H][C@]12CCC[C@](CCCCC3=CC=C(OCOC)C=C3)(OC3=C1C(=O)CCC3)O2'
    #below is a arbitrary test
    #target_mol1 = '[H][C@@]12OC3=C(O)C=CC4=C3[C@@]11CCN(C)[C@]([H])(C4)[C@]1([H])C=C[C@@H]2O'
    #below is alreay in base compounds
    #target_mol1 = 'CC(=O)NC1=CC=C(O)C=C1'
    #below is REMDESIVIR
    #target_mol1 = 'CCC(CC)COC(=O)[C@H](C)N[P@](=O)(OC[C@H]1O[C@@](C#N)(c2ccc3c(N)ncnn23)[C@H](O)[C@@H]1O)Oc1ccccc1'
    #target_mol1 = m19
    #target_mol1 = 'CCC(CC)COC(=O)[C@H](C)NP(=O)(OC[C@H]1O[C@@](C#N)(c2ccc3c(N)ncnn23)[C@H](O)[C@@H]1O)Oc1ccccc1'
    try:
        target_mol = Chem.MolToSmiles(Chem.MolFromSmiles(target_mol1),allHsExplicit=0,allBondsExplicit=0, isomericSmiles=1)
        if target_mol not in starting_mols: 
#            root = Node(state={target_mol})
            path = plan(target_mol, expansion_net, filter_net, rollout_net, iterations=iters, max_depth=max_d)
            path.sort(key = lambda x: x[-1], reverse = True)
            if path:
                pathset = defaultdict(list)
                for elem in path:
                    pathset[elem[-1]].append(elem) 
                    
        else: print('target is alreay in base compounds')
    except Exception as e:
        print('error: {}'.format(e))
        print('mol which is unable to standize: {}'.format(target_mol1))  
 
    #    import ipdb; ipdb.set_trace()
    

Using TensorFlow backend.


Instructions for updating:
Colocations handled automatically by placer.


Loading base compounds: 26664423it [00:19, 1380301.50it/s]


Base compounds: 26664423


Loading base compounds: 9216407it [00:07, 1183738.35it/s]
Loading base compounds: 63435it [00:00, 1135809.62it/s]


Base compounds: 32636890


expansion: 55608it [00:00, 785317.21it/s]
rollout: 19728it [00:00, 760798.72it/s]


Instructions for updating:
Use tf.cast instead.


100%|██████████| 5000/5000 [01:55<00:00, 65.90it/s]


Path found:


In [4]:

len(path)

70

In [2]:
len(pathset)


67

In [2]:
path

NameError: name 'pathnot' is not defined

In [3]:
from rdkit.Chem import Draw
save_dir = os.path.join(os.getcwd(), 'images/','target_mol/')
if not os.path.isdir(save_dir):
    os.makedirs(save_dir)

for key,value in pathset.items():
#    if key != 0.8732798339579952: break
    ms=[]
    ms.append(target_mol)
    for a in range(len(value[0])-1):
        value[0][a][1]
        for b in value[0][a][1]:
            if b not in ms:
                ms.append(b)
    msmol=[Chem.MolFromSmiles(mol) for mol in ms]
    
    img=Draw.MolsToGridImage(msmol[:],molsPerRow=3,subImgSize=(800,200))
    img.save(save_dir+str(key)+'.png')  
with open(save_dir+'pathset.pickle', 'wb') as f:
    pickle.dump(pathset, f)

In [4]:
with open('data/path-new.pickle', 'wb') as f:
    pickle.dump(path, f)

In [3]:
pathset

defaultdict(list,
            {0.9613110277032852: [[('([N;H0;+0:2]-[c;H0;+0:1])>>([Br;H0;+0]-[c;H0;+0:1]).([NH;+0:2])',
                {'CC(=O)c1ccc2cc(Br)ccc2c1', 'CNCCO'}),
               0.9613110277032852]],
             0.9613085261321068: [[('([N;H0;+0:1]-[c;H0;+0:2])>>([NH;+0:1]).([OH;+0]-[c;H0;+0:2])',
                {'CC(=O)c1ccc2cc(O)ccc2c1', 'CNCCO'}),
               0.9613085261321068]],
             0.9613060122871399: [[('([N;H0;+0:2]-[c;H0;+0:1])>>([F;H0;+0]-[c;H0;+0:1]).([NH;+0:2])',
                {'CC(=O)c1ccc2cc(F)ccc2c1', 'CNCCO'}),
               0.9613060122871399]],
             0.9613018823933601: [[('([N;H0;+0:2]-[c;H0;+0:1])>>([CH3;+0]-[O;H0;+0]-[c;H0;+0:1]).([NH;+0:2])',
                {'CNCCO', 'COc1ccc2cc(C(C)=O)ccc2c1'}),
               0.9613018823933601]],
             0.9612241890764236: [[('([N;H0;+0:2]-[c;H0;+0:1])>>([Cl;H0;+0]-[c;H0;+0:1]).([N;H0;+0:2]-[CH;+0]=[O;H0;+0])',
                {'CC(=O)c1ccc2cc(Cl)ccc2c1', 'CN(C=O)CCO'}),
            

In [5]:
len(path)


173