-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a training example using RLLib (#72)
* fix remote env problems * finish this script * update this example
- Loading branch information
PENG Zhenghao
committed
Sep 13, 2021
1 parent
df2c115
commit 65cb8e4
Showing
6 changed files
with
238 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
######################## | ||
Training with RLLib | ||
######################## | ||
|
||
|
||
We provide a script demonstrating how to use `RLLib <https://docs.ray.io/en/latest/rllib.html>`_ to | ||
train RL agents: | ||
|
||
.. code-block:: shell | ||
# Make sure current folder does not have a sub-folder named metadrive | ||
python -m metadrive.examples.train_generalization_experiment | ||
# You can also use GPUs and customized experiment name: | ||
python -m metadrive.examples.train_generalization_experiment \ | ||
--exp-name CUSTOMIZED_EXP_NAME \ | ||
--num-gpus HOW_MANY_GPUS_IN_THIS_MACHINES | ||
In this example, we leave the training hyper-parameter :code:`config["num_envs_per_worker"] = 1` as default, so that each process (ray worker) will only contain one MetaDrive instance. | ||
We further set the evaluation workers :code:`config["evaluation_num_workers"] = 5`, so that the test set environments are hosted in separated processes. | ||
By utilizing the feature of RLLib, we avoid the issue of multiple MetaDrive instances in single process. | ||
|
||
We welcome more examples using MetaDrive in different context! Please show off your code if you like to share it by opening new issue! Thanks! | ||
|
||
.. note:: We tested this script using :code:`ray==1.2.0`. If you find this script not compatible with newer RLLib, please contact us. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
""" | ||
This script demonstrates how to train a set of policies under different number of training scenarios and test them | ||
in the same test set using rllib. | ||
We verified this script with ray==1.2.0. Please report to use if you find newer version of ray is not compatible with | ||
this script. | ||
""" | ||
import argparse | ||
import copy | ||
from typing import Dict | ||
|
||
import numpy as np | ||
|
||
from metadrive import MetaDriveEnv | ||
|
||
try: | ||
import ray | ||
from ray import tune | ||
|
||
from ray.tune import CLIReporter | ||
from ray.rllib.agents.callbacks import DefaultCallbacks | ||
from ray.rllib.env import BaseEnv | ||
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker | ||
from ray.rllib.policy import Policy | ||
except ImportError: | ||
ray = None | ||
raise ValueError("Please install ray through 'pip install ray'.") | ||
|
||
|
||
class DrivingCallbacks(DefaultCallbacks): | ||
def on_episode_start( | ||
self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, | ||
env_index: int, **kwargs | ||
): | ||
episode.user_data["velocity"] = [] | ||
episode.user_data["steering"] = [] | ||
episode.user_data["step_reward"] = [] | ||
episode.user_data["acceleration"] = [] | ||
episode.user_data["cost"] = [] | ||
|
||
def on_episode_step( | ||
self, *, worker: RolloutWorker, base_env: BaseEnv, episode: MultiAgentEpisode, env_index: int, **kwargs | ||
): | ||
info = episode.last_info_for() | ||
if info is not None: | ||
episode.user_data["velocity"].append(info["velocity"]) | ||
episode.user_data["steering"].append(info["steering"]) | ||
episode.user_data["step_reward"].append(info["step_reward"]) | ||
episode.user_data["acceleration"].append(info["acceleration"]) | ||
episode.user_data["cost"].append(info["cost"]) | ||
|
||
def on_episode_end( | ||
self, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy], episode: MultiAgentEpisode, | ||
**kwargs | ||
): | ||
arrive_dest = episode.last_info_for()["arrive_dest"] | ||
crash = episode.last_info_for()["crash"] | ||
out_of_road = episode.last_info_for()["out_of_road"] | ||
max_step_rate = not (arrive_dest or crash or out_of_road) | ||
episode.custom_metrics["success_rate"] = float(arrive_dest) | ||
episode.custom_metrics["crash_rate"] = float(crash) | ||
episode.custom_metrics["out_of_road_rate"] = float(out_of_road) | ||
episode.custom_metrics["max_step_rate"] = float(max_step_rate) | ||
episode.custom_metrics["velocity_max"] = float(np.max(episode.user_data["velocity"])) | ||
episode.custom_metrics["velocity_mean"] = float(np.mean(episode.user_data["velocity"])) | ||
episode.custom_metrics["velocity_min"] = float(np.min(episode.user_data["velocity"])) | ||
episode.custom_metrics["steering_max"] = float(np.max(episode.user_data["steering"])) | ||
episode.custom_metrics["steering_mean"] = float(np.mean(episode.user_data["steering"])) | ||
episode.custom_metrics["steering_min"] = float(np.min(episode.user_data["steering"])) | ||
episode.custom_metrics["acceleration_min"] = float(np.min(episode.user_data["acceleration"])) | ||
episode.custom_metrics["acceleration_mean"] = float(np.mean(episode.user_data["acceleration"])) | ||
episode.custom_metrics["acceleration_max"] = float(np.max(episode.user_data["acceleration"])) | ||
episode.custom_metrics["step_reward_max"] = float(np.max(episode.user_data["step_reward"])) | ||
episode.custom_metrics["step_reward_mean"] = float(np.mean(episode.user_data["step_reward"])) | ||
episode.custom_metrics["step_reward_min"] = float(np.min(episode.user_data["step_reward"])) | ||
episode.custom_metrics["cost"] = float(sum(episode.user_data["cost"])) | ||
|
||
def on_train_result(self, *, trainer, result: dict, **kwargs): | ||
result["success"] = np.nan | ||
result["crash"] = np.nan | ||
result["out"] = np.nan | ||
result["max_step"] = np.nan | ||
result["length"] = result["episode_len_mean"] | ||
result["cost"] = np.nan | ||
if "custom_metrics" not in result: | ||
return | ||
|
||
if "success_rate_mean" in result["custom_metrics"]: | ||
result["success"] = result["custom_metrics"]["success_rate_mean"] | ||
result["crash"] = result["custom_metrics"]["crash_rate_mean"] | ||
result["out"] = result["custom_metrics"]["out_of_road_rate_mean"] | ||
result["max_step"] = result["custom_metrics"]["max_step_rate_mean"] | ||
if "cost_mean" in result["custom_metrics"]: | ||
result["cost"] = result["custom_metrics"]["cost_mean"] | ||
|
||
|
||
def train( | ||
trainer, | ||
config, | ||
stop, | ||
exp_name, | ||
num_gpus=0, | ||
test_mode=False, | ||
checkpoint_freq=10, | ||
keep_checkpoints_num=None, | ||
custom_callback=None, | ||
max_failures=5, | ||
**kwargs | ||
): | ||
ray.init(num_gpus=num_gpus) | ||
used_config = { | ||
"callbacks": custom_callback if custom_callback else DrivingCallbacks, # Must Have! | ||
} | ||
used_config.update(config) | ||
config = copy.deepcopy(used_config) | ||
|
||
if not isinstance(stop, dict) and stop is not None: | ||
assert np.isscalar(stop) | ||
stop = {"timesteps_total": int(stop)} | ||
|
||
if keep_checkpoints_num is not None and not test_mode: | ||
assert isinstance(keep_checkpoints_num, int) | ||
kwargs["keep_checkpoints_num"] = keep_checkpoints_num | ||
kwargs["checkpoint_score_attr"] = "episode_reward_mean" | ||
|
||
metric_columns = CLIReporter.DEFAULT_COLUMNS.copy() | ||
progress_reporter = CLIReporter(metric_columns) | ||
progress_reporter.add_metric_column("success") | ||
progress_reporter.add_metric_column("crash") | ||
progress_reporter.add_metric_column("out") | ||
progress_reporter.add_metric_column("max_step") | ||
progress_reporter.add_metric_column("length") | ||
progress_reporter.add_metric_column("cost") | ||
kwargs["progress_reporter"] = progress_reporter | ||
|
||
# start training | ||
analysis = tune.run( | ||
trainer, | ||
name=exp_name, | ||
checkpoint_freq=checkpoint_freq, | ||
checkpoint_at_end=True if "checkpoint_at_end" not in kwargs else kwargs.pop("checkpoint_at_end"), | ||
stop=stop, | ||
config=config, | ||
max_failures=max_failures if not test_mode else 0, | ||
reuse_actors=False, | ||
local_dir=".", | ||
**kwargs | ||
) | ||
return analysis | ||
|
||
|
||
def get_train_parser(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--exp-name", type=str, default="generalization_experiment") | ||
parser.add_argument("--num-gpus", type=int, default=0) | ||
return parser | ||
|
||
|
||
if __name__ == '__main__': | ||
args = get_train_parser().parse_args() | ||
exp_name = args.exp_name | ||
stop = int(1000_0000) | ||
config = dict( | ||
|
||
# ===== Training Environment ===== | ||
# Train the policies in scenario sets with different number of scenarios. | ||
env=MetaDriveEnv, | ||
env_config=dict( | ||
environment_num=tune.grid_search([1, 5, 10, 20, 50, 100, 300, 1000]), | ||
start_seed=tune.grid_search([5000, 6000, 7000]), | ||
random_traffic=False, | ||
), | ||
|
||
# ===== Evaluation ===== | ||
# Evaluate the trained policies in unseen 200 scenarios. | ||
evaluation_interval=2, | ||
evaluation_num_episodes=40, | ||
metrics_smoothing_episodes=200, | ||
evaluation_config=dict(env_config=dict(environment_num=200, start_seed=0)), | ||
evaluation_num_workers=5, | ||
|
||
# ===== Training ===== | ||
# Hyper-parameters for PPO | ||
horizon=1000, | ||
rollout_fragment_length=200, | ||
sgd_minibatch_size=256, | ||
train_batch_size=20000, | ||
num_sgd_iter=10, | ||
lr=3e-4, | ||
num_workers=5, | ||
**{"lambda": 0.95}, | ||
|
||
# ===== Resources Specification ===== | ||
num_gpus=0.25 if args.num_gpus != 0 else 0, | ||
num_cpus_per_worker=0.2, | ||
num_cpus_for_driver=0.5, | ||
) | ||
|
||
train( | ||
"PPO", | ||
exp_name=exp_name, | ||
keep_checkpoints_num=5, | ||
stop=stop, | ||
config=config, | ||
num_gpus=args.num_gpus, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters