In [1]:
# add path (for local)
import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [2]:
import optuna
from rdkit import RDLogger
from utils import generator_from_conf, conf_from_yaml
RDLogger.DisableLog('rdApp.*')

yaml_path = "config/optuna_d_score.yaml"

conf = conf_from_yaml(yaml_path, repo_root)
name = conf.get("study_name")
n_trials = conf.get("n_trials")

def objective(trial):
    conf = conf_from_yaml(yaml_path, repo_root)
    conf.setdefault("transition_args", {})
    # conf["transition_args"]["model_dir"] = "model/smiles/" + trial.suggest_categorical("model", ["drugs_zinc_gru", "drugs_zinc_lstm", "tf25_ported"])
    conf["transition_args"]["sharpness"] = trial.suggest_float("sharpness", 0.8, 1.1)
    conf["transition_args"]["top_p"] = trial.suggest_float("top_p", 0.994, 0.999)
    
    conf.setdefault("policy_args", {})
    conf["policy_args"]["c"] = trial.suggest_float("c", 0.05, 0.4)
    conf["policy_args"]["best_rate"] = trial.suggest_float("best_rate", 0, 1)
    conf["policy_args"]["prior"] = trial.suggest_float("prior", 0.3, 0.9)
    conf["policy_args"]["prior_weight"] = trial.suggest_int("prior_weight", 0, 2)
    conf["policy_args"]["max_prior"] = trial.suggest_float("max_prior", 0, 0.6)
    # conf["policy_args"]["c_action"] = trial.suggest_float("c_action", 0, 0.4)
    # conf["policy_args"]["prior_offset"] = trial.suggest_float("prior_offset", 0, 0.2)

    conf.setdefault("generator_args", {})
    conf["generator_args"]["eval_width"] = trial.suggest_int("eval_width", 1, 40)
    # conf["generator_args"]["allow_rollout_overlaps"] = trial.suggest_categorical("allow_rollout_overlaps", [True, False])
    conf["generator_args"]["n_evals"] = trial.suggest_int("n_evals", 1, 10)
    conf["generator_args"]["n_tries"] = trial.suggest_int("n_tries", 1, 3)
    # conf["generator_args"]["filter_reward"] = trial.suggest_categorical("filter_reward", ["ignore", 0])
    # conf["generator_args"]["failed_parent_reward"] = trial.suggest_categorical("failed_parent_reward", ["ignore", -1])
    # conf["generator_args"]["terminal_reward"] = trial.suggest_categorical("terminal_reward", ["ignore", -1])
    # conf["generator_args"]["cut_failed_child"] = trial.suggest_categorical("cut_failed_child", [True, False])
    
    generator = generator_from_conf(conf, repo_root)
    generator.logger.info(f"params={trial.params}")
    
    max_generations, time_limit = conf.get("max_generations"), conf.get("time_limit")
    n_steps, best_reward_rate= conf.get("n_steps"), conf.get("best_reward_rate")
    
    for i in range(0, n_steps):
        generator.generate(max_generations=max_generations / n_steps, time_limit=time_limit / n_steps)
        average_reward = generator.average_reward()
        top_p_average_reward = generator.average_reward(top_p=conf.get("average_top_p", 0.1))
        intermediate_value = (1 - best_reward_rate) * top_p_average_reward + best_reward_rate * generator.best_reward
        trial.report(intermediate_value, i)
        if trial.should_prune():
            print(f"Trial {trial.number} - Step {i}: average_reward={average_reward:.3f}, top_p_average_reward={top_p_average_reward:.3f}, best_reward={generator.best_reward:.3f}, params={trial.params}")
            raise optuna.TrialPruned()

    trial.set_user_attr("average_reward", average_reward)
    trial.set_user_attr("top_p_average_reward", top_p_average_reward)
    trial.set_user_attr("best_reward", generator.best_reward)
    print(f"Trial {trial.number}: average_reward={average_reward:.3f}, top_p_average_reward={top_p_average_reward:.3f}, best_reward={generator.best_reward:.3f}")
    generator.analyze()
    generator.plot(**conf.get("plot_args"))
    return (1 - best_reward_rate) * top_p_average_reward + best_reward_rate * generator.best_reward
    
def print_trial(trial: optuna.Trial):
    print(f"Trial {trial.number} score={trial.value:.3f}, attrs={trial.user_attrs}, params={trial.params}")
    
def print_best_trials(study: optuna.Study):
    print("Optuna trials completed.")
    print("------ Best trials -----")
    best_trials = sorted([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE], key=lambda t: t.value, reverse=True)[:5]
    for t in best_trials:
        print_trial(t)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# start search
storage = "sqlite:///optuna/" + name + ".db"
sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
# sampler = optuna.samplers.GPSampler(deterministic_objective=False) # better if not using pruner?
pruner = optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=0, interval_steps=1)
study = optuna.create_study(direction="maximize", study_name=name, storage=storage, sampler=sampler, pruner=pruner)
study.enqueue_trial({"sharpness": 1.0, "top_p": 0.995, "c": 0.25, "best_rate": 0.5, "prior": 0.6, "prior_weight": 1, "max_prior": 0.4, "eval_width": 1, "n_evals": 1, "n_tries": 2})
study.optimize(objective, n_trials=n_trials)
print_best_trials(study)

[I 2025-07-11 16:09:06,809] A new study created in RDB with name: d_score_pfv_false


Trial 0: average_reward=0.291, top_p_average_reward=0.480, best_reward=0.547


[I 2025-07-11 16:09:50,677] Trial 0 finished with value: 0.5001970114049721 and parameters: {'sharpness': 1.0, 'top_p': 0.995, 'c': 0.25, 'best_rate': 0.5, 'prior': 0.6, 'prior_weight': 1, 'max_prior': 0.4, 'eval_width': 1, 'n_evals': 1, 'n_tries': 2}. Best is trial 0 with value: 0.5001970114049721.


Trial 1: average_reward=0.267, top_p_average_reward=0.454, best_reward=0.531


[I 2025-07-11 16:10:34,956] Trial 1 finished with value: 0.477386387395601 and parameters: {'sharpness': 0.8669840197758373, 'top_p': 0.9987506543175005, 'c': 0.12319895750585806, 'best_rate': 0.49356519524561115, 'prior': 0.5606676588548496, 'prior_weight': 1, 'max_prior': 0.29951772625559997, 'eval_width': 13, 'n_evals': 1, 'n_tries': 2}. Best is trial 0 with value: 0.5001970114049721.


Trial 2: average_reward=0.179, top_p_average_reward=0.400, best_reward=0.483


[I 2025-07-11 16:11:33,605] Trial 2 finished with value: 0.42507418656824114 and parameters: {'sharpness': 0.8039210287693077, 'top_p': 0.9963330023659503, 'c': 0.3234836920438077, 'best_rate': 0.8859287960358315, 'prior': 0.6720829526124672, 'prior_weight': 0, 'max_prior': 0.08596387494497874, 'eval_width': 31, 'n_evals': 3, 'n_tries': 3}. Best is trial 0 with value: 0.5001970114049721.
[I 2025-07-11 16:11:34,919] Trial 3 pruned. 


Trial 3 - Step 0: average_reward=0.226, top_p_average_reward=0.401, best_reward=0.415, params={'sharpness': 1.013273249445924, 'top_p': 0.9949751009223485, 'c': 0.21991906702918132, 'best_rate': 0.602180697829364, 'prior': 0.3189618604008091, 'prior_weight': 0, 'max_prior': 0.08101651103516669, 'eval_width': 5, 'n_evals': 2, 'n_tries': 2}


[I 2025-07-11 16:11:36,030] Trial 4 pruned. 


Trial 4 - Step 0: average_reward=0.201, top_p_average_reward=0.402, best_reward=0.412, params={'sharpness': 1.023209982350212, 'top_p': 0.995961427706737, 'c': 0.36408465131316076, 'best_rate': 0.9760740669789458, 'prior': 0.5177142340774937, 'prior_weight': 0, 'max_prior': 0.3027286150719825, 'eval_width': 8, 'n_evals': 2, 'n_tries': 1}


[I 2025-07-11 16:11:37,145] Trial 5 pruned. 


Trial 5 - Step 0: average_reward=0.163, top_p_average_reward=0.371, best_reward=0.422, params={'sharpness': 0.8891915437619239, 'top_p': 0.99472936333149, 'c': 0.056084194706174, 'best_rate': 0.6343012728659116, 'prior': 0.6640110747869536, 'prior_weight': 0, 'max_prior': 0.023415796874547933, 'eval_width': 25, 'n_evals': 1, 'n_tries': 2}


[I 2025-07-11 16:11:38,423] Trial 6 pruned. 


Trial 6 - Step 0: average_reward=0.183, top_p_average_reward=0.369, best_reward=0.398, params={'sharpness': 0.9306932583797423, 'top_p': 0.9960466310322268, 'c': 0.2583182154118876, 'best_rate': 0.3796608391653997, 'prior': 0.32278428423654976, 'prior_weight': 2, 'max_prior': 0.21445028579999642, 'eval_width': 7, 'n_evals': 9, 'n_tries': 2}


[I 2025-07-11 16:11:39,718] Trial 7 pruned. 


Trial 7 - Step 0: average_reward=0.168, top_p_average_reward=0.398, best_reward=0.430, params={'sharpness': 0.9038235291343043, 'top_p': 0.9987132879931269, 'c': 0.1632502000680769, 'best_rate': 0.6792233866081655, 'prior': 0.5299671343129426, 'prior_weight': 0, 'max_prior': 0.12160927852766332, 'eval_width': 6, 'n_evals': 10, 'n_tries': 2}


[I 2025-07-11 16:11:43,220] Trial 8 pruned. 


Trial 8 - Step 1: average_reward=0.225, top_p_average_reward=0.402, best_reward=0.442, params={'sharpness': 0.9363242133917015, 'top_p': 0.9940198890880976, 'c': 0.1621996041955174, 'best_rate': 0.24269591751143738, 'prior': 0.7267688688842345, 'prior_weight': 2, 'max_prior': 0.010190561112537378, 'eval_width': 11, 'n_evals': 10, 'n_tries': 1}


[I 2025-07-11 16:11:45,378] Trial 9 pruned. 


Trial 9 - Step 0: average_reward=0.210, top_p_average_reward=0.386, best_reward=0.421, params={'sharpness': 0.859364736683662, 'top_p': 0.9956101491949338, 'c': 0.09154931675770149, 'best_rate': 0.8461313765264119, 'prior': 0.5620594808283406, 'prior_weight': 0, 'max_prior': 0.29922904940198686, 'eval_width': 38, 'n_evals': 7, 'n_tries': 1}


[I 2025-07-11 16:11:46,977] Trial 10 pruned. 


Trial 10 - Step 0: average_reward=0.182, top_p_average_reward=0.352, best_reward=0.379, params={'sharpness': 0.9208795414163355, 'top_p': 0.9949016408949012, 'c': 0.33852523314243516, 'best_rate': 0.6999537671751707, 'prior': 0.7476172969429653, 'prior_weight': 2, 'max_prior': 0.2825513928737057, 'eval_width': 5, 'n_evals': 5, 'n_tries': 2}


[I 2025-07-11 16:11:48,417] Trial 11 pruned. 


Trial 11 - Step 0: average_reward=0.163, top_p_average_reward=0.395, best_reward=0.429, params={'sharpness': 0.8249809980683172, 'top_p': 0.99837943271985, 'c': 0.1772016434987026, 'best_rate': 0.18360343087036102, 'prior': 0.6313911895879695, 'prior_weight': 1, 'max_prior': 0.3369581984087986, 'eval_width': 8, 'n_evals': 1, 'n_tries': 2}


[I 2025-07-11 16:11:49,680] Trial 12 pruned. 


Trial 12 - Step 0: average_reward=0.178, top_p_average_reward=0.364, best_reward=0.382, params={'sharpness': 0.9491498807495848, 'top_p': 0.9947096962566511, 'c': 0.17984636622753464, 'best_rate': 0.20415443504399733, 'prior': 0.6202598459184269, 'prior_weight': 2, 'max_prior': 0.2017191322818818, 'eval_width': 5, 'n_evals': 2, 'n_tries': 2}


[I 2025-07-11 16:11:50,551] Trial 13 pruned. 


Trial 13 - Step 0: average_reward=0.177, top_p_average_reward=0.367, best_reward=0.380, params={'sharpness': 1.0969583763841246, 'top_p': 0.9945417123090554, 'c': 0.29734909597602366, 'best_rate': 0.5579393880765007, 'prior': 0.6346219505854832, 'prior_weight': 1, 'max_prior': 0.3631276611053328, 'eval_width': 1, 'n_evals': 1, 'n_tries': 2}


[I 2025-07-11 16:11:53,454] Trial 14 pruned. 


Trial 14 - Step 1: average_reward=0.173, top_p_average_reward=0.376, best_reward=0.479, params={'sharpness': 0.9349711125776244, 'top_p': 0.9985908077746981, 'c': 0.24000788721413968, 'best_rate': 0.6225717528847224, 'prior': 0.4030524170363848, 'prior_weight': 2, 'max_prior': 0.35091230184628824, 'eval_width': 13, 'n_evals': 2, 'n_tries': 2}


[I 2025-07-11 16:11:54,911] Trial 15 pruned. 


Trial 15 - Step 0: average_reward=0.212, top_p_average_reward=0.389, best_reward=0.439, params={'sharpness': 0.9479498166476279, 'top_p': 0.9944046682953256, 'c': 0.3080089375628622, 'best_rate': 0.5629902271305005, 'prior': 0.5852069971091104, 'prior_weight': 1, 'max_prior': 0.5862375606233109, 'eval_width': 7, 'n_evals': 4, 'n_tries': 3}


[I 2025-07-11 16:11:56,103] Trial 16 pruned. 


Trial 16 - Step 0: average_reward=0.184, top_p_average_reward=0.389, best_reward=0.406, params={'sharpness': 0.9361363702006338, 'top_p': 0.9986725763657498, 'c': 0.05782086083764572, 'best_rate': 0.5526047434242304, 'prior': 0.7693257032919613, 'prior_weight': 1, 'max_prior': 0.40757816789677015, 'eval_width': 20, 'n_evals': 1, 'n_tries': 2}


In [None]:
# continue search
study = optuna.study.load_study(study_name="d_score_200000", storage="sqlite:///optuna/d_score_200000.db")
study.optimize(objective, n_trials=300)
print_best_trials(study)

In [None]:
# add parameters
study = optuna.study.load_study(study_name="d_score_200000", storage="sqlite:///sqlite:///optuna/d_score_200000.db")

sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
pruner = optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=0, interval_steps=1)
new_storage="sqlite:///optuna/d_score_200000_new.db"
study_with_new_param = optuna.create_study(direction="maximize", study_name=name, storage=new_storage, sampler=sampler, pruner=pruner)

for trial in study.trials:
    params = trial.params
    dists = trial.distributions

    params["cut_failed_child"] = False
    dists["cut_failed_child"] = optuna.distributions.CategoricalDistribution([True, False])

    trial.params = params
    trial.distributions = dists

    study_with_new_param.add_trial(trial)