# Hyperparameter tuning with Optuna

Github repo: https://github.com/araffin/tools-for-robotic-rl-icra2022

Optuna: https://github.com/optuna/optuna

Stable-Baselines3: https://github.com/DLR-RM/stable-baselines3

Documentation: https://stable-baselines3.readthedocs.io/en/master/

SB3 Contrib: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib

RL Baselines3 zoo: https://github.com/DLR-RM/rl-baselines3-zoo

[RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) is a collection of pre-trained Reinforcement Learning agents using Stable-Baselines3.

It also provides basic scripts for training, evaluating agents, tuning hyperparameters and recording videos.


## Introduction

In this notebook, you will learn the importance of tuning
hyperparameters. You will first try to optimize the parameters
manually and then we will see how to automate the search using Optuna.


## Install Dependencies and Stable Baselines3 Using Pip

List of full dependencies can be found in the [README](https://github.com/DLR-RM/stable-baselines3).


```
pip install stable-baselines3[extra]
```

In [1]:
!pip install stable-baselines3













In [2]:
# Optional: install SB3 contrib to have access to additional algorithms
!pip install sb3-contrib



Collecting stable-baselines3<3.0,>=2.2.1 (from sb3-contrib)
  Using cached stable_baselines3-2.2.1-py3-none-any.whl.metadata (5.0 kB)












Using cached stable_baselines3-2.2.1-py3-none-any.whl (181 kB)


Installing collected packages: stable-baselines3
  Attempting uninstall: stable-baselines3
    Found existing installation: stable-baselines3 2.2.0
    Uninstalling stable-baselines3-2.2.0:
      Successfully uninstalled stable-baselines3-2.2.0


[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
rl-zoo3 2.2.1 requires gymnasium~=0.29.1, but you have gymnasium 0.28.1 which is incompatible.[0m[31m
[0mSuccessfully installed stable-baselines3-2.2.1


In [3]:
# Optuna will be used in the last part when doing hyperparameter tuning
!pip install optuna





## Imports

In [4]:
import gymnasium as gym
import numpy as np

The first thing you need to import is the RL model, check the documentation to know what you can use on which problem

In [5]:
from stable_baselines3 import PPO, A2C, SAC, TD3, DQN

In [6]:
# Algorithms from the contrib repo
# https://github.com/Stable-Baselines-Team/stable-baselines3-contrib
from sb3_contrib import QRDQN, TQC

In [7]:
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy

# Part I: The Importance Of Tuned Hyperparameters



When compared with Supervised Learning, Deep Reinforcement Learning is
far more sensitive to the choice of hyper-parameters such as learning
rate, number of neurons, number of layers, optimizer ... etc.

Poor choice of hyper-parameters can lead to poor/unstable convergence. This challenge is compounded by the variability in performance across random seeds (used to initialize the network weights and the environment).

In addition to hyperparameters, selecting the appropriate algorithm is
also an important choice. We will demonstrate it on the simple
Pendulum task.

See [gym doc](https://gym.openai.com/envs/Pendulum-v0/): "The inverted
pendulum swingup problem is a classic problem in the control
literature. In this version of the problem, the pendulum starts in a
random position, and the goal is to swing it up so it stays upright."


Let's try first with PPO and a small budget of 4000 steps (20 episodes):

In [8]:
env_id = "Pendulum-v1"
# Env used only for evaluation
eval_envs = make_vec_env(env_id, n_envs=10)
# 4000 training timesteps
budget_pendulum = 4000

### PPO

In [9]:
ppo_model = PPO("MlpPolicy", env_id, seed=0, verbose=0).learn(budget_pendulum)

In [10]:
mean_reward, std_reward = evaluate_policy(ppo_model, eval_envs, n_eval_episodes=100, deterministic=True)

print(f"PPO Mean episode reward: {mean_reward:.2f} +/- {std_reward:.2f}")

PPO Mean episode reward: -1144.99 +/- 250.23


### A2C

In [12]:
# Define and train a A2C model
a2c_model = A2C("MlpPolicy", env_id, verbose=0).learn(budget_pendulum)


In [13]:
# Evaluate the train A2C model
mean_reward, std_reward = evaluate_policy(a2c_model, eval_envs, n_eval_episodes=100, deterministic=True)

print(f"A2C Mean episode reward: {mean_reward:.2f} +/- {std_reward:.2f}")

A2C Mean episode reward: -1241.39 +/- 112.48


Both are far from solving the env (mean reward around -200).
Now, let's try with an off-policy algorithm:

### Training longer PPO ?

Maybe training longer would help?

You can try with 10x the budget, but in the case of A2C/PPO, training longer won't help much, finding better hyperparameters is needed instead.

In [14]:
# train longer
new_budget = 10 * budget_pendulum

ppo_model = PPO("MlpPolicy", env_id, seed=0, verbose=0).learn(new_budget)

In [15]:
mean_reward, std_reward = evaluate_policy(ppo_model, eval_envs, n_eval_episodes=100, deterministic=True)

print(f"PPO Mean episode reward: {mean_reward:.2f} +/- {std_reward:.2f}")

PPO Mean episode reward: -1266.26 +/- 304.38


### PPO - Tuned Hyperparameters

Using Optuna, we can in fact tune the hyperparameters and find a working solution (from the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml)):

In [16]:
tuned_params = {
    "gamma": 0.9,
    "use_sde": True,
    "sde_sample_freq": 4,
    "learning_rate": 1e-3,
}

# budget = 10 * budget_pendulum
ppo_tuned_model = PPO("MlpPolicy", env_id, seed=1, verbose=1, **tuned_params).learn(50_000, log_interval=5)

Using cuda device
Creating environment from the given name 'Pendulum-v1'
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 200         |
|    ep_rew_mean          | -1.17e+03   |
| time/                   |             |
|    fps                  | 833         |
|    iterations           | 5           |
|    time_elapsed         | 12          |
|    total_timesteps      | 10240       |
| train/                  |             |
|    approx_kl            | 0.028126108 |
|    clip_fraction        | 0.243       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.61       |
|    explained_variance   | 0.808       |
|    learning_rate        | 0.001       |
|    loss                 | 18.8        |
|    n_updates            | 40          |
|    policy_gradient_loss | -0.0271     |
|    std                  | 0.923       |
|    value_loss           | 44.2        |
-----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 200         |
|    ep_rew_mean          | -1.06e+03   |
| time/                   |             |
|    fps                  | 797         |
|    iterations           | 10          |
|    time_elapsed         | 25          |
|    total_timesteps      | 20480       |
| train/                  |             |
|    approx_kl            | 0.031463865 |
|    clip_fraction        | 0.268       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.28       |
|    explained_variance   | 0.968       |
|    learning_rate        | 0.001       |
|    loss                 | 3.57        |
|    n_updates            | 90          |
|    policy_gradient_loss | -0.0427     |
|    std                  | 0.49        |
|    value_loss           | 8.6         |
-----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 200        |
|    ep_rew_mean          | -686       |
| time/                   |            |
|    fps                  | 786        |
|    iterations           | 15         |
|    time_elapsed         | 39         |
|    total_timesteps      | 30720      |
| train/                  |            |
|    approx_kl            | 0.01979826 |
|    clip_fraction        | 0.24       |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.44      |
|    explained_variance   | 0.994      |
|    learning_rate        | 0.001      |
|    loss                 | 1.02       |
|    n_updates            | 140        |
|    policy_gradient_loss | -0.00708   |
|    std                  | 0.284      |
|    value_loss           | 2.21       |
----------------------------------------


----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 200        |
|    ep_rew_mean          | -325       |
| time/                   |            |
|    fps                  | 784        |
|    iterations           | 20         |
|    time_elapsed         | 52         |
|    total_timesteps      | 40960      |
| train/                  |            |
|    approx_kl            | 0.03706546 |
|    clip_fraction        | 0.243      |
|    clip_range           | 0.2        |
|    entropy_loss         | -1.25      |
|    explained_variance   | 0.996      |
|    learning_rate        | 0.001      |
|    loss                 | 0.437      |
|    n_updates            | 190        |
|    policy_gradient_loss | 0.00105    |
|    std                  | 0.229      |
|    value_loss           | 1.01       |
----------------------------------------


-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 200         |
|    ep_rew_mean          | -215        |
| time/                   |             |
|    fps                  | 784         |
|    iterations           | 25          |
|    time_elapsed         | 65          |
|    total_timesteps      | 51200       |
| train/                  |             |
|    approx_kl            | 0.022205412 |
|    clip_fraction        | 0.285       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.14       |
|    explained_variance   | 0.996       |
|    learning_rate        | 0.001       |
|    loss                 | 0.185       |
|    n_updates            | 240         |
|    policy_gradient_loss | -0.0165     |
|    std                  | 0.2         |
|    value_loss           | 0.698       |
-----------------------------------------


In [17]:
mean_reward, std_reward = evaluate_policy(ppo_tuned_model, eval_envs, n_eval_episodes=100, deterministic=True)

print(f"Tuned PPO Mean episode reward: {mean_reward:.2f} +/- {std_reward:.2f}")

Tuned PPO Mean episode reward: -193.94 +/- 105.82


Note: if you try SAC on the simple MountainCarContinuous environment, you will encounter some issues without tuned hyperparameters: https://github.com/rail-berkeley/softlearning/issues/76

Simple environments can be challenging even for SOTA algorithms.

# Part II: Grad Student Descent


### Challenge (10 minutes): "Grad Student Descent"
The challenge is to find the best hyperparameters (max performance) for A2C on `CartPole-v1` with a limited budget of 20 000 training steps.


Maximum reward: 500 on `CartPole-v1`

The hyperparameters should work for different random seeds.

In [18]:
budget = 20_000

#### The baseline: default hyperparameters

In [19]:
eval_envs_cartpole = make_vec_env("CartPole-v1", n_envs=10)

In [20]:
model = A2C("MlpPolicy", "CartPole-v1", seed=8, verbose=0).learn(budget)

Using cuda device
Creating environment from the given name 'CartPole-v1'
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 20.2     |
|    ep_rew_mean        | 20.2     |
| time/                 |          |
|    fps                | 637      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -0.688   |
|    explained_variance | -0.0129  |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 1.95     |
|    value_loss         | 8.69     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 19.9     |
|    ep_rew_mean        | 19.9     |
| time/                 |          |
|    fps                | 679      |
|    iterations         | 200      |
|    time_elapsed       | 1        |
|    total_timesteps    | 1000     |
| train/                |          |
|    entropy_loss       | -0.693   |
|    explained_variance | 0.178    |
|    learning_rate      | 0.0007   |
|    n_updates          | 199      |
|    policy_loss        | -5.49    |
|    value_loss         | 88.4     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 19.6     |
|    ep_rew_mean        | 19.6     |
| time/                 |          |
|    fps                | 697      |
|    iterations         | 300      |
|    time_elapsed       | 2        |
|    total_timesteps    | 1500     |
| train/                |          |
|    entropy_loss       | -0.691   |
|    explained_variance | -0.0665  |
|    learning_rate      | 0.0007   |
|    n_updates          | 299      |
|    policy_loss        | 1.81     |
|    value_loss         | 7.53     |
------------------------------------


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 19.7      |
|    ep_rew_mean        | 19.7      |
| time/                 |           |
|    fps                | 702       |
|    iterations         | 400       |
|    time_elapsed       | 2         |
|    total_timesteps    | 2000      |
| train/                |           |
|    entropy_loss       | -0.69     |
|    explained_variance | -0.000388 |
|    learning_rate      | 0.0007    |
|    n_updates          | 399       |
|    policy_loss        | 1.65      |
|    value_loss         | 6.52      |
-------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 20.7     |
|    ep_rew_mean        | 20.7     |
| time/                 |          |
|    fps                | 717      |
|    iterations         | 500      |
|    time_elapsed       | 3        |
|    total_timesteps    | 2500     |
| train/                |          |
|    entropy_loss       | -0.68    |
|    explained_variance | -0.178   |
|    learning_rate      | 0.0007   |
|    n_updates          | 499      |
|    policy_loss        | -12.9    |
|    value_loss         | 476      |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 21.3     |
|    ep_rew_mean        | 21.3     |
| time/                 |          |
|    fps                | 723      |
|    iterations         | 600      |
|    time_elapsed       | 4        |
|    total_timesteps    | 3000     |
| train/                |          |
|    entropy_loss       | -0.692   |
|    explained_variance | 0.00671  |
|    learning_rate      | 0.0007   |
|    n_updates          | 599      |
|    policy_loss        | 1.46     |
|    value_loss         | 5.61     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 22.2     |
|    ep_rew_mean        | 22.2     |
| time/                 |          |
|    fps                | 724      |
|    iterations         | 700      |
|    time_elapsed       | 4        |
|    total_timesteps    | 3500     |
| train/                |          |
|    entropy_loss       | -0.693   |
|    explained_variance | 0.031    |
|    learning_rate      | 0.0007   |
|    n_updates          | 699      |
|    policy_loss        | 1.33     |
|    value_loss         | 4.7      |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 23       |
|    ep_rew_mean        | 23       |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 800      |
|    time_elapsed       | 5        |
|    total_timesteps    | 4000     |
| train/                |          |
|    entropy_loss       | -0.689   |
|    explained_variance | 0.17     |
|    learning_rate      | 0.0007   |
|    n_updates          | 799      |
|    policy_loss        | 1.32     |
|    value_loss         | 4.34     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 24.2     |
|    ep_rew_mean        | 24.2     |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 900      |
|    time_elapsed       | 6        |
|    total_timesteps    | 4500     |
| train/                |          |
|    entropy_loss       | -0.688   |
|    explained_variance | 0.006    |
|    learning_rate      | 0.0007   |
|    n_updates          | 899      |
|    policy_loss        | 1.33     |
|    value_loss         | 4.33     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 25.3     |
|    ep_rew_mean        | 25.3     |
| time/                 |          |
|    fps                | 726      |
|    iterations         | 1000     |
|    time_elapsed       | 6        |
|    total_timesteps    | 5000     |
| train/                |          |
|    entropy_loss       | -0.625   |
|    explained_variance | 0.0399   |
|    learning_rate      | 0.0007   |
|    n_updates          | 999      |
|    policy_loss        | 1.54     |
|    value_loss         | 3.87     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 29.1     |
|    ep_rew_mean        | 29.1     |
| time/                 |          |
|    fps                | 727      |
|    iterations         | 1100     |
|    time_elapsed       | 7        |
|    total_timesteps    | 5500     |
| train/                |          |
|    entropy_loss       | -0.671   |
|    explained_variance | 0.00653  |
|    learning_rate      | 0.0007   |
|    n_updates          | 1099     |
|    policy_loss        | 0.939    |
|    value_loss         | 3.27     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 32.3     |
|    ep_rew_mean        | 32.3     |
| time/                 |          |
|    fps                | 727      |
|    iterations         | 1200     |
|    time_elapsed       | 8        |
|    total_timesteps    | 6000     |
| train/                |          |
|    entropy_loss       | -0.62    |
|    explained_variance | 0.0192   |
|    learning_rate      | 0.0007   |
|    n_updates          | 1199     |
|    policy_loss        | 0.995    |
|    value_loss         | 2.82     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 34.6     |
|    ep_rew_mean        | 34.6     |
| time/                 |          |
|    fps                | 727      |
|    iterations         | 1300     |
|    time_elapsed       | 8        |
|    total_timesteps    | 6500     |
| train/                |          |
|    entropy_loss       | -0.607   |
|    explained_variance | 0.000248 |
|    learning_rate      | 0.0007   |
|    n_updates          | 1299     |
|    policy_loss        | 0.572    |
|    value_loss         | 2.51     |
------------------------------------


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 38.1      |
|    ep_rew_mean        | 38.1      |
| time/                 |           |
|    fps                | 727       |
|    iterations         | 1400      |
|    time_elapsed       | 9         |
|    total_timesteps    | 7000      |
| train/                |           |
|    entropy_loss       | -0.425    |
|    explained_variance | -0.000446 |
|    learning_rate      | 0.0007    |
|    n_updates          | 1399      |
|    policy_loss        | 1.1       |
|    value_loss         | 2.13      |
-------------------------------------


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 42.1      |
|    ep_rew_mean        | 42.1      |
| time/                 |           |
|    fps                | 726       |
|    iterations         | 1500      |
|    time_elapsed       | 10        |
|    total_timesteps    | 7500      |
| train/                |           |
|    entropy_loss       | -0.65     |
|    explained_variance | -0.000627 |
|    learning_rate      | 0.0007    |
|    n_updates          | 1499      |
|    policy_loss        | 0.649     |
|    value_loss         | 1.77      |
-------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 46       |
|    ep_rew_mean        | 46       |
| time/                 |          |
|    fps                | 726      |
|    iterations         | 1600     |
|    time_elapsed       | 11       |
|    total_timesteps    | 8000     |
| train/                |          |
|    entropy_loss       | -0.474   |
|    explained_variance | -0.00596 |
|    learning_rate      | 0.0007   |
|    n_updates          | 1599     |
|    policy_loss        | 0.85     |
|    value_loss         | 1.42     |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 48.9     |
|    ep_rew_mean        | 48.9     |
| time/                 |          |
|    fps                | 726      |
|    iterations         | 1700     |
|    time_elapsed       | 11       |
|    total_timesteps    | 8500     |
| train/                |          |
|    entropy_loss       | -0.404   |
|    explained_variance | 0.00036  |
|    learning_rate      | 0.0007   |
|    n_updates          | 1699     |
|    policy_loss        | 0.882    |
|    value_loss         | 1.1      |
------------------------------------


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 55        |
|    ep_rew_mean        | 55        |
| time/                 |           |
|    fps                | 726       |
|    iterations         | 1800      |
|    time_elapsed       | 12        |
|    total_timesteps    | 9000      |
| train/                |           |
|    entropy_loss       | -0.534    |
|    explained_variance | -0.000588 |
|    learning_rate      | 0.0007    |
|    n_updates          | 1799      |
|    policy_loss        | 0.495     |
|    value_loss         | 0.823     |
-------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 58.8     |
|    ep_rew_mean        | 58.8     |
| time/                 |          |
|    fps                | 726      |
|    iterations         | 1900     |
|    time_elapsed       | 13       |
|    total_timesteps    | 9500     |
| train/                |          |
|    entropy_loss       | -0.505   |
|    explained_variance | 0.000926 |
|    learning_rate      | 0.0007   |
|    n_updates          | 1899     |
|    policy_loss        | 0.295    |
|    value_loss         | 0.584    |
------------------------------------


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 62.8      |
|    ep_rew_mean        | 62.8      |
| time/                 |           |
|    fps                | 726       |
|    iterations         | 2000      |
|    time_elapsed       | 13        |
|    total_timesteps    | 10000     |
| train/                |           |
|    entropy_loss       | -0.547    |
|    explained_variance | -0.000189 |
|    learning_rate      | 0.0007    |
|    n_updates          | 1999      |
|    policy_loss        | 0.271     |
|    value_loss         | 0.387     |
-------------------------------------


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 67.4      |
|    ep_rew_mean        | 67.4      |
| time/                 |           |
|    fps                | 725       |
|    iterations         | 2100      |
|    time_elapsed       | 14        |
|    total_timesteps    | 10500     |
| train/                |           |
|    entropy_loss       | -0.557    |
|    explained_variance | -0.000216 |
|    learning_rate      | 0.0007    |
|    n_updates          | 2099      |
|    policy_loss        | 0.22      |
|    value_loss         | 0.225     |
-------------------------------------


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 72.6      |
|    ep_rew_mean        | 72.6      |
| time/                 |           |
|    fps                | 725       |
|    iterations         | 2200      |
|    time_elapsed       | 15        |
|    total_timesteps    | 11000     |
| train/                |           |
|    entropy_loss       | -0.585    |
|    explained_variance | -4.77e-07 |
|    learning_rate      | 0.0007    |
|    n_updates          | 2199      |
|    policy_loss        | 0.104     |
|    value_loss         | 0.109     |
-------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 77.2     |
|    ep_rew_mean        | 77.2     |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 2300     |
|    time_elapsed       | 15       |
|    total_timesteps    | 11500    |
| train/                |          |
|    entropy_loss       | -0.416   |
|    explained_variance | 7.03e-05 |
|    learning_rate      | 0.0007   |
|    n_updates          | 2299     |
|    policy_loss        | 0.162    |
|    value_loss         | 0.0354   |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 81.8     |
|    ep_rew_mean        | 81.8     |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 2400     |
|    time_elapsed       | 16       |
|    total_timesteps    | 12000    |
| train/                |          |
|    entropy_loss       | -0.532   |
|    explained_variance | 0.000919 |
|    learning_rate      | 0.0007   |
|    n_updates          | 2399     |
|    policy_loss        | 0.0178   |
|    value_loss         | 0.00197  |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 84.5     |
|    ep_rew_mean        | 84.5     |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 2500     |
|    time_elapsed       | 17       |
|    total_timesteps    | 12500    |
| train/                |          |
|    entropy_loss       | -0.56    |
|    explained_variance | -0.00834 |
|    learning_rate      | 0.0007   |
|    n_updates          | 2499     |
|    policy_loss        | 0.000931 |
|    value_loss         | 9.36e-06 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 91.1     |
|    ep_rew_mean        | 91.1     |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 2600     |
|    time_elapsed       | 17       |
|    total_timesteps    | 13000    |
| train/                |          |
|    entropy_loss       | -0.51    |
|    explained_variance | -0.0151  |
|    learning_rate      | 0.0007   |
|    n_updates          | 2599     |
|    policy_loss        | 0.00124  |
|    value_loss         | 5.94e-06 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 94.6     |
|    ep_rew_mean        | 94.6     |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 2700     |
|    time_elapsed       | 18       |
|    total_timesteps    | 13500    |
| train/                |          |
|    entropy_loss       | -0.368   |
|    explained_variance | -0.0187  |
|    learning_rate      | 0.0007   |
|    n_updates          | 2699     |
|    policy_loss        | 0.00071  |
|    value_loss         | 1.2e-06  |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 98.8     |
|    ep_rew_mean        | 98.8     |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 2800     |
|    time_elapsed       | 19       |
|    total_timesteps    | 14000    |
| train/                |          |
|    entropy_loss       | -0.544   |
|    explained_variance | 0.301    |
|    learning_rate      | 0.0007   |
|    n_updates          | 2799     |
|    policy_loss        | 0.00018  |
|    value_loss         | 1.58e-07 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 104      |
|    ep_rew_mean        | 104      |
| time/                 |          |
|    fps                | 726      |
|    iterations         | 2900     |
|    time_elapsed       | 19       |
|    total_timesteps    | 14500    |
| train/                |          |
|    entropy_loss       | -0.469   |
|    explained_variance | 0.122    |
|    learning_rate      | 0.0007   |
|    n_updates          | 2899     |
|    policy_loss        | 0.000126 |
|    value_loss         | 2.27e-07 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 107      |
|    ep_rew_mean        | 107      |
| time/                 |          |
|    fps                | 726      |
|    iterations         | 3000     |
|    time_elapsed       | 20       |
|    total_timesteps    | 15000    |
| train/                |          |
|    entropy_loss       | -0.533   |
|    explained_variance | nan      |
|    learning_rate      | 0.0007   |
|    n_updates          | 2999     |
|    policy_loss        | 6.52e-06 |
|    value_loss         | 5.59e-10 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 111      |
|    ep_rew_mean        | 111      |
| time/                 |          |
|    fps                | 726      |
|    iterations         | 3100     |
|    time_elapsed       | 21       |
|    total_timesteps    | 15500    |
| train/                |          |
|    entropy_loss       | -0.423   |
|    explained_variance | nan      |
|    learning_rate      | 0.0007   |
|    n_updates          | 3099     |
|    policy_loss        | 1.31e-06 |
|    value_loss         | 2.33e-11 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 114      |
|    ep_rew_mean        | 114      |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 3200     |
|    time_elapsed       | 22       |
|    total_timesteps    | 16000    |
| train/                |          |
|    entropy_loss       | -0.435   |
|    explained_variance | -3.12    |
|    learning_rate      | 0.0007   |
|    n_updates          | 3199     |
|    policy_loss        | 1.17e-05 |
|    value_loss         | 1.61e-09 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 119      |
|    ep_rew_mean        | 119      |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 3300     |
|    time_elapsed       | 22       |
|    total_timesteps    | 16500    |
| train/                |          |
|    entropy_loss       | -0.424   |
|    explained_variance | nan      |
|    learning_rate      | 0.0007   |
|    n_updates          | 3299     |
|    policy_loss        | 9.67e-07 |
|    value_loss         | 5.24e-10 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 124      |
|    ep_rew_mean        | 124      |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 3400     |
|    time_elapsed       | 23       |
|    total_timesteps    | 17000    |
| train/                |          |
|    entropy_loss       | -0.399   |
|    explained_variance | -2.32    |
|    learning_rate      | 0.0007   |
|    n_updates          | 3399     |
|    policy_loss        | 1.46e-05 |
|    value_loss         | 1.51e-09 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 129      |
|    ep_rew_mean        | 129      |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 3500     |
|    time_elapsed       | 24       |
|    total_timesteps    | 17500    |
| train/                |          |
|    entropy_loss       | -0.536   |
|    explained_variance | nan      |
|    learning_rate      | 0.0007   |
|    n_updates          | 3499     |
|    policy_loss        | 8.35e-06 |
|    value_loss         | 1.13e-09 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 134      |
|    ep_rew_mean        | 134      |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 3600     |
|    time_elapsed       | 24       |
|    total_timesteps    | 18000    |
| train/                |          |
|    entropy_loss       | -0.463   |
|    explained_variance | -0.00633 |
|    learning_rate      | 0.0007   |
|    n_updates          | 3599     |
|    policy_loss        | 0.000121 |
|    value_loss         | 5.39e-08 |
------------------------------------


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 138       |
|    ep_rew_mean        | 138       |
| time/                 |           |
|    fps                | 725       |
|    iterations         | 3700      |
|    time_elapsed       | 25        |
|    total_timesteps    | 18500     |
| train/                |           |
|    entropy_loss       | -0.453    |
|    explained_variance | -0.104    |
|    learning_rate      | 0.0007    |
|    n_updates          | 3699      |
|    policy_loss        | -0.000458 |
|    value_loss         | 1.3e-06   |
-------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 147      |
|    ep_rew_mean        | 147      |
| time/                 |          |
|    fps                | 725      |
|    iterations         | 3800     |
|    time_elapsed       | 26       |
|    total_timesteps    | 19000    |
| train/                |          |
|    entropy_loss       | -0.341   |
|    explained_variance | 0.427    |
|    learning_rate      | 0.0007   |
|    n_updates          | 3799     |
|    policy_loss        | 4.41e-05 |
|    value_loss         | 1.37e-08 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 151      |
|    ep_rew_mean        | 151      |
| time/                 |          |
|    fps                | 726      |
|    iterations         | 3900     |
|    time_elapsed       | 26       |
|    total_timesteps    | 19500    |
| train/                |          |
|    entropy_loss       | -0.298   |
|    explained_variance | 0.0542   |
|    learning_rate      | 0.0007   |
|    n_updates          | 3899     |
|    policy_loss        | 0.000255 |
|    value_loss         | 1.21e-07 |
------------------------------------


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 155      |
|    ep_rew_mean        | 155      |
| time/                 |          |
|    fps                | 727      |
|    iterations         | 4000     |
|    time_elapsed       | 27       |
|    total_timesteps    | 20000    |
| train/                |          |
|    entropy_loss       | -0.465   |
|    explained_variance | -0.00157 |
|    learning_rate      | 0.0007   |
|    n_updates          | 3999     |
|    policy_loss        | 0.000204 |
|    value_loss         | 1.21e-06 |
------------------------------------


In [21]:
mean_reward, std_reward = evaluate_policy(model, eval_envs_cartpole, n_eval_episodes=50, deterministic=True)

print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

mean_reward:249.34 +/- 97.58


**Your goal is to beat that baseline and get closer to the optimal score of 500**

Time to tune!

In [22]:
import torch.nn as nn

In [24]:
policy_kwargs = dict(
    net_arch=[
      dict(vf=[64, 64], pi=[64, 64]), # network architectures for actor/critic
    ],
    activation_fn=nn.Tanh,
)

hyperparams = dict(
    n_steps=5, # number of steps to collect data before updating policy
    learning_rate=7e-4,
    gamma=0.99, # discount factor
    max_grad_norm=0.5, # The maximum value for the gradient clipping
    ent_coef=0.0, # Entropy coefficient for the loss calculation
)

model = A2C("MlpPolicy", "CartPole-v1", seed=8, verbose=0, **hyperparams).learn(budget)

In [25]:
mean_reward, std_reward = evaluate_policy(model, eval_envs_cartpole, n_eval_episodes=50, deterministic=True)

print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

mean_reward:250.64 +/- 81.99


Hint - Recommended Hyperparameter Range

```python
gamma = trial.suggest_float("gamma", 0.9, 0.99999, log=True)
max_grad_norm = trial.suggest_float("max_grad_norm", 0.3, 5.0, log=True)
# from 2**3 = 8 to 2**10 = 1024
n_steps = 2 ** trial.suggest_int("exponent_n_steps", 3, 10)
learning_rate = trial.suggest_float("lr", 1e-5, 1, log=True)
ent_coef = trial.suggest_float("ent_coef", 0.00000001, 0.1, log=True)
# net_arch tiny: {"pi": [64], "vf": [64]}
# net_arch default: {"pi": [64, 64], "vf": [64, 64]}
# activation_fn = nn.Tanh / nn.ReLU
```

# Part III: Automatic Hyperparameter Tuning





In this part we will create a script that allows to search for the best hyperparameters automatically.

### Imports

In [26]:
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from optuna.visualization import plot_optimization_history, plot_param_importances

### Config

In [27]:
N_TRIALS = 100  # Maximum number of trials
N_JOBS = 1 # Number of jobs to run in parallel
N_STARTUP_TRIALS = 5  # Stop random sampling after N_STARTUP_TRIALS
N_EVALUATIONS = 2  # Number of evaluations during the training
N_TIMESTEPS = int(2e4)  # Training budget
EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS)
N_EVAL_ENVS = 5
N_EVAL_EPISODES = 10
TIMEOUT = int(60 * 15)  # 15 minutes

ENV_ID = "CartPole-v1"

DEFAULT_HYPERPARAMS = {
    "policy": "MlpPolicy",
    "env": ENV_ID,
}

### Exercise (5 minutes): Define the search space

In [None]:
from typing import Any, Dict
import torch
import torch.nn as nn

def sample_a2c_params(trial: optuna.Trial) -> Dict[str, Any]:
    """
    Sampler for A2C hyperparameters.

    :param trial: Optuna trial object
    :return: The sampled hyperparameters for the given trial.
    """
    # Discount factor between 0.9 and 0.9999
    gamma = 1.0 - trial.suggest_float("gamma", 0.0001, 0.1, log=True)
    max_grad_norm = trial.suggest_float("max_grad_norm", 0.3, 5.0, log=True)
    # 8, 16, 32, ... 1024
    n_steps = 2 ** trial.suggest_int("exponent_n_steps", 3, 10)

    ### YOUR CODE HERE
    # TODO:
    # - define the learning rate search space [1e-5, 1] (log) -> `suggest_float`
    # - define the network architecture search space ["tiny", "small"] -> `suggest_categorical`
    # - define the activation function search space ["tanh", "relu"]
    learning_rate = suggest_float("lr", 1e-5, 1, log=True)
    net_arch = suggest_categorical("search", ["tiny", "small"])
    activation_fn = 

    ### END OF YOUR CODE

    # Display true values
    trial.set_user_attr("gamma_", gamma)
    trial.set_user_attr("n_steps", n_steps)

    net_arch = [
        {"pi": [64], "vf": [64]}
        if net_arch == "tiny"
        else {"pi": [64, 64], "vf": [64, 64]}
    ]

    activation_fn = {"tanh": nn.Tanh, "relu": nn.ReLU}[activation_fn]

    return {
        "n_steps": n_steps,
        "gamma": gamma,
        "learning_rate": learning_rate,
        "max_grad_norm": max_grad_norm,
        "policy_kwargs": {
            "net_arch": net_arch,
            "activation_fn": activation_fn,
        },
    }

### Define the objective function

First we define a custom callback to report the results of periodic evaluations to Optuna:

In [None]:
from stable_baselines3.common.callbacks import EvalCallback

class TrialEvalCallback(EvalCallback):
    """
    Callback used for evaluating and reporting a trial.

    :param eval_env: Evaluation environement
    :param trial: Optuna trial object
    :param n_eval_episodes: Number of evaluation episodes
    :param eval_freq:   Evaluate the agent every ``eval_freq`` call of the callback.
    :param deterministic: Whether the evaluation should
        use a stochastic or deterministic policy.
    :param verbose:
    """

    def __init__(
        self,
        eval_env: gym.Env,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        deterministic: bool = True,
        verbose: int = 0,
    ):

        super().__init__(
            eval_env=eval_env,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            deterministic=deterministic,
            verbose=verbose,
        )
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self) -> bool:
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            # Evaluate policy (done in the parent class)
            super()._on_step()
            self.eval_idx += 1
            # Send report to Optuna
            self.trial.report(self.last_mean_reward, self.eval_idx)
            # Prune trial if need
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True

### Exercise (10 minutes): Define the objective function

Then we define the objective function that is in charge of sampling hyperparameters, creating the model and then returning the result to Optuna

In [None]:
def objective(trial: optuna.Trial) -> float:
    """
    Objective function using by Optuna to evaluate
    one configuration (i.e., one set of hyperparameters).

    Given a trial object, it will sample hyperparameters,
    evaluate it and report the result (mean episodic reward after training)

    :param trial: Optuna trial object
    :return: Mean episodic reward after training
    """

    kwargs = DEFAULT_HYPERPARAMS.copy()
    ### YOUR CODE HERE
    # TODO:
    # 1. Sample hyperparameters and update the default keyword arguments: `kwargs.update(other_params)`
    # 2. Create the evaluation envs
    # 3. Create the `TrialEvalCallback`

    # 1. Sample hyperparameters and update the keyword arguments

    # Create the RL model
    model = A2C(**kwargs)

    # 2. Create envs used for evaluation using `make_vec_env`, `ENV_ID` and `N_EVAL_ENVS`

    # 3. Create the `TrialEvalCallback` callback defined above that will periodically evaluate
    # and report the performance using `N_EVAL_EPISODES` every `EVAL_FREQ`
    # TrialEvalCallback signature:
    # TrialEvalCallback(eval_env, trial, n_eval_episodes, eval_freq, deterministic, verbose)
    eval_callback = ...

    ### END OF YOUR CODE

    nan_encountered = False
    try:
        # Train the model
        model.learn(N_TIMESTEPS, callback=eval_callback)
    except AssertionError as e:
        # Sometimes, random hyperparams can generate NaN
        print(e)
        nan_encountered = True
    finally:
        # Free memory
        model.env.close()
        eval_envs.close()

    # Tell the optimizer that the trial failed
    if nan_encountered:
        return float("nan")

    if eval_callback.is_pruned:
        raise optuna.exceptions.TrialPruned()

    return eval_callback.last_mean_reward

### The optimization loop

In [None]:
import torch as th

# Set pytorch num threads to 1 for faster training
th.set_num_threads(1)
# Select the sampler, can be random, TPESampler, CMAES, ...
sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS)
# Do not prune before 1/3 of the max budget is used
pruner = MedianPruner(
    n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3
)
# Create the study and start the hyperparameter optimization
study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")

try:
    study.optimize(objective, n_trials=N_TRIALS, n_jobs=N_JOBS, timeout=TIMEOUT)
except KeyboardInterrupt:
    pass

print("Number of finished trials: ", len(study.trials))

print("Best trial:")
trial = study.best_trial

print(f"  Value: {trial.value}")

print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

print("  User attrs:")
for key, value in trial.user_attrs.items():
    print(f"    {key}: {value}")

# Write report
study.trials_dataframe().to_csv("study_results_a2c_cartpole.csv")

fig1 = plot_optimization_history(study)
fig2 = plot_param_importances(study)

fig1.show()
fig2.show()

Complete example: https://github.com/DLR-RM/rl-baselines3-zoo

# Conclusion

What we have seen in this notebook:
- the importance of good hyperparameters
- how to do automatic hyperparameter search with optuna
