In [1]:
#for local
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [None]:
from datetime import datetime
import importlib
import numpy as np
import logging
import os
import shutil
from typing import Any
import yaml
import optuna
from rdkit import RDLogger
from generator import Generator
from language import Language
from node import MolSentenceNode
from utils import add_sep, class_from_package, make_logger
RDLogger.DisableLog('rdApp.*')

yaml_path = "config/mcts_d_score_optuna.yaml"
with open(os.path.join(repo_root, yaml_path)) as f:
    conf = yaml.safe_load(f)
transition_args = conf.get("transition_args", {})
model_dir = os.path.join(repo_root, transition_args.pop("model_dir"))
lang_path = conf.get("lang_path")
if lang_path is None:
    lang_name = os.path.basename(os.path.normpath(model_dir)) + ".lang"
    lang_path = add_sep(model_dir) + lang_name
lang = Language.load(lang_path)
reward_class = class_from_package("reward", conf.get("reward_class"))
reward = reward_class(**conf.get("reward_args", {}))
filter_settings = conf.get("filters", [])
filters = []
for s in filter_settings:
    filter_class = class_from_package("filter", s.pop("filter_class"))
    filters.append(filter_class(**s))
generator_args = conf.get("generator_args", {})

def objective(trial):
    policy_class = trial.suggest_categorical("policy_class", ["UCB", "PUCT"])
    c = trial.suggest_loguniform("c", 0.01, 2)
    best_rate = trial.suggest_uniform("best_rate", 0, 1)
    n_rollouts = trial.suggest_int("n_rollouts", 1, 5)
    n_tries = trial.suggest_int("n_tries", 1, 5)
    bottom_p = trial.suggest_loguniform("1-top_p", 0.0005, 0.05)
    top_p = 1 - bottom_p
    temperature = trial.suggest_uniform("temperature", 0.8, 1.2)
    filtered_reward = trial.suggest_uniform("filtered_reward", -2, 0.2)

    output_dir=os.path.join(repo_root, "sandbox", conf["output_dir"], datetime.now().strftime("%m-%d_%H-%M")) + os.sep
    console_level = logging.ERROR
    file_level = logging.DEBUG if conf.get("debug") else logging.INFO
    logger = make_logger(output_dir, console_level=console_level, file_level=file_level)
    logger.info("params:" + str(trial.params))

    transition_args["top_p"] = top_p
    transition_args["temperature"] = temperature
    transition_class = class_from_package("transition", conf["transition_class"])
    transition = transition_class(model_dir=model_dir, lang=lang, logger=logger, device=conf.get("device"), **transition_args)
    
    policy_class = class_from_package("policy", policy_class)
    policy = policy_class(c=c, best_rate=best_rate)
    generator_args["policy"] = policy
        
    root = MolSentenceNode.bos_node(lang, device=conf.get("device")) # TODO: change after root node generalization
    
    generator_class = class_from_package("generator", conf.get("generator_class", "MCTS"))
    generator_args["filtered_reward"] = filtered_reward
    generator_args["n_rollouts"] = n_rollouts
    generator_args["n_tries"] = n_tries
    generator = generator_class(root=root, transition=transition, reward=reward, filters=filters, output_dir=output_dir, logger=logger, **generator_args)

    generator.generate(time_limit=conf.get("time_limit"), max_generations=conf.get("max_generations"))
    return generator.mean_reward(window=conf.get("mean_reward_window")), generator.best_reward
    
study = optuna.create_study(directions=["maximize", "maximize"])
study.optimize(objective, n_trials=conf.get("n_trials"))
print("Optuna trials completed.")
for trial in study.best_trials:
    print("values:", trial.values, "params:", trial.params)

  from .autonotebook import tqdm as notebook_tqdm
[I 2025-06-20 21:07:59,824] A new study created in memory with name: no-name-92d4a681-ac4e-4885-844d-06b7d209143c
