In [None]:
# Path setup
import sys
repo_root = "../../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [None]:
import gc
import optuna
from utils import generator_from_conf, conf_from_yaml

yaml_path = "config/optuna/d_score.yaml"

conf = conf_from_yaml(yaml_path)
name = conf.get("study_name")
n_trials = conf.get("n_trials")

def objective(trial):
    conf = conf_from_yaml(yaml_path)
    conf.setdefault("transition_args", {})
    # conf["transition_args"]["model_dir"] = "model/smiles/" + trial.suggest_categorical("model", ["drugs_zinc_gru", "drugs_zinc_lstm", "tf25_ported"])
    conf["transition_args"]["sharpness"] = trial.suggest_float("sharpness", 0.8, 1.1)
    conf["transition_args"]["top_p"] = trial.suggest_float("top_p", 0.994, 0.999)
    
    conf.setdefault("policy_args", {})
    conf["policy_args"]["c"] = trial.suggest_float("c", 0.05, 0.4)
    conf["policy_args"]["best_rate"] = trial.suggest_float("best_rate", 0, 1)
    conf["policy_args"]["max_prior"] = trial.suggest_float("max_prior", 0, 0.6)

    conf.setdefault("generator_args", {})
    conf["generator_args"]["eval_width"] = trial.suggest_int("eval_width", 1, 40)
    conf["generator_args"]["n_evals"] = trial.suggest_int("n_evals", 1, 10)
    conf["generator_args"]["n_tries"] = trial.suggest_int("n_tries", 1, 3)
    # conf["generator_args"]["filter_reward"] = trial.suggest_categorical("filter_reward", ["ignore", 0])
    # conf["generator_args"]["terminal_reward"] = trial.suggest_categorical("terminal_reward", ["ignore", -1])
    
    generator = generator_from_conf(conf)
    generator.logger.info(f"params={trial.params}")
    
    max_generations, time_limit = conf.get("max_generations"), conf.get("time_limit")
    n_steps, best_reward_rate= conf.get("n_steps"), conf.get("best_reward_rate")
    
    for i in range(0, n_steps):
        generator.generate(max_generations=max_generations / n_steps, time_limit=time_limit / n_steps)
        average_reward = generator.average_reward()
        top_p_average_reward = generator.average_reward(top_p=conf.get("average_top_p", 0.1))
        intermediate_value = (1 - best_reward_rate) * top_p_average_reward + best_reward_rate * generator.best_reward
        trial.report(intermediate_value, i)
        if trial.should_prune():
            print(f"Trial {trial.number} - Step {i}: average_reward={average_reward:.3f}, top_p_average_reward={top_p_average_reward:.3f}, best_reward={generator.best_reward:.3f}, params={trial.params}")
            del generator; gc.collect()
            raise optuna.TrialPruned()

    trial.set_user_attr("average_reward", average_reward)
    trial.set_user_attr("top_p_average_reward", top_p_average_reward)
    trial.set_user_attr("best_reward", generator.best_reward)
    print(f"Trial {trial.number}: average_reward={average_reward:.3f}, top_p_average_reward={top_p_average_reward:.3f}, best_reward={generator.best_reward:.3f}")
    generator.analyze()
    # generator.plot(**conf.get("plot_args"))
    del generator; gc.collect()
    return (1 - best_reward_rate) * top_p_average_reward + best_reward_rate * generator.best_reward
    
def print_trial(trial: optuna.Trial):
    print(f"Trial {trial.number} score={trial.value:.3f}, attrs={trial.user_attrs}, params={trial.params}")
    
def print_best_trials(study: optuna.Study):
    print("Optuna trials completed.")
    print("------ Best trials -----")
    best_trials = sorted([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE], key=lambda t: t.value, reverse=True)[:5]
    for t in best_trials:
        print_trial(t)

In [None]:
# Start search
storage = "sqlite:///optuna/" + name + ".db"
sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
pruner = optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=0, interval_steps=1)
study = optuna.create_study(direction="maximize", study_name=name, storage=storage, sampler=sampler, pruner=pruner)
# study.enqueue_trial({"sharpness": 1.0, "top_p": 0.995, "c": 0.25, "best_rate": 0.5, "max_prior": 0.4, "eval_width": 1, "n_evals": 1, "n_tries": 2})
study.optimize(objective, n_trials=n_trials)
print_best_trials(study)

In [None]:
# Continue search
study = optuna.study.load_study(study_name="d_score_200000", storage="sqlite:///optuna/d_score_200000.db")
study.optimize(objective, n_trials=300)
print_best_trials(study)

In [None]:
# Add parameters
study = optuna.study.load_study(study_name="d_score_200000", storage="sqlite:///sqlite:///optuna/d_score_200000.db")

sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
pruner = optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=0, interval_steps=1)
new_storage="sqlite:///optuna/d_score_200000_new.db"
study_with_new_param = optuna.create_study(direction="maximize", study_name=name, storage=new_storage, sampler=sampler, pruner=pruner)

for trial in study.trials:
    params = trial.params
    dists = trial.distributions

    params["filter_reward"] = "ignore"
    dists["filter_reward"] = optuna.distributions.CategoricalDistribution(["ignore", 0])

    trial.params = params
    trial.distributions = dists

    study_with_new_param.add_trial(trial)