## Choosing the best checkpoint from the PBT trials

In [1]:
from pathlib import Path

from ray.rllib.agents.ppo import PPOTrainer
from ray.tune import ExperimentAnalysis
from ray.tune.registry import register_env

from inventory_env.env_creator import inventory_env_creator

register_env("inventory_env", inventory_env_creator)

path_to_results_dir = Path().absolute().parent / "inventory_management_rl" / "experiments" / "experiment_results" / "pbt"
analysis = ExperimentAnalysis(path_to_results_dir, default_metric="evaluation/episode_reward_mean", default_mode="max")
best_trial = analysis.get_best_trial(scope="all")
best_checkpoint = analysis.get_best_checkpoint(best_trial)
config = best_trial.config
agent = PPOTrainer(config=config)
agent.restore(best_checkpoint)

[2m[36m(RolloutWorker pid=12347)[0m   logger.warn(
[2m[36m(RolloutWorker pid=12345)[0m   logger.warn(
2023-08-10 20:35:13,529	INFO trainable.py:495 -- Restored on 192.168.0.16 from checkpoint: /home/dibya/Dropbox/programming_projects/inventory_management_rl_experiments/experiments_full_length/experiment_results/pbt/PPO_inventory_env_351d6_00003_3_sgd_minibatch_size=64,train_batch_size=512_2023-03-16_16-35-57/checkpoint_011300/checkpoint-11300
2023-08-10 20:35:13,530	INFO trainable.py:503 -- Current state after restoring: {'_iteration': 11300, '_timesteps_total': 16230400, '_time_total': 91218.1564540863, '_episodes_total': 180320}
[2m[36m(RolloutWorker pid=12343)[0m   logger.warn(
[2m[36m(RolloutWorker pid=12341)[0m   logger.warn(


## Measuring average performance using 100000 episodes and same seed

In [2]:
import numpy as np

env = inventory_env_creator(
    {
        "obs_filter": "my_normalize", 
        "reward_filter": None, 
    }, 
    seed=0
)
num_episodes = 100000
all_r = []
for _ in range(num_episodes):
    obs = env.reset()
    ep_r = 0
    while True:
        action = agent.compute_action(obs, unsquash_action=True)
        obs, r, done, _ = env.step(np.around(action))
        ep_r += r
        if done:
            break
    all_r.append(ep_r)
baseline = sum(all_r) / num_episodes
print(baseline)

  logger.warn(


178504.99510324348


## Population Based Training (PBT) lives up to its promise!

- More stable training.
- Higher performance.