In [1]:
# Path setup
import sys
repo_root = "../../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

In [2]:
guacamol_oracle_names = ["zaleplon_mpo", "isomers_c7h8n2o2", "isomers_c9h10n2o2pf2cl", "troglitazone_rediscovery", "median1", "sitagliptin_mpo", "thiothixene_rediscovery", "deco_hop", "albuterol_similarity", "scaffold_hop", "amlodipine_mpo", "celecoxib_rediscovery", "fexofenadine_mpo", "median2", "mestranol_similarity", "perindopril_mpo", "osimertinib_mpo", "ranolazine_mpo", "valsartan_smarts"]
tdc_oracle_names = ["drd2", "gsk3b", "jnk3", "qed"]
oracle_names = guacamol_oracle_names + tdc_oracle_names
    
def reward_class_name_from_oracle_name(oracle_name: str) -> str:
    if oracle_name in tdc_oracle_names:
        return "TDCReward"
    else:
        return "GuacaMolReward"

In [None]:
import os
import optuna
from node import SMILESStringNode
from utils import generator_from_conf, conf_from_yaml

yaml_path = "config/optuna/mol_opt_v2_repli.yaml"

def objective(trial):
    try:
        conf = conf_from_yaml(yaml_path)
        conf.setdefault("transition_args", {})

        conf.setdefault("policy_args", {})
        conf["policy_args"]["c"] = trial.suggest_float("c", 0.01, 0.4)

        conf.setdefault("generator_args", {})
        conf["generator_args"]["n_evals"] = trial.suggest_int("n_evals", 1, 5)

        sum_auc = 0
        for i, oracle_name in enumerate(oracle_names):
            conf["reward_class"] = reward_class_name_from_oracle_name(oracle_name)
            conf["reward_args"] = {}
            conf["reward_args"]["objective"] = oracle_name
            conf["output_dir"] = "generation_result" + os.sep + "trial_" + str(trial.number) + os.sep + oracle_name
            
            generator = generator_from_conf(conf)
            generator.logger.info("reward="+oracle_name)
            generator.logger.info(f"params={trial.params}")
            generator.generate(max_generations=10000, time_limit=conf.get("time_limit", 1000))

            if len(generator.unique_keys) < 10000:
                raise optuna.exceptions.TrialPruned()

            auc = generator.auc(top_k=10, max_oracle_calls=10000, finish=True)
            trial.set_user_attr(oracle_name, auc)
            generator.logger.info("top_10_auc: " + str(auc))
            
            sum_auc += auc
            intermediate_value = sum_auc
            trial.report(intermediate_value, i)
            if trial.should_prune():
                print(f"{oracle_name} Trial {trial.number} - Step {0}: intermediate_score={intermediate_value:.3f}, params={trial.params}")
                raise optuna.TrialPruned()

        trial.set_user_attr("sum_top_10_auc", sum_auc)
        print(f"{oracle_name} Trial {trial.number}: sum_top_10_auc={sum_auc:.3f}, aucs={trial.user_attrs}")
        return sum_auc
    except Exception:
        raise optuna.exceptions.TrialPruned()

In [None]:
name = "mol_opt_rnn_only"
storage = "sqlite:///optuna/" + name + ".db"
sampler = sampler=optuna.samplers.TPESampler(multivariate=True, group=True)
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=2, interval_steps=1)
study = optuna.create_study(direction="maximize", study_name=name, storage=storage, sampler=sampler, pruner=pruner, load_if_exists=True)
study.enqueue_trial({"c_1": 0.1, "n_evals": 3})
study.enqueue_trial({"c_1": 0.282842712475, "n_evals": 3})
study.enqueue_trial({"c_1": 0.01, "n_evals": 3})
study.optimize(objective, n_trials=100000)

In [None]:
# continue search
name = "mol_opt_rnn_only"
storage = "sqlite:///optuna/" + name + ".db"
study = optuna.study.load_study(study_name=name, storage=storage)
study.optimize(objective, n_trials=30000)