from https://github.com/brilee/python_uct

In [None]:
# main function for the Monte Carlo Tree Search from https://www.geeksforgeeks.org/ml-monte-carlo-tree-search-mcts/
def monte_carlo_tree_search(root): #root is an initialized dummy node, in our case, a smiles string of a hypothetical product
	
	while resources_left(time, computational power): 
		leaf = traverse(root) #if node.visits = 0, rollout, else node.expand()
		simulation_result = rollout(leaf) 
		backpropagate(leaf, simulation_result) 
		
	return best_child(root) 

# function for node traversal 
def traverse(node): 
	while fully_expanded(node): 
		node = best_uct(node) 
		
	# in case no children are present / node is terminal 
	return pick_univisted(node.children) or node 

# function for the result of the simulation 
def rollout(node): 
	while non_terminal(node): 
		node = rollout_policy(node) 
	return result(node) 

# function for randomly selecting a child node (the heart of MCTS)
def rollout_policy(node): 
	return pick_random(node.children) 

# function for backpropagation 
def backpropagate(node, result): 
	if is_root(node) return
	node.stats = update_stats(node, result) 
	backpropagate(node.parent) 

# function for selecting the best child 
# node with highest number of visits 
def best_child(node): 
	return max_visits(child)


In [38]:
import collections
import random

import rdkit
from rdkit import Chem
from rdkit.Chem.EState import Fingerprinter
from rdkit.Chem import Descriptors
from rdkit.Chem import rdFMCS
from rdkit.Chem.rdmolops import RDKFingerprint
from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit import DataStructs
from rdkit.Avalon.pyAvalonTools import GetAvalonFP

import numpy as np
import math

class UCTNode(): #the node object, it corresponds to a molecular state.
    def __init__(self, mol_state, move, parent=None): #a given node needs a molecular state
        self.mol_state = mol_state #associated State() object which is a molecular structure
        self.move = move # if a node was made with a move, that move index is linked here. Move is a master index of sorts
        self.is_expanded = False # nodes start unexpanded
        self.parent = parent  # Optional[UCTNode] (if it's not the root it should have a parent)
        self.children = {}  # Dict[move, UCTNode] starts with no children
        self.child_priors = np.zeros([362], dtype=np.float32) # an array of probabilities, indicating the current preference for each followup move
        self.child_number_visits = np.zeros([362], dtype=np.float32) # an array of the number of visits of each child
        self.children_total_values = np.zeros([362], dtype=np.float32) # an array of the total values of each of the children
        
    """Now that each node no longer knows about its own statistics, 
    we create aliases for a node’s statistics by using property getters and setters. 
    These allow us to transparently proxy these properties to the relevant entry in the parents’ child arrays."""
        
    @property
    def number_visits(self):
        return self.parent.child_number_visits[self.move]

    @number_visits.setter
    def number_visits(self, value):
        self.parent.child_number_visits[self.move] = value

    @property
    def total_value(self):
        return self.parent.children_total_values[self.move]

    @total_value.setter # self.move indexes the child to itself in the parent's children_total_values array. 
    def total_value(self, value):
        self.parent.children_total_values[self.move] = value

    def child_Q(self): # calculate Quality for the child arrays
        return self.children_total_values / (1 + self.child_number_visits)

    def child_U(self): # calculate Upper confidence bound for the child arrays
        return math.sqrt(self.number_visits) * (self.child_priors / (1 + self.child_number_visits))

    def best_child(self): #quickly finds the index of the child that has the highest aggregate of Q and U scores
        return np.argmax(self.child_Q() + self.child_U()) 

    def select_leaf(self):
        current = self
        while current.is_expanded: # if not expanded immediately return self
            best_move = current.best_child() #get the index of the best child
            current = current.maybe_add_child(best_move) #adds a child at the index if it isn't there
        return current

    def expand(self, child_priors):
        self.is_expanded = True
        self.child_priors = child_priors

    def maybe_add_child(self, move):
        if move not in self.children:
            self.children[move] = UCTNode(
                self.mol_state.react(move), move, parent=self) # make a child with a reaction, need to update react method to include structure and enzyme list or even combine the two
        return self.children[move]

    def backup(self, value_estimate: float): #NEED to check 
        current = self
        while current.parent is not None:
            current.number_visits += 1
            current.total_value += value_estimate# used to be (value_estimate * self.mol_state.to_play)
            current = current.parent #MIGHT not be right

def UCT_search(mol_state, num_reads):
    root = UCTNode(mol_state, move=None, parent=DummyNode()) # initializes the search starting with the given mol_state
# the parent is a dummy node because there shouldn't be a parent for the root molecule
    
    pathway_choice = []
    
    terminal_states = []
    
    for _ in range(num_reads):
        leaf = root.select_leaf()
        child_priors, value_estimate = Metrics.evaluate(leaf.mol_state.mol)
        leaf.expand(child_priors)
        leaf.backup(value_estimate)
        
    print('tree search complete')    
        
    for i in range(10):
        terminal_states.append(random.choice(s)) # currently semi-random SMILES
    
    print('terminal states enumerated')
    
    while root.mol_state.smiles not in terminal_states: # if the root is a terminal SMILES stop immediately
        # this method has no failsafe as it assumes there always is a path
        
        idx = np.argmax(root.child_number_visits) # find the index of the most visted child node
        print(idx)
        pathway_choice.append(root.children[idx]) # add that node to the pathway
        root = root.children[idx] # make the child node the root
        
    return pathway_choice
    
    #return np.argmax(root.child_number_visits) # returns the index of the root's child, which is the first retrosynthetic node
# doing this recursively should result in the full pathway if solved?
        

class DummyNode(object): # makes a node without any child values or parent
    def __init__(self): 
        self.parent = None
        self.children_total_values = collections.defaultdict(float) 
        self.child_number_visits = collections.defaultdict(float)

class Metrics(): #returns an array of probabilities for followup moves and a value estimation of the current node
    @classmethod
    def evaluate(self, mol):
        return np.random.random([362]), np.random.random()

class State(): # corresponds to a molecular structure. contains a smiles string and rdkit mol
    # NEED to make sure move indexes map to reactions # ADD enzyme? or add to node IDK
    def __init__(self, smiles=None, enzyme=None): #smiles should be a smiles string, enzyme should be an EC#
        self.smiles = smiles
        self.mol = Chem.rdmolfiles.MolFromSmiles(smiles)
        self.enzyme = enzyme

    def react(self, move, structure_list=s, enzyme_list=e): #change s
        return State(structure_list[move], enzyme_list[move])#np.random.random()) # eventually build and return a new structure or point to a new smiles

In [39]:
import string

def randomSmiles(m1):
    m1.SetProp("_canonicalRankingNumbers", "True")
    idxs = list(range(0,m1.GetNumAtoms()))
    random.shuffle(idxs)
    for i,v in enumerate(idxs):
        m1.GetAtomWithIdx(i).SetProp("_canonicalRankingNumber", str(v))
    return Chem.MolToSmiles(m1)

m1 = Chem.MolFromSmiles("CNOPc1ccccc1")
m2 = Chem.MolFromSmiles("O=C(O)C(=O)CCC(=O)O")
s = []
for i in range(1000):
    smiles = randomSmiles(m1)
    s.append(smiles)
e= []
for i in range(1000):
    smiles = randomSmiles(m1)
    e.append(''.join(random.choices(string.ascii_uppercase + string.digits, k=8)))


Testing: if the search results in an index error of 0 it probably means it reached a depth that hadn't properly been expanded? All arrays are intialized with zeroes

In [40]:
smiles = 'CC1CCC2CC(C(=CC=CC=CC(CC(C(=O)C(C(C(=CC(C(=O)CC(OC(=O)C3CCCCN3C(=O)C(=O)C1(O2)O)C(C)CC4CCC(C(C4)OC)O)C)C)O)OC)C)C)C)OC'
state = State(smiles)
num_reads = 10000
import time
tick = time.time()
node_list = UCT_search(state, num_reads)
tock = time.time()
print("Took %s sec to run %s times" % (tock - tick, num_reads))

print("parsing nodelist")

for member in node_list:
    print(member.mol_state.smiles)
    
print("nodelist enzymes")
print("oops, to view the enzymes for this pathway you need a premium metamoles subscription")
print("just kidding")

print("parsing enzymes")

for member in node_list:
    print(member.mol_state.enzyme)

tree search complete
terminal states enumerated
57
Took 3.5820913314819336 sec to run 10000 times
parsing nodelist
c1ccc(PONC)cc1
nodelist enzymes
oops, to view the enzymes for this pathway you need a premium metamoles subscription
just kidding
parsing enzymes
4RWKUP9W
