In [1]:
""" Optuna example that optimizes the hyperparameters of
a reinforcement learning agent using A2C implementation from Stable-Baselines3
on a OpenAI Gym environment.
This is a simplified version of what can be found in https://github.com/DLR-RM/rl-baselines3-zoo.
You can run this example as follows:
    $ python sb3_simple.py
"""
from typing import Any
from typing import Dict

import gym
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import EvalCallback
import torch
import torch.nn as nn


N_TRIALS = 100
N_STARTUP_TRIALS = 5
N_EVALUATIONS = 2
N_TIMESTEPS = int(1e6)
EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS)
N_EVAL_EPISODES = 3

ENV_ID = "LunarLander-v2"

DEFAULT_HYPERPARAMS = {
    "policy": "MlpPolicy",
    "env": ENV_ID,
}


def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
    """Sampler for PPO hyperparameters."""
    n_steps = 2 ** trial.suggest_int("exponent_n_steps", 7, 11)
    gae_lambda = 1.0 - trial.suggest_float("gae_lambda", 0.001, 0.2, log=True)
    gamma = 1.0 - trial.suggest_float("gamma", 0.0001, 0.1, log=True)
    n_epochs = trial.suggest_categorical("n_epochs", [4, 10, 20])
    ent_coef = trial.suggest_float("ent_coef", 0.001, 0.1, log=True)

    # Display true values
    trial.set_user_attr("gamma_", gamma)
    trial.set_user_attr("gae_lambda_", gae_lambda)
    trial.set_user_attr("n_steps", n_steps)

    net_arch = [
        {"pi": [64], "vf": [64]}
        if net_arch == "tiny"
        else {"pi": [64, 64], "vf": [64, 64]}
    ]

    activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU}[activation_fn]

    return {
        "n_steps": n_steps,
        "gae_lambda": gae_lambda,
        "gamma": gamma,
        "n_epochs": n_epochs,
        "ent_coef": ent_coef,
    }


class TrialEvalCallback(EvalCallback):
    """Callback used for evaluating and reporting a trial."""

    def __init__(
        self,
        eval_env: gym.Env,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        deterministic: bool = True,
        verbose: int = 0,
    ):

        super().__init__(
            eval_env=eval_env,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            deterministic=deterministic,
            verbose=verbose,
        )
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self) -> bool:
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            super()._on_step()
            self.eval_idx += 1
            self.trial.report(self.last_mean_reward, self.eval_idx)
            # Prune trial if need
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True


def objective(trial: optuna.Trial) -> float:

    kwargs = DEFAULT_HYPERPARAMS.copy()
    # Sample hyperparameters
    kwargs.update(sample_ppo_params(trial))
    # Create the RL model
    model = PPO(**kwargs)
    # Create env used for evaluation
    eval_env = gym.make(ENV_ID)
    # Create the callback that will periodically evaluate
    # and report the performance
    eval_callback = TrialEvalCallback(
        eval_env,
        trial,
        n_eval_episodes=N_EVAL_EPISODES,
        eval_freq=EVAL_FREQ,
        deterministic=True,
    )

    nan_encountered = False
    try:
        model.learn(N_TIMESTEPS, callback=eval_callback)
    except AssertionError as e:
        # Sometimes, random hyperparams can generate NaN
        print(e)
        nan_encountered = True
    finally:
        # Free memory
        model.env.close()
        eval_env.close()

    # Tell the optimizer that the trial failed
    if nan_encountered:
        return float("nan")

    if eval_callback.is_pruned:
        raise optuna.exceptions.TrialPruned()

    return eval_callback.last_mean_reward


if __name__ == "__main__":
    # Set pytorch num threads to 1 for faster training
    torch.set_num_threads(1)

    sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS)
    # Do not prune before 1/3 of the max budget is used
    pruner = MedianPruner(
        n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3
    )

    study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")
    try:
        study.optimize(objective, n_trials=N_TRIALS, timeout=600)
    except KeyboardInterrupt:
        pass

    print("Number of finished trials: ", len(study.trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

    print("  User attrs:")
    for key, value in trial.user_attrs.items():
        print("    {}: {}".format(key, value))


  from .autonotebook import tqdm as notebook_tqdm
[32m[I 2022-10-08 20:29:17,025][0m A new study created in memory with name: no-name-fc7c1b30-09a6-4765-810a-cf2e88fe643e[0m
[32m[I 2022-10-08 20:29:19,631][0m Trial 0 finished with value: 500.0 and parameters: {'gamma': 0.0007415827707315834, 'max_grad_norm': 2.3344868693397762, 'gae_lambda': 0.008729403961766975, 'exponent_n_steps': 6, 'lr': 0.0017950518642776102, 'ent_coef': 1.6124087656990934e-07, 'ortho_init': False, 'net_arch': 'tiny', 'activation_fn': 'tanh'}. Best is trial 0 with value: 500.0.[0m
[32m[I 2022-10-08 20:29:22,139][0m Trial 1 finished with value: 500.0 and parameters: {'gamma': 0.030745055807499456, 'max_grad_norm': 0.7223464037686391, 'gae_lambda': 0.0013421773632395389, 'exponent_n_steps': 8, 'lr': 0.013639642894899533, 'ent_coef': 3.1498651377608986e-07, 'ortho_init': False, 'net_arch': 'tiny', 'activation_fn': 'tanh'}. Best is trial 0 with value: 500.0.[0m
[32m[I 2022-10-08 20:29:24,719][0m Trial 2 fin

Number of finished trials:  100
Best trial:
  Value:  500.0
  Params: 
    gamma: 0.0007415827707315834
    max_grad_norm: 2.3344868693397762
    gae_lambda: 0.008729403961766975
    exponent_n_steps: 6
    lr: 0.0017950518642776102
    ent_coef: 1.6124087656990934e-07
    ortho_init: False
    net_arch: tiny
    activation_fn: tanh
  User attrs:
    gamma_: 0.9992584172292684
    gae_lambda_: 0.991270596038233
    n_steps: 64
