## Tutorial #2
- Basic workflow (without YAML)
- Save and load
- Analysis
- Logger and seed

# Basic workflow (without YAML)

This section provides an overview of the basic workings and setups of each class.

More practical workflow (de novo molecular generation, chained to lead optimizaiton) will be covered in `tutorial_user_2.ipynb`.

In [None]:
# Imports (may take some time on the first run)

import sys
repo_root = "../../" # Change this if running the notebook from a different directory
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

from filter import ValidityFilter, RadicalFilter
from generator import MCTS
from node import SMILESStringNode
from policy import PUCT
from transition import JensenTransition
from reward import LogPReward

In [None]:
# Set up a generator (without YAML)

root_node = SMILESStringNode.node_from_key("c1ccccc1") # Benzene
reward = LogPReward()
filters = [ValidityFilter(), RadicalFilter()] # Note: ValidityFilter checks whether the molecule is valid. Since other filters and rewards typically assume validity and do not recheck it, this filter should usually be put first in molecular generation.

policy = PUCT(c=0.1, best_rate=0.9) # Hover over the class name (e.g., "PUCT") to see its available arguments, types, default values, and descriptions. Note: This may not be supported in some IDEs, and is not supported for classes loaded via lazy imports.

generator = MCTS(root=root_node, transition=JensenTransition(), reward=reward, filters=filters, filter_reward=[-1,0], policy=policy,
                 avoid_duplicates=True, cut_failed_child=True,
                 info_interval=100, output_dir="generation_result/tutorial_1") # sandbox/tutorials/generation_result/tutorial_1

In [None]:
# Start generation

generator.generate(max_generations=1000, time_limit=60) # Stops generation when either the number of generated nodes reaches 1000 or 60 seconds have passed. 
# Each generated molecule is logged to a CSV file in the output directory.

In [None]:
# Analyze and plot results

generator.analyze()
generator.plot(moving_average_window=0.05, reward_top_ps=[0.1, 0.5]) # Plot the objective values and final reward for the generated molecules. The plots will also be saved to the output directory.

# Check transition

In [None]:
from utils import draw_mol

for child in generator.root.children: # child nodes of the root node
    print(f"Probability: {child.last_prob:.3f} Action: {child.last_action}")
    draw_mol(child.mol(), width=70, height=70) # All MolNode subclasses have a mol() method.

# Save and load

In [None]:
# Continue generation using the existing generator

generator.generate(max_generations=200, time_limit=60)

In [None]:
# Save the generator and its current progress to a file

save_path = generator.output_dir() + "save.gtr" # Generator's output directory can be fetched using output_dir()
generator.save(save_path)

In [None]:
# Load generator

generator = MCTS.load_file(save_path, transition=JensenTransition()) # Since some transitions rely on heavy models, they are separated from the generator's saved state. 

In [None]:
# Continue generation using the loaded generator

generator.generate(max_generations=200, time_limit=60)

# Logger and seed
(Optional) Users can specify a logger and a seed used for the generation.

In [None]:
import logging
from utils import make_logger, set_seed

seed = 0 # Set to None if you want the seed to be set automatically.

for i in range(3):
    output_dir=f"generation_result/tutorial_1_seed_{seed}/{i}"
    
    # make logger
    logger = make_logger(output_dir=output_dir, console_level=logging.INFO, file_level=logging.INFO)
    logger.info("------------------------------------------------------------------")
    
    # set seed
    set_seed(seed, logger=logger)
    
    # Make a new root node: The root node defined above already has child nodes. Starting with root=benzene will continue the generation using the existing search tree.
    new_root = SMILESStringNode.node_from_key("c1ccccc1")
    
    # Make a generator and start generation
    generator = MCTS(root=new_root, transition=JensenTransition(), reward=reward, filters=filters, filter_reward=0, policy=policy,
                    avoid_duplicates=True, cut_failed_child=True,
                    info_interval=1, # All keys of generated nodes will be logged to the console and file if info_interval is set to 1
                    output_dir=output_dir, logger=logger)
    generator.generate(max_generations=50)
