Skip to content

Commit

Permalink
[RLlib] DreamerV3: rllib train and README.md fixes. (ray-project#38259
Browse files Browse the repository at this point in the history
)

Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
  • Loading branch information
sven1977 authored and arvind-chandra committed Aug 31, 2023
1 parent a139326 commit 50b9323
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
6 changes: 3 additions & 3 deletions rllib/algorithms/dreamerv3/README.md
Expand Up @@ -62,13 +62,13 @@ in combination with the following scripts and command lines in order to run RLli
### Atari100k
```shell
$ cd ray/rllib
$ rllib train --file ../tuned_examples/dreamerv3/atari_100k.py --env ALE/Pong-v5
$ rllib train file tuned_examples/dreamerv3/atari_100k.py --env ALE/Pong-v5
```

### DeepMind Control Suite (vision)
```shell
$ cd ray/rllib/tests
$ rllib train --file ../tuned_examples/dreamerv3/dm_control_suite_vision.py --env DMC/cartpole/swingup
$ cd ray/rllib
$ rllib train file tuned_examples/dreamerv3/dm_control_suite_vision.py --env DMC/cartpole/swingup
```
Other `--env` options for the DM Control Suite would be `--env DMC/hopper/hop`, `--env DMC/walker/walk`, etc..
Note that you can also switch on WandB logging with the above script via the options
Expand Down
17 changes: 13 additions & 4 deletions rllib/common.py
Expand Up @@ -144,13 +144,15 @@ def get_help(key: str) -> str:
ray_num_nodes="Emulate multiple cluster nodes for debugging.",
ray_object_store_memory="--object-store-memory to use if starting a new cluster.",
upload_dir="Optional URI to sync training results to (e.g. s3://bucket).",
trace="Whether to attempt to enable tracing for eager mode.",
trace="Whether to attempt to enable eager-tracing for framework=tf2.",
torch="Whether to use PyTorch (instead of tf) as the DL framework. "
"This argument is deprecated, please use --framework to select 'torch'"
"as backend.",
eager="Whether to attempt to enable TensorFlow eager execution. "
"This argument is deprecated, please choose 'tf2' in "
"--framework to run select eager mode.",
wandb_key="An optional WandB API key for logging all results to your WandB "
"account.",
wandb_project="An optional project name under which to store the training results.",
wandb_run_name="An optional name for the specific run under which to store the "
"training results.",
)


Expand Down Expand Up @@ -247,6 +249,13 @@ class CLIArguments:
RayObjectStoreMemory = typer.Option(
None, help=train_help.get("ray_object_store_memory")
)
WandBKey = typer.Option(None, "--wandb-key", help=train_help.get("wandb_key"))
WandBProject = typer.Option(
None, "--wandb-project", help=eval_help.get("wandb_project")
)
WandBRunName = typer.Option(
None, "--wandb-run-name", help=eval_help.get("wandb_run_name")
)
# __cli_train_end__

# Eval arguments
Expand Down
13 changes: 10 additions & 3 deletions rllib/train.py
Expand Up @@ -132,6 +132,8 @@ def file(
config_file: str = cli.ConfigFile,
# stopping conditions
stop: Optional[str] = cli.Stop,
# Environment override.
env: Optional[str] = cli.Env,
# Checkpointing
checkpoint_freq: int = cli.CheckpointFreq,
checkpoint_at_end: bool = cli.CheckpointAtEnd,
Expand All @@ -143,9 +145,9 @@ def file(
framework: FrameworkEnum = cli.Framework,
trace: bool = cli.Trace,
# WandB options.
wandb_key: Optional[str] = None,
wandb_project: Optional[str] = None,
wandb_run_name: Optional[str] = None,
wandb_key: Optional[str] = cli.WandBKey,
wandb_project: Optional[str] = cli.WandBProject,
wandb_run_name: Optional[str] = cli.WandBRunName,
# Ray cluster options.
local_mode: bool = cli.LocalMode,
ray_address: Optional[str] = cli.RayAddress,
Expand Down Expand Up @@ -193,6 +195,11 @@ def file(
experiment = experiments[exp_name]
algo = experiment["run"]

# Override the env from the config by the value given on the command line.
if env is not None:
experiment["env"] = env

# WandB logging support.
callbacks = None
if wandb_key is not None:
project = wandb_project or (
Expand Down

0 comments on commit 50b9323

Please sign in to comment.