In [1]:
#for local
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [2]:
from datetime import datetime
import importlib
import numpy as np
import logging
import os
import shutil
from typing import Any
import yaml
import optuna
from rdkit import RDLogger
from generator import Generator
from language import Language
from node import MolSentenceNode
from utils import add_sep, class_from_package, make_logger
RDLogger.DisableLog('rdApp.*')
optuna.logging.disable_default_handler()

yaml_path = "config/optuna_generation.yaml"

with open(os.path.join(repo_root, yaml_path)) as f:
    conf = yaml.safe_load(f)
transition_args = conf.get("transition_args", {})
model_dir = os.path.join(repo_root, transition_args.pop("model_dir"))
lang_path = conf.get("lang_path")
if lang_path is None:
    lang_name = os.path.basename(os.path.normpath(model_dir)) + ".lang"
    lang_path = add_sep(model_dir) + lang_name
lang = Language.load(lang_path)
policy_args = conf.get("policy_args", {})
reward_class = class_from_package("reward", conf.get("reward_class"))
reward = reward_class(**conf.get("reward_args", {}))
filter_settings = conf.get("filters", [])
filters = []
for s in filter_settings:
    filter_class = class_from_package("filter", s.pop("filter_class"))
    filters.append(filter_class(**s))
generator_args = conf.get("generator_args", {})

def objective(trial):
    transition_args["top_p"] = 1 - trial.suggest_loguniform("1-top_p", 0.002, 0.02)
    transition_args["temperature"] = trial.suggest_uniform("temperature", 0.8, 1.2)
    policy_class = trial.suggest_categorical("policy_class", ["UCB", "PUCT"])
    policy_args["c"] = trial.suggest_loguniform("c", 0.01, 2)
    policy_args["best_rate"] = trial.suggest_uniform("best_rate", 0, 1)
    generator_args["filtered_reward"] = trial.suggest_uniform("filtered_reward", -2, 0.2)
    generator_args["rollout_width"] = trial.suggest_int("rollout_width", 1, 10)
    generator_args["allow_rollout_overlaps"] = trial.suggest_categorical("allow_rollout_overlaps", [True, False])
    generator_args["n_rollouts"] = trial.suggest_int("n_rollouts", 1, 10)
    generator_args["n_tries"] = trial.suggest_int("n_tries", 1, 10)

    output_dir=os.path.join(repo_root, "sandbox", conf["output_dir"], datetime.now().strftime("%m-%d_%H-%M")) + os.sep
    console_level = logging.ERROR
    file_level = logging.DEBUG if conf.get("debug") else logging.INFO
    logger = make_logger(output_dir, console_level=console_level, file_level=file_level)
    logger.info("params:" + str(trial.params))

    transition_class = class_from_package("transition", conf["transition_class"])
    transition = transition_class(model_dir=model_dir, lang=lang, logger=logger, device=conf.get("device"), **transition_args)
    
    policy_class = class_from_package("policy", policy_class)
    policy = policy_class(**policy_args)
    generator_args["policy"] = policy
        
    root = MolSentenceNode.bos_node(lang, device=conf.get("device")) # TODO: change after root node generalization
    
    generator_class = class_from_package("generator", conf.get("generator_class", "MCTS"))
    generator = generator_class(root=root, transition=transition, reward=reward, filters=filters, output_dir=output_dir, logger=logger, **generator_args)

    generator.generate(time_limit=conf.get("time_limit"), max_generations=conf.get("max_generations"))
    best_reward_rate = conf.get("best_reward_rate")
    mean_reward = generator.mean_reward(window=conf.get("mean_reward_window"))
    trial.set_user_attr("mean_reward", mean_reward)
    trial.set_user_attr("best_reward", generator.best_reward)
    return (1 - best_reward_rate) * mean_reward + best_reward_rate * generator.best_reward

def log_callback(study: optuna.Study, trial: optuna.Trial):
    val = trial.value
    print_trial(trial)
    
def print_trial(trial: optuna.Trial):
    print(f"Trial {trial.number} score={trial.value:.3f}, mean_reward={trial.user_attrs['mean_reward']:.3f}, best_reward={trial.user_attrs['best_reward']:.3f}, params={trial.params}")
    
def print_best_trials(study: optuna.Study):
    print("Optuna trials completed.")
    print("------ Best trials -----")
    best_trials = sorted([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE], key=lambda t: t.value, reverse=True)[:5]
    for t in best_trials:
        print_trial(t)

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [3]:
name = conf.get("study_name")
study = optuna.create_study(direction="maximize", study_name=name, storage="sqlite:///generation_result/optuna_" + name + ".db", sampler=optuna.samplers.TPESampler())
study.optimize(objective, n_trials=conf.get("n_trials"), callbacks=[log_callback])
print_best_trials(study)

Trial 0 score=0.259, mean_reward=0.177, best_reward=0.450, params={'1-top_p': 0.01311618576535542, 'temperature': 0.909535226412763, 'policy_class': 'PUCT', 'c': 0.010156568262119215, 'best_rate': 0.9106261790495621, 'filtered_reward': -0.5800526439568316, 'rollout_width': 8, 'allow_rollout_overlaps': True, 'n_rollouts': 4, 'n_tries': 10}
Trial 1 score=0.297, mean_reward=0.227, best_reward=0.459, params={'1-top_p': 0.0052619152221884255, 'temperature': 0.9602059116067898, 'policy_class': 'UCB', 'c': 0.03025987391770358, 'best_rate': 0.89326870230366, 'filtered_reward': -0.011185826239096919, 'rollout_width': 7, 'allow_rollout_overlaps': True, 'n_rollouts': 1, 'n_tries': 6}
Trial 2 score=0.220, mean_reward=0.136, best_reward=0.417, params={'1-top_p': 0.01512668816478695, 'temperature': 1.0549615879883538, 'policy_class': 'PUCT', 'c': 0.5450707599166557, 'best_rate': 0.296298006934008, 'filtered_reward': -1.6193171859292772, 'rollout_width': 1, 'allow_rollout_overlaps': True, 'n_rollouts

In [None]:
# continue
optuna.study.load_study(study_name="d_score", storage="sqlite:///generation_result/optuna_d_score.db")
study.optimize(objective, n_trials=3, callbacks=[log_callback])
print_best_trials(study)