In [1]:
# add path (for local)
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [2]:
guacamol_oracle_names = ["zaleplon_mpo", "perindopril_mpo", "albuterol_similarity", "amlodipine_mpo", "celecoxib_rediscovery", "deco_hop", "fexofenadine_mpo", "isomers_c7h8n2o2", "isomers_c9h10n2o2pf2cl", "median1", "median2", "mestranol_similarity", "osimertinib_mpo", "ranolazine_mpo", "scaffold_hop", "sitagliptin_mpo", "thiothixene_rediscovery", "troglitazone_rediscovery", "valsartan_smarts"]
tdc_oracle_names = ["drd2", "gsk3b", "jnk3", "qed"]
oracle_names = guacamol_oracle_names + tdc_oracle_names
    
def reward_class_name_from_oracle_name(oracle_name: str) -> str:
    if oracle_name in tdc_oracle_names:
        return "TDCReward"
    else:
        return "GuacaMolReward"

In [3]:
import os
import optuna
from rdkit import RDLogger
from node import SMILESStringNode
from utils import generator_from_conf, conf_from_yaml
RDLogger.DisableLog('rdApp.*')

yaml_path_1 = "config/optuna/mol_opt_rnn.yaml"
yaml_path_2 = "config/optuna/mol_opt_jensen.yaml"

def objective(trial):
    conf_1 = conf_from_yaml(yaml_path_1, repo_root)
    conf_1.setdefault("transition_args", {})
    conf_1["transition_args"]["sharpness"] = trial.suggest_float("sharpness_1", 0.8, 1.1)
    conf_1["transition_args"]["top_p"] = trial.suggest_float("top_p_1", 0.993, 0.999)

    conf_1.setdefault("policy_args", {})
    conf_1["policy_args"]["c"] = trial.suggest_float("c_1", 0.01, 0.5)
    conf_1["policy_args"]["best_rate"] = trial.suggest_float("best_rate_1", 0, 1)
    conf_1["policy_args"]["prior"] = trial.suggest_float("prior_1", 0.3, 1.2)
    conf_1["policy_args"]["prior_weight"] = trial.suggest_int("prior_weight_1", 0, 2)
    conf_1["policy_args"]["max_prior"] = trial.suggest_float("max_prior_1", 0, 0.8)

    conf_1.setdefault("generator_args", {})
    conf_1["generator_args"]["eval_width"] = trial.suggest_int("eval_width_1", 1, 40)
    conf_1["generator_args"]["n_evals"] = trial.suggest_int("n_evals_1", 1, 10)
    conf_1["generator_args"]["n_tries"] = trial.suggest_int("n_tries_1", 1, 10)
    conf_1["generator_args"]["terminal_reward"] = trial.suggest_categorical("terminal_reward_1", ["ignore", -1])
    
    conf_2 = conf_from_yaml(yaml_path_2, repo_root)
    conf_2.setdefault("policy_args", {})
    conf_2["policy_args"]["c"] = trial.suggest_float("c_2", 0.01, 0.5)
    conf_2["policy_args"]["best_rate"] = trial.suggest_float("best_rate_2", 0, 1)
    conf_2["policy_args"]["prior"] = trial.suggest_float("prior_2", 0.3, 1.2)
    conf_2["policy_args"]["prior_weight"] = trial.suggest_int("prior_weight_2", 0, 2)
    conf_2["policy_args"]["max_prior"] = trial.suggest_float("max_prior_2", 0, 0.8)
    
    n_generations_until_lead = trial.suggest_categorical("n_generations_until_lead", [200, 500, 1000, 1500, 2000, 2500, 5000])
    n_keys_to_pass = trial.suggest_categorical("n_keys_to_pass", [1, 3, 5, 10])
    
    sum_auc = 0
    for i, oracle_name in enumerate(oracle_names):
        conf_1["reward_class"] = reward_class_name_from_oracle_name(oracle_name)
        conf_1["reward_args"] = {}
        conf_1["reward_args"]["objective"] = oracle_name
        conf_1["output_dir"] = "generation_result" + os.sep + "trial_" + str(trial.number) + os.sep + oracle_name
        
        generator_1 = generator_from_conf(conf_1, repo_root)
        generator_1.logger.info("reward="+oracle_name)
        generator_1.logger.info(f"params={trial.params}")
    
        generator_1.generate(max_generations=n_generations_until_lead, time_limit=conf_1.get("time_limit"))
        top_keys = [key for key, _ in generator_1.top_k(k=n_keys_to_pass)]

        conf_2["root"] = top_keys
        conf_2["reward_class"] = conf_1["reward_class"]
        conf_2["reward_args"] = conf_1.get("reward_args", {})
        conf_2["output_dir"] = conf_1["output_dir"]
        generator_2 = generator_from_conf(conf_2)
        generator_2.inherit(generator_1)
        generator_2.generate(max_generations= 10000 - n_generations_until_lead, time_limit=conf_2.get("time_limit"))

        auc = generator_2.top_k_auc(top_k=10, max_oracle_calls=10000)
        trial.set_user_attr(oracle_name, auc)
        generator_2.logger.info("top_10_auc: " + str(auc))
        
        sum_auc += auc
        intermediate_value = sum_auc
        trial.report(intermediate_value, i)
        if trial.should_prune():
            print(f"{oracle_name} Trial {trial.number} - Step {0}: intermediate_score={intermediate_value:.3f}, params={trial.params}")
            raise optuna.TrialPruned()

    trial.set_user_attr("sum_top_10_auc", sum_auc)
    print(f"{oracle_name} Trial {trial.number}: sum_top_10_auc={sum_auc:.3f}, aucs={trial.user_attrs}")
    
    return sum_auc

In [None]:
name = "mol_opt_rnn_and_jensen"
storage = "sqlite:///optuna/" + name + ".db"
sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
pruner = optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=0, interval_steps=1)
study = optuna.create_study(direction="maximize", study_name=name, storage=storage, sampler=sampler, pruner=pruner)
study.enqueue_trial({"sharpness_1": 1.0, "top_p_1": 0.995, "c_1": 0.2, "best_rate_1": 0.5, "prior_1": 1.0, "prior_weight_1": 0, "max_prior_1": 0.5, "eval_width_1": 1, "n_evals_1": 1, "n_tries_1": 5, "c_2": 0.2, "best_rate_2": 0.9, "prior_2": 1.0, "prior_weight_2": 0, "max_prior_2": 0.5, "n_generations_until_lead": 1000, "n_keys_to_pass": 3})
study.optimize(objective, n_trials=100000)

[I 2025-07-18 17:10:48,456] A new study created in RDB with name: mol_opt_rnn_and_jensen
  state = torch.load(os.path.join(model_dir, "model.pt"), map_location=self.device)
  state = torch.load(os.path.join(model_dir, "model.pt"), map_location=self.device)
Found local copy...
Found local copy...
Found local copy...
Found local copy...
Found local copy...
Found local copy...
[I 2025-07-18 23:51:12,898] Trial 0 finished with value: 11.110070052858662 and parameters: {'sharpness_1': 1.0, 'top_p_1': 0.995, 'c_1': 0.2, 'best_rate_1': 0.5, 'prior_1': 1.0, 'prior_weight_1': 0, 'max_prior_1': 0.5, 'eval_width_1': 1, 'n_evals_1': 1, 'n_tries_1': 5, 'terminal_reward_1': 'ignore', 'c_2': 0.2, 'best_rate_2': 0.9, 'prior_2': 1.0, 'prior_weight_2': 0, 'max_prior_2': 0.5, 'n_generations_until_lead': 1000, 'n_keys_to_pass': 3}. Best is trial 0 with value: 11.110070052858662.


qed Trial 0: sum_top_10_auc=11.110, aucs={'zaleplon_mpo': 0.5615201967038963, 'perindopril_mpo': 0.47048494629055965, 'albuterol_similarity': 0.6705461615552808, 'amlodipine_mpo': 0.5710178716242713, 'celecoxib_rediscovery': 0.4973205369649948, 'deco_hop': 0.28542519583251563, 'fexofenadine_mpo': 0.3357686240647808, 'isomers_c7h8n2o2': 0.9014877600070033, 'isomers_c9h10n2o2pf2cl': 0.7944732598488922, 'median1': 0.2700821473818322, 'median2': 0.18992111398050007, 'mestranol_similarity': 0.5615568244727, 'osimertinib_mpo': 0.40363779442401704, 'ranolazine_mpo': 0.305285816346567, 'scaffold_hop': 0.5155927324380237, 'sitagliptin_mpo': 0.5921147074825305, 'thiothixene_rediscovery': 0.33542303375530197, 'troglitazone_rediscovery': 0.2685631972193421, 'valsartan_smarts': 0.0986181661136761, 'drd2': 0.9217037228638366, 'gsk3b': 0.5299256000000002, 'jnk3': 0.3604950000000007, 'qed': 0.66910564348814, 'sum_top_10_auc': 11.110070052858662}


Found local copy...
Found local copy...
Found local copy...
Found local copy...
Found local copy...
Found local copy...
[I 2025-07-19 05:36:43,481] Trial 1 finished with value: 12.524292052608176 and parameters: {'sharpness_1': 0.8993949961289143, 'top_p_1': 0.9955243193920423, 'c_1': 0.4729463903964468, 'best_rate_1': 0.12591140133254186, 'prior_1': 0.7603587919077075, 'prior_weight_1': 0, 'max_prior_1': 0.017544054169471048, 'eval_width_1': 5, 'n_evals_1': 1, 'n_tries_1': 3, 'terminal_reward_1': -1, 'c_2': 0.09881280524670522, 'best_rate_2': 0.10005763426214409, 'prior_2': 0.7920248354440952, 'prior_weight_2': 0, 'max_prior_2': 0.3298597697523124, 'n_generations_until_lead': 500, 'n_keys_to_pass': 5}. Best is trial 1 with value: 12.524292052608176.


qed Trial 1: sum_top_10_auc=12.524, aucs={'zaleplon_mpo': 0.43532352678626596, 'perindopril_mpo': 0.4735500900697541, 'albuterol_similarity': 0.7640640012465727, 'amlodipine_mpo': 0.4844618537147413, 'celecoxib_rediscovery': 0.5039586781178609, 'deco_hop': 0.6130715530149068, 'fexofenadine_mpo': 0.7109851734204236, 'isomers_c7h8n2o2': 0.9722160392439811, 'isomers_c9h10n2o2pf2cl': 0.9203585248279702, 'median1': 0.2760988269668758, 'median2': 0.20221133570283478, 'mestranol_similarity': 0.6753939085144826, 'osimertinib_mpo': 0.6505341399558492, 'ranolazine_mpo': 0.5474378988465087, 'scaffold_hop': 0.5032133928310634, 'sitagliptin_mpo': 0.5416535696443474, 'thiothixene_rediscovery': 0.37795331646428226, 'troglitazone_rediscovery': 0.44718017802555743, 'valsartan_smarts': 0.0, 'drd2': 0.5491824069434862, 'gsk3b': 0.8202997000000001, 'jnk3': 0.2946648000000001, 'qed': 0.7604791382704118, 'sum_top_10_auc': 12.524292052608176}


Found local copy...
Found local copy...
Found local copy...
Found local copy...
Found local copy...
Found local copy...
[I 2025-07-19 15:58:13,822] Trial 2 finished with value: 5.339083486356405 and parameters: {'sharpness_1': 1.0754295561016243, 'top_p_1': 0.9967912999628799, 'c_1': 0.05628664890142018, 'best_rate_1': 0.061181518356504094, 'prior_1': 0.4803641849541345, 'prior_weight_1': 2, 'max_prior_1': 0.5588943832337846, 'eval_width_1': 3, 'n_evals_1': 10, 'n_tries_1': 1, 'terminal_reward_1': -1, 'c_2': 0.018448327958695998, 'best_rate_2': 0.6283253828036415, 'prior_2': 0.7106798077713078, 'prior_weight_2': 0, 'max_prior_2': 0.01711004879027378, 'n_generations_until_lead': 500, 'n_keys_to_pass': 10}. Best is trial 1 with value: 12.524292052608176.


qed Trial 2: sum_top_10_auc=5.339, aucs={'zaleplon_mpo': 0.2905656403794343, 'perindopril_mpo': 0.15539789330278272, 'albuterol_similarity': 0.193605525450912, 'amlodipine_mpo': 0.19693083806553136, 'celecoxib_rediscovery': 0.2344799426945726, 'deco_hop': 0.05470902821148473, 'fexofenadine_mpo': 0.2537415760803627, 'isomers_c7h8n2o2': 0.2692591176562359, 'isomers_c9h10n2o2pf2cl': 0.3104958912872647, 'median1': 0.31961203666967986, 'median2': 0.24014725914452356, 'mestranol_similarity': 0.17643084855531996, 'osimertinib_mpo': 0.299172334932039, 'ranolazine_mpo': 0.20241494666796933, 'scaffold_hop': 0.04627733719372498, 'sitagliptin_mpo': 0.4450775601444394, 'thiothixene_rediscovery': 0.19764315913515254, 'troglitazone_rediscovery': 0.3177397100189318, 'valsartan_smarts': 0.0, 'drd2': 0.4115485398923347, 'gsk3b': 0.17873900000000006, 'jnk3': 0.174964, 'qed': 0.3701313008737089, 'sum_top_10_auc': 5.339083486356405}
