In [1]:
# add path (for local)
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [None]:
guacamol_oracle_names = ["zaleplon_mpo", "deco_hop", "thiothixene_rediscovery", "scaffold_hop", "albuterol_similarity", "amlodipine_mpo", "celecoxib_rediscovery", "fexofenadine_mpo", "isomers_c7h8n2o2", "isomers_c9h10n2o2pf2cl", "median1", "median2", "mestranol_similarity", "perindopril_mpo", "osimertinib_mpo", "ranolazine_mpo", "sitagliptin_mpo", "troglitazone_rediscovery", "valsartan_smarts"]
tdc_oracle_names = ["drd2", "gsk3b", "jnk3", "qed"]
oracle_names = guacamol_oracle_names + tdc_oracle_names
    
def reward_class_name_from_oracle_name(oracle_name: str) -> str:
    if oracle_name in tdc_oracle_names:
        return "TDCReward"
    else:
        return "GuacaMolReward"

In [3]:
import os
import optuna
from rdkit import RDLogger
from node import SMILESStringNode
from utils import generator_from_conf, conf_from_yaml
RDLogger.DisableLog('rdApp.*')

yaml_path_1 = "config/optuna/mol_opt_rnn.yaml"
yaml_path_2 = "config/optuna/mol_opt_jensen.yaml"

def objective(trial):
    conf_1 = conf_from_yaml(yaml_path_1, repo_root)
    conf_1.setdefault("transition_args", {})
    conf_1["transition_args"]["sharpness"] = trial.suggest_float("sharpness_1", 0.8, 1.1)
    conf_1["transition_args"]["top_p"] = trial.suggest_float("top_p_1", 0.993, 0.999)

    conf_1["policy_class"] = trial.suggest_categorical("policy_1", ["UCT", "PUCT"])
    conf_1.setdefault("policy_args", {})
    conf_1["policy_args"]["c"] = trial.suggest_float("c_1", 0.01, 1)
    conf_1["policy_args"]["best_rate"] = trial.suggest_float("best_rate_1", 0, 1)
    # conf_1["policy_args"]["prior"] = trial.suggest_float("prior_1", 0.3, 1.2)
    # conf_1["policy_args"]["prior_weight"] = trial.suggest_int("prior_weight_1", 0, 2)
    # conf_1["policy_args"]["max_prior"] = trial.suggest_float("max_prior_1", 0, 0.8)

    conf_1.setdefault("generator_args", {})
    # conf_1["generator_args"]["eval_width"] = trial.suggest_int("eval_width_1", 1, 40)
    # conf_1["generator_args"]["n_evals"] = trial.suggest_int("n_evals_1", 1, 10)
    # conf_1["generator_args"]["n_tries"] = trial.suggest_int("n_tries_1", 1, 10)
    conf_1["generator_args"]["terminal_reward"] = trial.suggest_categorical("terminal_reward_1", ["ignore", -1])
    
    conf_2 = conf_from_yaml(yaml_path_2, repo_root)
    conf_2.setdefault("policy_args", {})
    conf_2["policy_args"]["c"] = trial.suggest_float("c_2", 0.01, 1)
    conf_2["policy_args"]["best_rate"] = trial.suggest_float("best_rate_2", 0, 1)
    # conf_2["policy_args"]["prior"] = trial.suggest_float("prior_2", 0.3, 1.2)
    # conf_2["policy_args"]["prior_weight"] = trial.suggest_int("prior_weight_2", 0, 2)
    # conf_2["policy_args"]["max_prior"] = trial.suggest_float("max_prior_2", 0, 0.8)
    
    n_generations_until_lead = trial.suggest_categorical("n_generations_until_lead", [100, 200, 300, 500, 1000, 2000, 2500, 3000, 5000])
    n_keys_to_pass = trial.suggest_categorical("n_keys_to_pass", [1, 3, 5, 10, 20])
    
    sum_auc = 0
    for i, oracle_name in enumerate(oracle_names):
        conf_1["reward_class"] = reward_class_name_from_oracle_name(oracle_name)
        conf_1["reward_args"] = {}
        conf_1["reward_args"]["objective"] = oracle_name
        conf_1["output_dir"] = "generation_result" + os.sep + "trial_" + str(trial.number) + os.sep + oracle_name
        
        generator_1 = generator_from_conf(conf_1, repo_root)
        generator_1.logger.info("reward="+oracle_name)
        generator_1.logger.info(f"params={trial.params}")
        generator_1.generate(max_generations=n_generations_until_lead, time_limit=conf_1.get("time_limit"))
        
        generator_2 = generator_from_conf(conf_2, predecessor=generator_1, n_top_keys_to_pass=n_keys_to_pass)
        generator_2.generate(max_generations=10000 - n_generations_until_lead, time_limit=conf_2.get("time_limit"))

        auc = generator_2.top_k_auc(top_k=10, max_oracle_calls=10000, finish=True)
        trial.set_user_attr(oracle_name, auc)
        generator_2.logger.info("top_10_auc: " + str(auc))
        
        sum_auc += auc
        intermediate_value = sum_auc
        trial.report(intermediate_value, i)
        if trial.should_prune():
            print(f"{oracle_name} Trial {trial.number} - Step {0}: intermediate_score={intermediate_value:.3f}, params={trial.params}")
            raise optuna.TrialPruned()

    trial.set_user_attr("sum_top_10_auc", sum_auc)
    print(f"{oracle_name} Trial {trial.number}: sum_top_10_auc={sum_auc:.3f}, aucs={trial.user_attrs}")
    
    return sum_auc

In [None]:
name = "mol_opt_chain"
storage = "sqlite:///optuna/" + name + ".db"
sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
pruner = optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=0, interval_steps=1)
study = optuna.create_study(direction="maximize", study_name=name, storage=storage, sampler=sampler, pruner=pruner)
study.enqueue_trial({"policy_1": "PUCT", "sharpness_1": 1.0, "top_p_1": 0.995, "c_1": 1, "best_rate_1": 0.5, "terminal_reward_1": -1, "c_2": 0.25, "best_rate_2": 0.9, "n_generations_until_lead": 200, "n_keys_to_pass": 10})
study.enqueue_trial({"policy_1": "UCT", "sharpness_1": 1.0, "top_p_1": 0.995, "c_1": 0.05, "best_rate_1": 0.95, "terminal_reward_1": -1, "c_2": 0.25, "best_rate_2": 0.9, "n_generations_until_lead": 1000, "n_keys_to_pass": 10})
study.enqueue_trial({"policy_1": "UCT", "sharpness_1": 1.0, "top_p_1": 0.995, "c_1": 0.2, "best_rate_1": 0.5, "terminal_reward_1": -1, "c_2": 0.1, "best_rate_2": 0.95, "n_generations_until_lead": 200, "n_keys_to_pass": 5})
study.optimize(objective, n_trials=100000)

[I 2025-07-22 11:29:07,604] A new study created in RDB with name: mol_opt_chain
  state = torch.load(os.path.join(model_dir, "model.pt"), map_location=self.device)
  state = torch.load(os.path.join(model_dir, "model.pt"), map_location=self.device)


In [None]:
# continue search
name = "mol_opt_chain"
storage = "sqlite:///optuna/" + name + ".db"
study = optuna.study.load_study(study_name=name, storage=storage)
study.optimize(objective, n_trials=30000)

  state = torch.load(os.path.join(model_dir, "model.pt"), map_location=self.device)
  state = torch.load(os.path.join(model_dir, "model.pt"), map_location=self.device)
