In [1]:
#for local
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [None]:
from datetime import datetime
import importlib
import numpy as np
import logging
import os
import shutil
from typing import Any
import yaml
import optuna
from rdkit import RDLogger
from generator import Generator
from language import Language
from node import MolSentenceNode
from utils import add_sep, class_from_package, make_logger
RDLogger.DisableLog('rdApp.*')

yaml_path = "config/optuna_generation.yaml"

with open(os.path.join(repo_root, yaml_path)) as f:
    conf = yaml.safe_load(f)
transition_args = conf.get("transition_args", {})
model_dir = os.path.join(repo_root, transition_args.pop("model_dir"))
lang_path = conf.get("lang_path")
if lang_path is None:
    lang_name = os.path.basename(os.path.normpath(model_dir)) + ".lang"
    lang_path = add_sep(model_dir) + lang_name
lang = Language.load(lang_path)
policy_args = conf.get("policy_args", {})
reward_class = class_from_package("reward", conf.get("reward_class"))
reward = reward_class(**conf.get("reward_args", {}))
filter_settings = conf.get("filters", [])
filters = []
for s in filter_settings:
    filter_class = class_from_package("filter", s.pop("filter_class"))
    filters.append(filter_class(**s))
generator_args = conf.get("generator_args", {})

def objective(trial):
    transition_args["top_p"] = 1 - trial.suggest_loguniform("1-top_p", 0.001, 0.015)
    transition_args["temperature"] = trial.suggest_uniform("temperature", 0.7, 1.2)
    # policy_class = trial.suggest_categorical("policy_class", ["UCB", "PUCT"])
    policy_class = "UCB"
    policy_args["c"] = trial.suggest_uniform("c", 0.01, 1)
    policy_args["best_rate"] = trial.suggest_uniform("best_rate", 0, 1)
    generator_args["filtered_reward"] = trial.suggest_uniform("filtered_reward", -1, 0.2)
    generator_args["rollout_width"] = trial.suggest_int("rollout_width", 1, 40)
    # generator_args["allow_rollout_overlaps"] = trial.suggest_categorical("allow_rollout_overlaps", [True, False])
    generator_args["n_rollouts"] = trial.suggest_int("n_rollouts", 1, 10)
    generator_args["n_tries"] = trial.suggest_int("n_tries", 1, 5)

    output_dir=os.path.join(repo_root, "sandbox", conf["output_dir"], datetime.now().strftime("%m-%d_%H-%M")) + os.sep
    console_level = logging.ERROR
    file_level = logging.DEBUG if conf.get("debug") else logging.INFO
    logger = make_logger(output_dir, console_level=console_level, file_level=file_level)
    logger.info("params:" + str(trial.params))

    transition_class = class_from_package("transition", conf["transition_class"])
    transition = transition_class(model_dir=model_dir, lang=lang, logger=logger, device=conf.get("device"), **transition_args)
    
    policy_class = class_from_package("policy", policy_class)
    policy = policy_class(**policy_args)
    generator_args["policy"] = policy
        
    root = MolSentenceNode.bos_node(lang, device=conf.get("device")) # TODO: change after root node generalization
    
    generator_class = class_from_package("generator", conf.get("generator_class", "MCTS"))
    generator = generator_class(root=root, transition=transition, reward=reward, filters=filters, output_dir=output_dir, logger=logger, **generator_args)
    
    max_generations, time_limit = conf.get("max_generations"), conf.get("time_limit")
    n_steps, best_reward_rate, mean_reward_window = conf.get("n_steps"), conf.get("best_reward_rate"), conf.get("mean_reward_window")
    
    for i in range(0, n_steps):
        generator.generate(max_generations=max_generations / n_steps, time_limit=time_limit / n_steps)
        mean_reward = generator.mean_reward(window=mean_reward_window)
        intermediate_value = (1 - best_reward_rate) * mean_reward + best_reward_rate * generator.best_reward
        trial.report(intermediate_value, i)
        if trial.should_prune():
            print(f"Trial {trial.number} - Step {i}: mean_reward={mean_reward:.3f}, best_reward={generator.best_reward:.3f}, params={trial.params}")
            raise optuna.TrialPruned()

    trial.set_user_attr("mean_reward", mean_reward)
    trial.set_user_attr("best_reward", generator.best_reward)
    print(f"Trial {trial.number}: mean_reward={trial.user_attrs['mean_reward']:.3f}, best_reward={trial.user_attrs['best_reward']:.3f}")
    generator.plot(**conf.get("plot_args"))
    return (1 - best_reward_rate) * mean_reward + best_reward_rate * generator.best_reward
    
def print_trial(trial: optuna.Trial):
    print(f"Trial {trial.number} score={trial.value:.3f}, mean_reward={trial.user_attrs['mean_reward']:.3f}, best_reward={trial.user_attrs['best_reward']:.3f}, params={trial.params}")
    
def print_best_trials(study: optuna.Study):
    print("Optuna trials completed.")
    print("------ Best trials -----")
    best_trials = sorted([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE], key=lambda t: t.value, reverse=True)[:5]
    for t in best_trials:
        print_trial(t)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# start search
name = conf.get("study_name")
storage = "sqlite:///generation_result/optuna_" + name + ".db"
sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
# sampler = optuna.samplers.GPSampler(deterministic_objective=False) # better if not using pruner?
pruner = optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=0, interval_steps=1)
study = optuna.create_study(direction="maximize", study_name=name, storage=storage, sampler=sampler, pruner=pruner)
study.enqueue_trial({"1-top_p": 0.002598657659083491, "temperature": 0.8808830873822, "c": 0.33646460754494056, "best_rate": 0.6042315721899106, "filtered_reward": 0.16523029006562534, "rollout_width": 15, "n_rollouts": 5, "n_tries": 2})
study.enqueue_trial({"1-top_p": 0.005, "temperature": 1, "c": 0.2, "best_rate": 0.5, "filtered_reward": 0, "rollout_width": 1, "n_rollouts": 1, "n_tries": 1})
study.enqueue_trial({"1-top_p": 0.015, "temperature": 1.2, "c": 0.3, "best_rate": 0.7, "filtered_reward": 0, "rollout_width": 15, "n_rollouts": 10, "n_tries": 3})
study.optimize(objective, n_trials=conf.get("n_trials"))
print_best_trials(study)

[I 2025-06-23 02:35:44,407] A new study created in RDB with name: d_score_200000


Trial 0: mean_reward=0.367, best_reward=0.621


[I 2025-06-23 04:39:03,015] Trial 0 finished with value: 0.4434488203183479 and parameters: {'1-top_p': 0.002598657659083491, 'temperature': 0.8808830873822, 'c': 0.33646460754494056, 'best_rate': 0.6042315721899106, 'filtered_reward': 0.16523029006562534, 'rollout_width': 15, 'n_rollouts': 5, 'n_tries': 2}. Best is trial 0 with value: 0.4434488203183479.


Trial 1: mean_reward=0.402, best_reward=0.577


[I 2025-06-23 06:09:56,504] Trial 1 finished with value: 0.4545571971174015 and parameters: {'1-top_p': 0.005, 'temperature': 1.0, 'c': 0.2, 'best_rate': 0.5, 'filtered_reward': 0.0, 'rollout_width': 1, 'n_rollouts': 1, 'n_tries': 1}. Best is trial 1 with value: 0.4545571971174015.


Trial 2: mean_reward=0.352, best_reward=0.561


[I 2025-06-23 08:18:10,638] Trial 2 finished with value: 0.4147843812009473 and parameters: {'1-top_p': 0.015, 'temperature': 1.2, 'c': 0.3, 'best_rate': 0.7, 'filtered_reward': 0.0, 'rollout_width': 15, 'n_rollouts': 10, 'n_tries': 3}. Best is trial 1 with value: 0.4545571971174015.
[W 2025-06-23 08:34:03,193] Trial 3 failed with parameters: {'1-top_p': 0.007513321515291435, 'temperature': 0.7636748119044094, 'c': 0.7979051452303083, 'best_rate': 0.11289165017210212, 'filtered_reward': -0.8759411378666879, 'rollout_width': 4, 'n_rollouts': 9, 'n_tries': 3} because of the following error: KeyError('mean_reward').
Traceback (most recent call last):
  File "/opt/anaconda3/envs/v3-forge/lib/python3.11/site-packages/optuna/study/_optimize.py", line 201, in _run_trial
    value_or_values = func(trial)
                      ^^^^^^^^^^^
  File "/var/folders/3h/91wbm7510zq63w21p0pzq_g40000gn/T/ipykernel_82850/664523383.py", line 78, in objective
    print(f"Trial {trial.number} - Step {i}: mea

KeyError: 'mean_reward'

In [None]:
# continue search
study = optuna.study.load_study(study_name="d_score_200000", storage="sqlite:///generation_result/optuna_d_score_25000.db")
study.optimize(objective, n_trials=8)
print_best_trials(study)