Skip to content

Commit

Permalink
Added support for using multiple envs for evaluation (DLR-RM#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed May 29, 2021
1 parent f77d9e1 commit 8f41f40
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 7 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Release 1.1.0a6 (WIP)
## Release 1.1.0a10 (WIP)

### Breaking Changes
- Upgrade to SB3 >= 1.1.0a6
- Upgrade to SB3 >= 1.1.0a10 (master version)
- Upgrade to sb3-contrib >= 1.1.0a6
- Add timeout handling (cf SB3 doc)
- `HER` is now a replay buffer class and no more an algorithm
Expand All @@ -13,6 +13,7 @@
- Add support for recording videos of training experiments (@mcres)
- Add support for dictionary observations
- Added experimental parallel training (with `utils.callbacks.ParallelTrainCallback`)
- Added support for using multiple envs for evaluation

### Bug fixes
- Fixed video rendering for PyBullet envs on Linux
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ For example (with tensorboard support):
python train.py --algo ppo --env CartPole-v1 --tensorboard-log /tmp/stable-baselines/
```

Evaluate the agent every 10000 steps using 10 episodes for evaluation:
Evaluate the agent every 10000 steps using 10 episodes for evaluation (using only one evaluation env):
```
python train.py --algo sac --env HalfCheetahBulletEnv-v0 --eval-freq 10000 --eval-episodes 10
python train.py --algo sac --env HalfCheetahBulletEnv-v0 --eval-freq 10000 --eval-episodes 10 --n-eval-envs 1
```

Save a checkpoint of the agent every 100000 steps:
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
stable-baselines3[extra,tests,docs]>=1.1.0a7
git+https://github.com/DLR-RM/stable-baselines3#egg=stable-baselines3[extra,tests,docs]
box2d-py==2.3.8
pybullet
gym-minigrid
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"--eval-freq", help="Evaluate the agent every n steps (if negative, no evaluation)", default=10000, type=int
)
parser.add_argument("--eval-episodes", help="Number of episodes to use for evaluation", default=5, type=int)
parser.add_argument("--n-eval-envs", help="Number of environments for evaluation", default=1, type=int)
parser.add_argument("--save-freq", help="Save the model every n steps (if negative, no checkpoint)", default=-1, type=int)
parser.add_argument(
"--save-replay-buffer", help="Save the replay buffer too (when applicable)", action="store_true", default=False
Expand Down Expand Up @@ -157,6 +158,7 @@
save_replay_buffer=args.save_replay_buffer,
verbose=args.verbose,
vec_env_type=args.vec_env,
n_eval_envs=args.n_eval_envs,
)

# Prepare experiment and launch hyperparameter optimization if needed
Expand Down
4 changes: 3 additions & 1 deletion utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
save_replay_buffer: bool = False,
verbose: int = 1,
vec_env_type: str = "dummy",
n_eval_envs: int = 1,
):
super(ExperimentManager, self).__init__()
self.algo = algo
Expand All @@ -97,6 +98,7 @@ def __init__(
self.save_freq = save_freq
self.eval_freq = eval_freq
self.n_eval_episodes = n_eval_episodes
self.n_eval_envs = n_eval_envs

self.n_envs = 1 # it will be updated when reading hyperparams
self.n_actions = None # For DDPG/TD3 action noise objects
Expand Down Expand Up @@ -400,7 +402,7 @@ def create_callbacks(self):

save_vec_normalize = SaveVecNormalizeCallback(save_freq=1, save_path=self.params_path)
eval_callback = EvalCallback(
self.create_envs(1, eval_env=True),
self.create_envs(self.n_eval_envs, eval_env=True),
callback_on_new_best=save_vec_normalize,
best_model_save_path=self.save_path,
n_eval_episodes=self.n_eval_episodes,
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0a6
1.1.0a10

0 comments on commit 8f41f40

Please sign in to comment.