## Tutorial (Dev) #2
- De novo generation using RNN
- Chain to lead optimization

In [11]:
# Imports

import sys
repo_root = "../../" # Change this if running the notebook from a different directory
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)
    
import os
from filter import ValidityFilter, RadicalFilter
from generator import MCTS
from language import Language
from node import MolSentenceNode
from policy import UCT
from transition import RNNTransition
from reward import JScoreReward

# De novo generation using RNN

In [None]:
# Load the RNN model and the language file

model_dir = os.path.join(repo_root, "model/smiles/drugs_zinc/gru")
device = "cpu" # For RNNTransition, using the CPU tends to be faster even in GPU environments.

smiles_lang = Language.load(os.path.join(model_dir, "smiles_zinc.lang")) # In the YAML workflow, a language file with the same name as the model directory will be loaded automatically.
rnn_transition = RNNTransition(lang=smiles_lang, model_dir=model_dir, device=device,
                               top_p=0.995) # Use tokens covering the top 99.5% cumulative probability, ignoring the rest.

In [None]:
# Set up a de novo generator

key = "" # Start from an empty SMILES
MolSentenceNode.use_canonical_smiles_as_key = True # If set to True, the generated SMILES tensor will be converted to canonical SMILES to avoid duplicate molecules.
root = MolSentenceNode.node_from_key(key=key, lang=smiles_lang, device=device)

reward = JScoreReward()
filters = [ValidityFilter(), RadicalFilter()]
policy = UCT(c=0.1, best_rate=0.5)

generator_de_novo = MCTS(root=root, transition=rnn_transition, reward=reward, filters=filters, filter_reward=0, policy=policy,
                         avoid_duplicates=False, # The tree structure of the transition graph is guaranteed, so duplication checks within the search tree are unnecessary.
                         cut_failed_child=False) # Filter results are probabilistic, as the evaluation step involves rollout: removing a child node whose rollout leads to a filtered node might eliminate a potentially good path.

In [None]:
# Start generation

generator_de_novo.generate(max_generations=1000, time_limit=60)
generator_de_novo.plot(moving_average_window=0.05, reward_top_ps=[0.1])

# Chain to lead optimization

In [None]:
# Generated molecules with top_k rewards

n_keys_to_pass = 5
top_k = generator_de_novo.top_k(k=n_keys_to_pass)
for tuple in top_k:
    print("Key: ", tuple[0], " Reward: ", tuple[1])

In [None]:
# Make a virtual surrogate node as the common parent of top_k molecules

from node import CanonicalSMILESStringNode, SurrogateNode

top_keys = [key for key, _ in top_k]
surrogate_root = SurrogateNode()
for s in top_keys:
    node = CanonicalSMILESStringNode.node_from_key(key=s, parent=surrogate_root, last_prob=1/len(top_keys))
    surrogate_root.add_child(node)

In [None]:
# Set up a lead generator

from transition import JensenTransition
from policy import PUCT

generator_lead = MCTS(root=surrogate_root, transition=JensenTransition(),
                      reward=generator_de_novo.reward, filters=generator_de_novo.filters, filter_reward=generator_de_novo.filter_reward, # inherit
                      policy=PUCT(c=0.2, best_rate=0.9),
                      avoid_duplicates=False, # The transition graph has cycles and convergences
                      cut_failed_child=False) # Filter results are not probabilistic

generator_lead.inherit(generator_de_novo) # Inherits the generation results

In [None]:
# Lead optimization

generator_lead.generate(max_generations=1000, time_limit=60)
generator_lead.plot(moving_average_window=0.05, reward_top_ps=[0.1])