In [None]:
import matplotlib
matplotlib.__version__

In [None]:
from tree import *
from tree_viz import *
# import json
from rdkit import RDLogger
from IPython.display import display
import random
RDLogger.DisableLog('rdApp.*') 

In [None]:
seed = 42
random.seed(seed)

state = random.getstate()

random.setstate(state)

# Input parameters

In [None]:
# choose to either run parallel trees MCTS or single tree MCTS
parallel = True

# levels: number of "turns"
levels = 5
# number of simulations, default 10000; set smaller to test code
num_sims = 10000

# filename to save MCTS output molecule
filename = 'mcts_molecules.json'

# size of molecule
# size = 10
# target of chemical characteristic
# in this case it is log P
goal = 2.3406
# target of synthesizability score
sa_target = 0
# SMILES symbols that are choices for molecular generation
choices = ['C', 'O', '=', 'N', 'c', '1', 'S', 'P', 'F', '\n']
# max value of target chemical characteristic
max_value1 = 20
# max value of synthesizability score
max_value2 = 10
# counter to track number of turns
turn = levels + 2

# set the beamsize
beamsize = 10
# boltzmann constant
alpha = 2

# Run MCTS

In [None]:
if parallel:
    current_nodes = [Node(State(
        goal=goal, 
        sa_target=sa_target, 
        allchoices=choices,
        max_value1=max_value1,
        max_value2=max_value2, 
        turn=turn))]
    for l in range(levels):
        logger.info(f"This is the turn number: {l}")
        next_nodes=utcbeam(budget=num_sims,rootpop=current_nodes, beamsize=beamsize, alpha=alpha)
        for i in current_nodes:
            print(f"Level {l}")
            print(f"This is one of the current nodes: {i}")
            print(f"Num Children: {len(i.children)}")
            for j,c in enumerate(i.children):
                print(j,c)
        print("These are the best children:")
        for i in next_nodes:
            print(i)
        current_nodes = next_nodes
        print("--------------------------------")

    logger.info("MCTS finished.")

else:
    current_node = Node(State(
    # size=size, 
    goal=goal, 
    sa_target=sa_target, 
    allchoices=choices,
    max_value1=max_value1,
    max_value2=max_value2, 
    turn=turn))
    for l in range(levels):
        logger.info(f"This is the turn number: {l}")
        next_node=utcsearch(num_sims/(l+1),current_node, alpha=alpha) #budget = num_sims/(l+1)
        # as the increment of turns increases, the budget/number of simulations decreases
        # next_node is the node that is best known child or a newly expanded child node with the highets reward value
        print(f"level {l}")
        print(f"Num Children: {len(current_node.children)}")
        for i,c in enumerate(current_node.children):
            print(i,c)
        print(f"Best Child: {next_node.state}")
        current_node = next_node
        print("--------------------------------")

    logger.info("MCTS finished.")

In [None]:
# root_node = current_node
# while root_node.parent is not None:
#     root_node = root_node.parent
# out1 = tree_search(root_node,thresh=4)
# out1

# Visualize final output molecule(s)

In [None]:
if parallel:
    allsmiles = []
    allmols = []
    
    for i in current_nodes:
        smiles = i.state.smiles
        allsmiles += [smiles]

        mol = Chem.MolFromSmiles(smiles)
        allmols.append(mol)
        mol.SetProp("name", str(Chem.rdMolDescriptors.CalcMolFormula(mol)))
        mol.SetProp("logP", str(f"{Descriptors.MolLogP(mol):.3f}"))
        mol.SetProp("SA score", str(f"{sascorer.calculateScore(mol):.3f}"))
    
    img = Chem.Draw.MolsToGridImage(
        allmols, 
        legends=[f"{mol.GetProp('name')}\nlogP: {mol.GetProp('logP')}\nSA: {mol.GetProp('SA score')}" for mol in allmols],
        subImgSize=(350,350))
    
    display(img)

else:
    smiles = ''.join(current_node.state.moves)
    mol = Chem.MolFromSmiles(smiles)
    logp = Descriptors.MolLogP(mol)
    sa_score = sascorer.calculateScore(mol)
    print(f"SA score: {sa_score:.3f}")
    print(f"LogP: {logp:.3f}")
    display(mol)

# Visualize the search tree of the best children

In [None]:
if parallel:
    for i, j in enumerate(current_nodes):
        tuples = make_tree_nodes(node=j, size=levels)
        node_info = make_node_info(node=j, size=levels)
        # display(node_info)
        mass_plotting(node_info=node_info, params=['visits', 'reward', 'logp', 'sa_score'], tuples=tuples, smiles=allsmiles[i])
else:
    tuples = make_tree_nodes(node=current_node, size=turn)
    node_info = make_node_info(node=current_node, size=turn)
    mass_plotting(node_info=node_info, tuples=tuples, smiles=smiles)

# Visualize the search tree of the OK children