In [2]:
# add path (for local)
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [1]:
guacamol_oracle_names = ["zaleplon_mpo", "perindopril_mpo", "albuterol_similarity", "amlodipine_mpo", "celecoxib_rediscovery", "deco_hop", "fexofenadine_mpo", "isomers_c7h8n2o2", "isomers_c9h10n2o2pf2cl", "median1", "median2", "mestranol_similarity", "osimertinib_mpo", "ranolazine_mpo", "scaffold_hop", "sitagliptin_mpo", "thiothixene_rediscovery", "troglitazone_rediscovery", "valsartan_smarts"]
tdc_oracle_names = ["drd2", "gsk3b", "jnk3", "qed"]
oracle_names = guacamol_oracle_names + tdc_oracle_names
    
def reward_class_name_from_oracle_name(oracle_name: str) -> str:
    if oracle_name in tdc_oracle_names:
        return "TDCReward"
    else:
        return "GuacaMolReward"

In [None]:
import os
import optuna
from rdkit import RDLogger
from node import SMILESStringNode
from utils import generator_from_conf, conf_from_yaml
RDLogger.DisableLog('rdApp.*')

yaml_path_r = "config/optuna/mol_opt_rnn.yaml"
yaml_path_j = "config/optuna/mol_opt_jensen.yaml"

def objective(trial):
    conf_r = conf_from_yaml(yaml_path_r, repo_root)
    conf_r.setdefault("transition_args", {})
    conf_r["transition_args"]["sharpness"] = trial.suggest_float("sharpness", 0.8, 1.1)
    conf_r["transition_args"]["top_p"] = trial.suggest_float("top_p", 0.993, 0.999)
    
    conf_r.setdefault("policy_args", {})
    conf_r["policy_args"]["c"] = trial.suggest_float("c", 0.01, 0.5)
    conf_r["policy_args"]["best_rate"] = trial.suggest_float("best_rate", 0, 1)
    conf_r["policy_args"]["prior"] = trial.suggest_float("prior", 0.3, 1.2)
    conf_r["policy_args"]["prior_weight"] = trial.suggest_int("prior_weight", 0, 2)
    conf_r["policy_args"]["max_prior"] = trial.suggest_float("max_prior", 0, 0.8)

    conf_r.setdefault("generator_args", {})
    conf_r["generator_args"]["eval_width"] = trial.suggest_int("eval_width", 1, 40)
    conf_r["generator_args"]["n_evals"] = trial.suggest_int("n_evals", 1, 10)
    conf_r["generator_args"]["n_tries"] = trial.suggest_int("n_tries", 1, 10)
    conf_r["generator_args"]["terminal_reward"] = trial.suggest_categorical("terminal_reward", ["ignore", -1])
    
    conf_j = conf_from_yaml(yaml_path_j, repo_root)
    
    sum_auc = 0
    for i, oracle_name in enumerate(oracle_names):
        conf_r["reward_class"] = reward_class_name_from_oracle_name(oracle_name)
        conf_r["reward_args"] = {}
        conf_r["reward_args"]["objective"] = oracle_name
        conf_r["output_dir"] = conf_r.get("output_dir") + os.sep + "trial_" + str(trial.number) + os.sep + oracle_name
        
        generator = generator_from_conf(conf_r, repo_root)
        generator.logger.info("reward="+oracle_name)
        generator.logger.info(f"params={trial.params}")
        max_generations, time_limit = 10000, conf_r.get("time_limit")
    
        generator.generate(max_generations=max_generations, time_limit=time_limit)
        best_key = generator.top_k(k=1)[0]
        new_root = SMILESStringNode.node_from_key(key=best_key)

        auc = generator.top_k_auc(top_k=10, max_oracle_calls=max_generations)
        generator.logger.info("top_10_auc: " + str(auc))
        sum_auc += auc
        intermediate_value = sum_auc
        trial.report(intermediate_value, i)
        if trial.should_prune():
            print(f"{oracle_name} Trial {trial.number} - Step {0}: intermediate_score={intermediate_value:.3f}, params={trial.params}")
            raise optuna.TrialPruned()

    trial.set_user_attr("sum_top_10_auc", sum_auc)
    print(f"{oracle_name} Trial {trial.number}: sum_top_10_auc={sum_auc:.3f}")
    
    return sum_auc

In [None]:
name = "mol_opt_rnn_and_jensen"
storage = "sqlite:///optuna/" + name + ".db"
sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
pruner = optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=0, interval_steps=1)
study = optuna.create_study(direction="maximize", study_name=name, storage=storage, sampler=sampler, pruner=pruner)
study.enqueue_trial({"sharpness": 1.0, "top_p": 0.995, "c": 0.2, "best_rate": 0.8, "prior": 1.0, "prior_weight": 0, "max_prior": 0.5, "eval_width": 1, "n_evals": 1, "n_tries": 5})
study.optimize(objective, n_trials=100000)

[I 2025-07-18 14:27:47,060] A new study created in RDB with name: mol_opt_rnn_and_jensen
