In [None]:
import random
from hrro_env_norm import HRROenv
from ray import air, tune
from ray.tune.schedulers import PopulationBasedTraining
from ray.rllib.algorithms.ppo import PPOConfig
from auxiliary_classes import (
    Membrane,
    Solution,
    DesignParameters,
    OperationParameters,
)


# Postprocess the perturbed config to ensure it's still valid
def explore(config):
    # ensure we collect enough timesteps to do sgd
    if config["train_batch_size"] < config["sgd_minibatch_size"] * 2:
        config["train_batch_size"] = config["sgd_minibatch_size"] * 2
    # ensure we run at least one sgd iter
    if config["num_sgd_iter"] < 1:
        config["num_sgd_iter"] = 1
    return config


pbt = PopulationBasedTraining(
    time_attr="time_total_s",
    perturbation_interval=100,
    resample_probability=0.25,
    # Specifies the mutations of these hyperparams
    hyperparam_mutations={
        "lambda": lambda: random.uniform(0.9, 1.0),
        "clip_param": lambda: random.uniform(0.01, 0.5),
        "lr": [5e-4, 1e-4, 5e-5, 1e-5],
        "num_sgd_iter": lambda: random.randint(1, 30),
        "sgd_minibatch_size": lambda: random.randint(128, 16384),
        "train_batch_size": lambda: random.randint(2000, 60000),
    },
    custom_explore_fn=explore,
)

In [None]:
config_PPO = (
    PPOConfig()
    .environment(
        HRROenv,
        env_config={
            'membrane': Membrane().membrane_xus180808_double(),
            'solution': Solution(),
            'design': DesignParameters.Nijhuis_BIA(),
            'operation': OperationParameters()
        }
    )
    .rollouts(
        num_rollout_workers=25,
        num_envs_per_worker=2,
        create_env_on_local_worker=True
    )
    .framework('tf2', eager_tracing=True)
    .evaluation(evaluation_num_workers=1)
    .resources(num_gpus=1, num_cpus_per_worker=1)
    .experimental(_disable_action_flattening=True)
    .training(
        lr=tune.choice([5e-4, 1e-4, 5e-5, 1e-5]),
        kl_coeff=0.0,
        lambda_=0.99,
        clip_param=0.2,
        num_sgd_iter=tune.choice([3, 5, 10 , 15, 30]),
        sgd_minibatch_size=tune.choice([128, 512, 1024, 2048, 4096]),
        train_batch_size=tune.choice([10000, 20000, 30000]),
        model={
            'fcnet_hiddens': [256, 256, 256, 128],
            'use_lstm': True,
            'lstm_cell_size': tune.choice([64, 128, 256]),
            'max_seq_len': tune.choice([20, 30, 40, 50]),
            'lstm_use_prev_action': True,
            'lstm_use_prev_reward': True,
        }
    )
)

tuner = tune.Tuner(
    "PPO",
    tune_config=tune.TuneConfig(
        metric="episode_reward_mean",
        mode="max",
        scheduler=pbt,
        num_samples=100,
    ),
    param_space=config_PPO,
    run_config=air.RunConfig(
        stop={"training_iteration": 200},
    )
)
results = tuner.fit()

print("best hyperparameters: ", results.get_best_result().config)