In [1]:
import gymnasium as gym
from pogema import GridConfig
from stable_baselines3 import DQN, A2C
from stable_baselines3.common.evaluation import evaluate_policy

%load_ext autoreload
%autoreload 2
%matplotlib inline

grid_config = GridConfig(
    size=8,
    density=0.3,
    num_agents=1,
    max_episode_steps=30
)

env = gym.make("Pogema-v0",grid_config=grid_config)

a2c_model = A2C(
    "MlpPolicy",
    env,
    verbose=1,
    seed=42
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


  logger.warn(
  logger.warn(


#### Optuna Integration

In [11]:
""" Optuna example that optimizes the hyperparameters of
a reinforcement learning agent using A2C implementation from Stable-Baselines3
on a Gymnasium environment.

This is a simplified version of what can be found in https://github.com/DLR-RM/rl-baselines3-zoo.

You can run this example as follows:
    $ python sb3_simple.py

"""
from typing import Any
from typing import Dict

import gymnasium
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
import torch
import torch.nn as nn


N_TRIALS = 100
N_STARTUP_TRIALS = 5
N_EVALUATIONS = 2
N_TIMESTEPS = int(1.2e5)
EVAL_FREQ = int(N_TIMESTEPS / N_EVALUATIONS)
N_EVAL_EPISODES = 3


DEFAULT_HYPERPARAMS = {
    "policy": "MlpPolicy",
    "verbose": 1,
    "env": env,
    "seed": 42
}

def sample_a2c_params(trial: optuna.Trial) -> Dict[str, Any]:
    """Sampler for A2C hyperparameters."""
    gamma = 1.0 - trial.suggest_float("gamma", 0.0001, 0.1, log=True)
    learning_rate = trial.suggest_float("lr", 1e-5, 1, log=True)

    trial.set_user_attr("gamma", gamma)

    return {
        "gamma": gamma,
        "learning_rate": learning_rate,
    }


class TrialEvalCallback(EvalCallback):
    """Callback used for evaluating and reporting a trial."""

    def __init__(
        self,
        eval_env: gymnasium.Env,
        trial: optuna.Trial,
        n_eval_episodes: int = 5,
        eval_freq: int = 10000,
        deterministic: bool = True,
        verbose: int = 0,
    ):
        super().__init__(
            eval_env=eval_env,
            n_eval_episodes=n_eval_episodes,
            eval_freq=eval_freq,
            deterministic=deterministic,
            verbose=verbose,
        )
        self.trial = trial
        self.eval_idx = 0
        self.is_pruned = False

    def _on_step(self) -> bool:
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            super()._on_step()
            self.eval_idx += 1
            self.trial.report(self.last_mean_reward, self.eval_idx)
            # Prune trial if need.
            if self.trial.should_prune():
                self.is_pruned = True
                return False
        return True


def objective(trial: optuna.Trial) -> float:
    kwargs = DEFAULT_HYPERPARAMS.copy()
    # Sample hyperparameters.
    kwargs.update(sample_a2c_params(trial))
    # Create the RL model.
    model = A2C(**kwargs)
    # Create env used for evaluation.
    eval_env = Monitor(env)
    # Create the callback that will periodically evaluate and report the performance.
    eval_callback = TrialEvalCallback(
        eval_env, trial, n_eval_episodes=N_EVAL_EPISODES, eval_freq=EVAL_FREQ, deterministic=True
    )

    nan_encountered = False
    try:
        model.learn(N_TIMESTEPS, callback=eval_callback)
    except AssertionError as e:
        # Sometimes, random hyperparams can generate NaN.
        print(e)
        nan_encountered = True
    finally:
        # Free memory.
        model.env.close()
        eval_env.close()

    # Tell the optimizer that the trial failed.
    if nan_encountered:
        return float("nan")

    if eval_callback.is_pruned:
        raise optuna.exceptions.TrialPruned()

    return eval_callback.last_mean_reward


if __name__ == "__main__":
    # Set pytorch num threads to 1 for faster training.
    torch.set_num_threads(1)

    sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS)
    # Do not prune before 1/3 of the max budget is used.
    pruner = MedianPruner(n_startup_trials=N_STARTUP_TRIALS, n_warmup_steps=N_EVALUATIONS // 3)

    study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")
    try:
        study.optimize(objective, n_trials=N_TRIALS, n_jobs=8, timeout=600)
    except KeyboardInterrupt:
        pass

    print("Number of finished trials: ", len(study.trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

    print("  User attrs:")
    for key, value in trial.user_attrs.items():
        print("    {}: {}".format(key, value))


  from .autonotebook import tqdm as notebook_tqdm
[I 2023-11-05 20:37:38,983] A new study created in memory with name: no-name-1422e052-4fd0-4e6f-bb31-31e014e0e8ca


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 21.6      |
|    ep_rew_mean        | 0.174     |
| time/                 |           |
|    fps                | 85        |
|    iterations     

[I 2023-11-05 21:00:40,407] Trial 3 finished with value: 0.0 and parameters: {'gamma': 0.0006126144803556276, 'lr': 3.964705082143868e-05}. Best is trial 3 with value: 0.0.


-------------------------------------
| rollout/              |           |
|    ep_len_mean        | 35.4      |
|    ep_rew_mean        | 0         |
| time/                 |           |
|    fps                | 86        |
|    iterations         | 23900     |
|    time_elapsed       | 1381      |
|    total_timesteps    | 119500    |
| train/                |           |
|    entropy_loss       | -6.89e-06 |
|    explained_variance | -0.965    |
|    learning_rate      | 0.0107    |
|    n_updates          | 23899     |
|    policy_loss        | -1.74e-11 |
|    value_loss         | 1.85e-08  |
-------------------------------------
-------------------------------------
| eval/                 |           |
|    mean_ep_length     | 25.3      |
|    mean_reward        | 0         |
| time/                 |           |
|    total_timesteps    | 120000    |
| train/                |           |
|    entropy_loss       | -1.29e-05 |
|    explained_variance | -0.000364 |
|    learnin

[I 2023-11-05 21:00:42,134] Trial 5 finished with value: 0.0 and parameters: {'gamma': 0.02117041979375509, 'lr': 0.017679498097075814}. Best is trial 3 with value: 0.0.


------------------------------------
| rollout/              |          |
|    ep_len_mean        | 23.7     |
|    ep_rew_mean        | 0.25     |
| time/                 |          |
|    fps                | 86       |
|    iterations         | 23900    |
|    time_elapsed       | 1383     |
|    total_timesteps    | 119500   |
| train/                |          |
|    entropy_loss       | -1.27    |
|    explained_variance | -712     |
|    learning_rate      | 4.44e-05 |
|    n_updates          | 23899    |
|    policy_loss        | -0.0462  |
|    value_loss         | 0.00102  |
------------------------------------
-------------------------------------
| eval/                 |           |
|    mean_ep_length     | 13        |
|    mean_reward        | 0         |
| time/                 |           |
|    total_timesteps    | 120000    |
| train/                |           |
|    entropy_loss       | -4.75e-05 |
|    explained_variance | -10.5     |
|    learning_rate      | 0.0

[I 2023-11-05 21:00:43,224] Trial 0 finished with value: 0.0 and parameters: {'gamma': 0.0002758623592175237, 'lr': 0.005422536876064272}. Best is trial 3 with value: 0.0.


------------------------------------
| eval/                 |          |
|    mean_ep_length     | 38.7     |
|    mean_reward        | 0.333    |
| time/                 |          |
|    total_timesteps    | 120000   |
| train/                |          |
|    entropy_loss       | -0.279   |
|    explained_variance | -9.77    |
|    learning_rate      | 0.00012  |
|    n_updates          | 23999    |
|    policy_loss        | -0.00132 |
|    value_loss         | 0.000397 |
------------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 24.6     |
|    ep_rew_mean     | 0.26     |
| time/              |          |
|    fps             | 86       |
|    iterations      | 24000    |
|    time_elapsed    | 1384     |
|    total_timesteps | 120000   |
---------------------------------


[I 2023-11-05 21:00:43,329] Trial 7 finished with value: 0.3333333333333333 and parameters: {'gamma': 0.016506877957572757, 'lr': 0.00012022785707842533}. Best is trial 7 with value: 0.3333333333333333.


------------------------------------
| eval/                 |          |
|    mean_ep_length     | 32.7     |
|    mean_reward        | 0        |
| time/                 |          |
|    total_timesteps    | 120000   |
| train/                |          |
|    entropy_loss       | -1.18    |
|    explained_variance | -34      |
|    learning_rate      | 6.1e-05  |
|    n_updates          | 23999    |
|    policy_loss        | 0.0493   |
|    value_loss         | 0.00116  |
------------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 23.9     |
|    ep_rew_mean     | 0.22     |
| time/              |          |
|    fps             | 86       |
|    iterations      | 24000    |
|    time_elapsed    | 1384     |
|    total_timesteps | 120000   |
---------------------------------


[I 2023-11-05 21:00:43,356] Trial 1 finished with value: 0.0 and parameters: {'gamma': 0.03343026257158648, 'lr': 6.100579220265278e-05}. Best is trial 7 with value: 0.3333333333333333.


-------------------------------------
| eval/                 |           |
|    mean_ep_length     | 22.3      |
|    mean_reward        | 0.333     |
| time/                 |           |
|    total_timesteps    | 120000    |
| train/                |           |
|    entropy_loss       | -4.08e-05 |
|    explained_variance | -0.000994 |
|    learning_rate      | 0.0158    |
|    n_updates          | 23999     |
|    policy_loss        | 1.85e-06  |
|    value_loss         | 0.523     |
-------------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 21.8     |
|    ep_rew_mean     | 0.23     |
| time/              |          |
|    fps             | 86       |
|    iterations      | 24000    |
|    time_elapsed    | 1384     |
|    total_timesteps | 120000   |
---------------------------------


[I 2023-11-05 21:00:43,531] Trial 6 finished with value: 0.3333333333333333 and parameters: {'gamma': 0.0158925638342946, 'lr': 0.015822596384922064}. Best is trial 7 with value: 0.3333333333333333.


-------------------------------------
| eval/                 |           |
|    mean_ep_length     | 35.7      |
|    mean_reward        | 0         |
| time/                 |           |
|    total_timesteps    | 120000    |
| train/                |           |
|    entropy_loss       | -6.09e-06 |
|    explained_variance | -86.9     |
|    learning_rate      | 0.0107    |
|    n_updates          | 23999     |
|    policy_loss        | -0        |
|    value_loss         | 4.73e-08  |
-------------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 32.1     |
|    ep_rew_mean     | 0        |
| time/              |          |
|    fps             | 86       |
|    iterations      | 24000    |
|    time_elapsed    | 1384     |
|    total_timesteps | 120000   |
---------------------------------


[I 2023-11-05 21:00:43,709] Trial 2 finished with value: 0.0 and parameters: {'gamma': 0.00030429344956590254, 'lr': 0.010658196803117007}. Best is trial 7 with value: 0.3333333333333333.


------------------------------------
| eval/                 |          |
|    mean_ep_length     | 11       |
|    mean_reward        | 0.667    |
| time/                 |          |
|    total_timesteps    | 120000   |
| train/                |          |
|    entropy_loss       | -1.29    |
|    explained_variance | -9.42    |
|    learning_rate      | 4.44e-05 |
|    n_updates          | 23999    |
|    policy_loss        | 0.00841  |
|    value_loss         | 0.000321 |
------------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 24.4     |
|    ep_rew_mean     | 0.26     |
| time/              |          |
|    fps             | 86       |
|    iterations      | 24000    |
|    time_elapsed    | 1384     |
|    total_timesteps | 120000   |
---------------------------------


[I 2023-11-05 21:00:43,803] Trial 4 finished with value: 0.6666666666666666 and parameters: {'gamma': 0.049494589643343134, 'lr': 4.441422507181185e-05}. Best is trial 4 with value: 0.6666666666666666.


Number of finished trials:  8
Best trial:
  Value:  0.6666666666666666
  Params: 
    gamma: 0.049494589643343134
    lr: 4.441422507181185e-05
  User attrs:
    gamma: 0.9505054103566568


### Train agent with best hyper parameters

In [14]:
a2c_model = A2C(
    "MlpPolicy",
    env,
    verbose=1,
    gamma=1-0.049494589643343134,
    learning_rate=4.441422507181185e-05,
    seed=42
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [15]:
# Train agent and save it
a2c_model.learn(int(1.2e5))
a2c_model.save("saved/a2c_baseline")

------------------------------------
| rollout/              |          |
|    ep_len_mean        | 27.1     |
|    ep_rew_mean        | 0.222    |
| time/                 |          |
|    fps                | 2163     |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -1.61    |
|    explained_variance | -14      |
|    learning_rate      | 4.44e-05 |
|    n_updates          | 99       |
|    policy_loss        | -0.21    |
|    value_loss         | 0.0197   |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 26.3     |
|    ep_rew_mean        | 0.216    |
| time/                 |          |
|    fps                | 2075     |
|    iterations         | 200      |
|    time_elapsed       | 0        |
|    total_timesteps    | 1000     |
| train/                |          |
|

#### Testing 

In [4]:
a2c_model = A2C.load("saved/a2c_baseline")

env.reset()

mean_reward, std_reward = evaluate_policy(a2c_model, env, deterministic=True, n_eval_episodes=20)
print(f"mean_reward:{mean_reward:.2f} +/- {std_reward:.2f}")

mean_reward:0.10 +/- 0.30


In [5]:
# RANDOM SEED
from IPython.display import SVG, display
from pogema.animation import AnimationMonitor, AnimationConfig

env = AnimationMonitor(env)

def evaluate_success_rate(model, env, num_episodes=100):
    success_count = 0
    step_array = []
    for i in range(num_episodes):
        print(f'---{i}---')
        obs = env.reset()

        # Check if observation is a tuple and extract the first element if true.
        if isinstance(obs, tuple):
            obs = obs[0]
        max_step = 100
        steps_taken = 0
        done = truncated = False
        while not done and max_step > 0:
            action, _ = model.predict(obs)
            next_obs, reward, done, truncated, info = env.step(action)
            print(action,max_step,success_count,done)
            max_step -= 1
            steps_taken += 1
            # Check if next_obs is a tuple and extract the first element if true.
            if isinstance(next_obs, tuple):
                next_obs = next_obs[0]
            obs = next_obs

            # Check if agent was successful in that episode.
            if done:
                success_count += 1
                step_array.append(steps_taken)
                env.save_animation(f"render{i}.svg", AnimationConfig(egocentric_idx=0))
                break

    success_rate = success_count / num_episodes
    return success_rate, step_array

success_rate,step_array = evaluate_success_rate(a2c_model, env)
print(f"Agent Success Rate: {success_rate * 100:.2f}%")
print(f"steps to termination : {step_array}")

---0---
2 100 0 False
0 99 0 False
2 98 0 False
0 97 0 False
4 96 0 False
2 95 0 False
2 94 0 False
3 93 0 False
3 92 0 False
3 91 0 False
3 90 0 False
0 89 0 False
2 88 0 False
4 87 0 False
4 86 0 False
1 85 0 False
1 84 0 False
3 83 0 False
4 82 0 False
4 81 0 False
3 80 0 False
1 79 0 False
2 78 0 False
1 77 0 False
3 76 0 False
1 75 0 False
1 74 0 False
1 73 0 False
3 72 0 False
2 71 0 False
2 70 0 False
4 69 0 False
1 68 0 False
2 67 0 False
1 66 0 False
2 65 0 False
3 64 0 False
1 63 0 False
3 62 0 False
1 61 0 False
2 60 0 False
2 59 0 False
1 58 0 False
3 57 0 False
4 56 0 False
2 55 0 False
3 54 0 False
2 53 0 False
1 52 0 False
3 51 0 False
4 50 0 False
1 49 0 False
0 48 0 False
3 47 0 False
1 46 0 False
3 45 0 False
0 44 0 False
3 43 0 False
1 42 0 False
4 41 0 False
0 40 0 False
3 39 0 False
1 38 0 False
2 37 0 False
4 36 0 False
3 35 0 False
3 34 0 False
1 33 0 False
3 32 0 False
3 31 0 False
1 30 0 False
1 29 0 False
2 28 0 False
1 27 0 False
0 26 0 False
4 25 0 False
3 2

In [None]:
# we got 35% success rate for 1.25e5 steps
# we got 7% success rate for 3.0e5 steps
# better to use 1.25e5 steps