In [None]:
import sys

import numpy as np

sys.path.append("../../")

from hyperparameter_optimization.hyperopt import save_search_space


import dill as pickle
from hyperopt import hp
from hyperopt.pyll import scope
from utils import CategoricalCrossentropyLoss, MSELoss, generate_layer_widths
import gymnasium as gym
import torch
from muzero.action_functions import action_as_onehot as action_function
from torch.optim import Adam, SGD

search_space = {
    "kernel_initializer": hp.choice(
        "kernel_initializer",
        [
            "he_uniform",
            "he_normal",
            "glorot_uniform",
            "glorot_normal",
            "orthogonal",
        ],
    ),
    "optimizer": hp.choice(
        "optimizer",
        [
            {
                "optimizer": "adam",
                # "adam_epsilon": hp.qloguniform(
                #     "adam_epsilon", np.log(1e-8), np.log(0.5), 1e-8
                # ),
                "adam_epsilon": 10 ** (-hp.quniform("adam_epsilon", 1, 8, 1)),
            },
            {
                "optimizer": "sgd",
                "momentum": hp.quniform("momentum", 0, 0.9, 0.1),
            },
        ],
    ),
    "conv_layers": hp.choice("conv_layers", [[]]),
    "learning_rate": 10 ** (-hp.quniform("learning_rate", 1, 4, 1)),
    "known_bounds": hp.choice("known_bounds", [[0, 500]]),
    "residual_filters": hp.choice("residual_filters", [[]]),
    "residual_stacks": hp.choice("residual_stacks", [[]]),
    "conv_layers": hp.choice("conv_layers", [[]]),
    "actor_and_critic_conv_filters": hp.choice("actor_and_critic_conv_filters", [[]]),
    "reward_conv_layers": hp.choice("reward_conv_layers", [[]]),
    "actor_dense_layer_widths": hp.choice(
        "actor_dense_layer_widths", [[512], [256], [128], []]
    ),
    "critic_dense_layer_widths": hp.choice(
        "critic_dense_layer_widths", [[512], [256], [128], []]
    ),
    "reward_dense_layer_widths": hp.choice(
        "reward_dense_layer_widths", [[512], [256], [128], []]
    ),
    "dense_layer_widths": hp.choice(
        "dense_layer_widths", [[512, 64], [256, 64], [128, 64]]
    ),
    "noisy_sigma": hp.choice("noisy_sigma", [0.0]),
    "value_loss_factor": hp.choice("value_loss_factor", [1.0]),
    "root_dirichlet_alpha": hp.quniform(
        "root_dirichlet_alpha", 0.1, 5.0, 0.1
    ),  # hp.choice("root_dirichlet_alpha", [0.3, 1.0, 2.0]),
    "root_exploration_fraction": hp.choice("root_exploration_fraction", [0.25]),
    "num_simulations": scope.int(
        hp.qloguniform("num_simulations", np.log(25), np.log(50), 25)
    ),
    # "temperature_updates": hp.choice("temperature_updates", [[15000, 30000]]),
    "temperatures": hp.choice("temperatures", [[1.0, 0.5, 0.25]]),
    "temperature_with_training_steps": hp.choice(
        "temperature_with_training_steps", [True]
    ),
    "clip_low_prob": hp.choice("clip_low_prob", [0.0]),
    "pb_c_base": hp.choice("pb_c_base", [19652]),
    "pb_c_init": hp.choice("pb_c_init", [1.25]),
    "value_loss_function": hp.choice(
        "value_loss_function", [CategoricalCrossentropyLoss()]
    ),
    "reward_loss_function": hp.choice(
        "reward_loss_function", [CategoricalCrossentropyLoss()]
    ),
    "policy_loss_function": hp.choice(
        "policy_loss_function", [CategoricalCrossentropyLoss()]
    ),
    "training_steps": scope.int(hp.quniform("training_steps", 15000, 45000, 15000)),
    # "minibatch_size": scope.int(
    #     hp.qloguniform("minibatch_size", np.log(8), np.log(64), 8)
    # ),
    # "min_replay_buffer_size": scope.int(
    #     hp.qloguniform("min_replay_buffer_size", np.log(1000), np.log(10000), 1000)
    # ),
    # "replay_buffer_size": scope.int(
    #     hp.qloguniform("replay_buffer_size", np.log(10000), np.log(200000), 10000)
    # ),
    "minibatch_size": scope.int(2 ** (hp.quniform("minibatch_size", 3, 7, 1))),
    "min_replay_buffer_size": scope.int(
        hp.qloguniform("min_replay_buffer_size", np.log(1000), np.log(20000), 1000)
    ),
    "replay_buffer_size": scope.int(10 ** (hp.quniform("replay_buffer_size", 4, 6, 1))),
    "unroll_steps": hp.choice("unroll_steps", [5]),
    "n_step": hp.choice("n_step", [10]),
    "clipnorm": scope.int(hp.quniform("clipnorm", 0, 10.0, 1)),
    "weight_decay": hp.choice("weight_decay", [1e-4]),
    "per_alpha": hp.choice("per_alpha", [0.0]),
    "per_beta": hp.choice("per_beta", [0.0]),
    "per_beta_final": hp.choice("per_beta_final", [0.0]),
    "per_epsilon": hp.choice("per_epsilon", [1e-4]),
    "action_function": hp.choice("action_function", [action_function]),
    "multi_process": hp.choice(
        "multi_process",
        [
            {
                "multi_process": True,
                "num_workers": scope.int(hp.quniform("num_workers", 1, 3, 1)),
            },
            # {
            #     "multi_process": False,
            #     "games_per_generation": scope.int(
            #         hp.qloguniform("games_per_generation", np.log(8), np.log(32), 8)
            #     ),
            # },
        ],
    ),
    "lr_ratio": hp.choice("lr_ratio", [float("inf")]),
    "support_range": scope.int(hp.quniform("support_range", 5, 15, 5)),
}

initial_best_config = []

search_space, initial_best_config = save_search_space(search_space, initial_best_config)

In [None]:
def prep_params(params):
    params["residual_layers"] = []
    params["actor_conv_layers"] = []
    params["critic_conv_layers"] = []

    if params["multi_process"]["multi_process"] == True:
        params["num_workers"] = params["multi_process"]["num_workers"]
        params["multi_process"] = True
    else:
        params["games_per_generation"] = params["multi_process"]["games_per_generation"]
        params["multi_process"] = False

    if params["optimizer"]["optimizer"] == "adam":
        params["adam_epsilon"] = params["optimizer"]["adam_epsilon"]
        params["optimizer"] = Adam
    elif params["optimizer"]["optimizer"] == "sgd":
        params["momentum"] = params["optimizer"]["momentum"]
        params["optimizer"] = SGD
    params["temperature_updates"] = [
        params["training_steps"] / 3,
        2 * params["training_steps"] / 3,
    ]
    return params

In [None]:
from agents.random import RandomAgent
from hyperparameter_optimization.hyperopt import (
    sarl_objective,
    set_sarl_config,
    SarlHyperoptConfig,
)
from hyperopt import atpe, tpe, fmin, space_eval
from hyperopt.exceptions import AllTrialsFailed

from muzero.muzero_agent_torch import MuZeroAgent
from agent_configs import MuZeroConfig
from game_configs import CartPoleConfig

search_space_path, initial_best_config_path = (
    "search_space.pkl",
    "best_config.pkl",
)
# search_space = pickle.load(open(search_space_path, "rb"))
# initial_best_config = pickle.load(open(initial_best_config_path, "rb"))
file_name = "cartpole_muzero"
max_trials = 1
trials_step = 24  # how many additional trials to do after loading the last ones

set_sarl_config(
    SarlHyperoptConfig(
        file_name=file_name,
        eval_method="rolling_average",
        make_env=CartPoleConfig().make_env,
        prep_params=prep_params,
        agent_class=MuZeroAgent,
        agent_config=MuZeroConfig,
        game_config=CartPoleConfig,
        checkpoint_interval=50,
        test_interval=100,
        test_trials=25,
        last_n_rolling_avg=10,
        device="cpu",
    )
)

try:  # try to load an already saved trials object, and increase the max
    trials = pickle.load(open(f"./{file_name}_trials.p", "rb"))
    print("Found saved Trials! Loading...")
    max_trials = len(trials.trials) + trials_step
    print(
        "Rerunning from {} trials to {} (+{}) trials".format(
            len(trials.trials), max_trials, trials_step
        )
    )
except:  # create a new trials object and start searching
    print("No saved Trials! Starting from scratch.")
    trials = None

best = fmin(
    fn=sarl_objective,  # Objective Function to optimize
    space=search_space,  # Hyperparameter's Search Space
    algo=atpe.suggest,  # Optimization algorithm (representative TPE)
    max_evals=max_trials,  # Number of optimization attempts
    trials=trials,  # Record the results
    # early_stop_fn=no_progress_loss(5, 1),
    trials_save_file=f"./{file_name}_trials.p",
    points_to_evaluate=initial_best_config,
    show_progressbar=False,
)
print(best)
best_trial = space_eval(search_space, best)
# gc.collect()