From 50b9323c57982099cdeadcb8f554651670fc95ea Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 10 Aug 2023 17:32:30 +0200 Subject: [PATCH] [RLlib] DreamerV3: `rllib train` and README.md fixes. (#38259) Signed-off-by: e428265 --- rllib/algorithms/dreamerv3/README.md | 6 +++--- rllib/common.py | 17 +++++++++++++---- rllib/train.py | 13 ++++++++++--- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/rllib/algorithms/dreamerv3/README.md b/rllib/algorithms/dreamerv3/README.md index 640d16b368f14..3d481f6261363 100644 --- a/rllib/algorithms/dreamerv3/README.md +++ b/rllib/algorithms/dreamerv3/README.md @@ -62,13 +62,13 @@ in combination with the following scripts and command lines in order to run RLli ### Atari100k ```shell $ cd ray/rllib -$ rllib train --file ../tuned_examples/dreamerv3/atari_100k.py --env ALE/Pong-v5 +$ rllib train file tuned_examples/dreamerv3/atari_100k.py --env ALE/Pong-v5 ``` ### DeepMind Control Suite (vision) ```shell -$ cd ray/rllib/tests -$ rllib train --file ../tuned_examples/dreamerv3/dm_control_suite_vision.py --env DMC/cartpole/swingup +$ cd ray/rllib +$ rllib train file tuned_examples/dreamerv3/dm_control_suite_vision.py --env DMC/cartpole/swingup ``` Other `--env` options for the DM Control Suite would be `--env DMC/hopper/hop`, `--env DMC/walker/walk`, etc.. Note that you can also switch on WandB logging with the above script via the options diff --git a/rllib/common.py b/rllib/common.py index a552958101f39..b2b73a44feb5b 100644 --- a/rllib/common.py +++ b/rllib/common.py @@ -144,13 +144,15 @@ def get_help(key: str) -> str: ray_num_nodes="Emulate multiple cluster nodes for debugging.", ray_object_store_memory="--object-store-memory to use if starting a new cluster.", upload_dir="Optional URI to sync training results to (e.g. s3://bucket).", - trace="Whether to attempt to enable tracing for eager mode.", + trace="Whether to attempt to enable eager-tracing for framework=tf2.", torch="Whether to use PyTorch (instead of tf) as the DL framework. " "This argument is deprecated, please use --framework to select 'torch'" "as backend.", - eager="Whether to attempt to enable TensorFlow eager execution. " - "This argument is deprecated, please choose 'tf2' in " - "--framework to run select eager mode.", + wandb_key="An optional WandB API key for logging all results to your WandB " + "account.", + wandb_project="An optional project name under which to store the training results.", + wandb_run_name="An optional name for the specific run under which to store the " + "training results.", ) @@ -247,6 +249,13 @@ class CLIArguments: RayObjectStoreMemory = typer.Option( None, help=train_help.get("ray_object_store_memory") ) + WandBKey = typer.Option(None, "--wandb-key", help=train_help.get("wandb_key")) + WandBProject = typer.Option( + None, "--wandb-project", help=eval_help.get("wandb_project") + ) + WandBRunName = typer.Option( + None, "--wandb-run-name", help=eval_help.get("wandb_run_name") + ) # __cli_train_end__ # Eval arguments diff --git a/rllib/train.py b/rllib/train.py index 566db230e5245..2286c384b7cbf 100755 --- a/rllib/train.py +++ b/rllib/train.py @@ -132,6 +132,8 @@ def file( config_file: str = cli.ConfigFile, # stopping conditions stop: Optional[str] = cli.Stop, + # Environment override. + env: Optional[str] = cli.Env, # Checkpointing checkpoint_freq: int = cli.CheckpointFreq, checkpoint_at_end: bool = cli.CheckpointAtEnd, @@ -143,9 +145,9 @@ def file( framework: FrameworkEnum = cli.Framework, trace: bool = cli.Trace, # WandB options. - wandb_key: Optional[str] = None, - wandb_project: Optional[str] = None, - wandb_run_name: Optional[str] = None, + wandb_key: Optional[str] = cli.WandBKey, + wandb_project: Optional[str] = cli.WandBProject, + wandb_run_name: Optional[str] = cli.WandBRunName, # Ray cluster options. local_mode: bool = cli.LocalMode, ray_address: Optional[str] = cli.RayAddress, @@ -193,6 +195,11 @@ def file( experiment = experiments[exp_name] algo = experiment["run"] + # Override the env from the config by the value given on the command line. + if env is not None: + experiment["env"] = env + + # WandB logging support. callbacks = None if wandb_key is not None: project = wandb_project or (