From e21cfcba4c31ddff404f5a3e463aa2c337e4612a Mon Sep 17 00:00:00 2001 From: fan-ziqi Date: Sun, 25 Aug 2024 23:21:14 +0800 Subject: [PATCH] Adds the Hydra configuration system for RL training --- scripts/rsl_rl/cli_args.py | 33 +++++++++++++++++++++++---------- scripts/rsl_rl/train.py | 31 ++++++++++++++++--------------- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/scripts/rsl_rl/cli_args.py b/scripts/rsl_rl/cli_args.py index ea91c7af..d0689c65 100644 --- a/scripts/rsl_rl/cli_args.py +++ b/scripts/rsl_rl/cli_args.py @@ -47,23 +47,36 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol # load the default configuration rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point") + rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli) + return rslrl_cfg + + +def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace): + """Update configuration for RSL-RL agent based on inputs. + Args: + agent_cfg: The configuration for RSL-RL agent. + args_cli: The command line arguments. + + Returns: + The updated configuration for RSL-RL agent based on inputs. + """ # override the default configuration with CLI arguments if args_cli.seed is not None: - rslrl_cfg.seed = args_cli.seed + agent_cfg.seed = args_cli.seed if args_cli.resume is not None: - rslrl_cfg.resume = args_cli.resume + agent_cfg.resume = args_cli.resume if args_cli.load_run is not None: - rslrl_cfg.load_run = args_cli.load_run + agent_cfg.load_run = args_cli.load_run if args_cli.checkpoint is not None: - rslrl_cfg.load_checkpoint = args_cli.checkpoint + agent_cfg.load_checkpoint = args_cli.checkpoint if args_cli.run_name is not None: - rslrl_cfg.run_name = args_cli.run_name + agent_cfg.run_name = args_cli.run_name if args_cli.logger is not None: - rslrl_cfg.logger = args_cli.logger + agent_cfg.logger = args_cli.logger # set the project name for wandb and neptune - if rslrl_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name: - rslrl_cfg.wandb_project = args_cli.log_project_name - rslrl_cfg.neptune_project = args_cli.log_project_name + if agent_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name: + agent_cfg.wandb_project = args_cli.log_project_name + agent_cfg.neptune_project = args_cli.log_project_name - return rslrl_cfg + return agent_cfg diff --git a/scripts/rsl_rl/train.py b/scripts/rsl_rl/train.py index c1638831..171ecb5f 100644 --- a/scripts/rsl_rl/train.py +++ b/scripts/rsl_rl/train.py @@ -3,6 +3,7 @@ """Launch Isaac Sim Simulator first.""" import argparse +import sys from omni.isaac.lab.app import AppLauncher @@ -14,9 +15,6 @@ parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.") parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).") parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).") -parser.add_argument( - "--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations." -) parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.") parser.add_argument("--task", type=str, default=None, help="Name of the task.") parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment") @@ -25,11 +23,15 @@ cli_args.add_rsl_rl_args(parser) # append AppLauncher cli args AppLauncher.add_app_launcher_args(parser) -args_cli = parser.parse_args() +args_cli, hydra_args = parser.parse_known_args() + # always enable cameras to record video if args_cli.video: args_cli.enable_cameras = True +# clear out sys.argv for Hydra +sys.argv = [sys.argv[0]] + hydra_args + # launch omniverse app app_launcher = AppLauncher(args_cli) simulation_app = app_launcher.app @@ -46,10 +48,11 @@ # Import extensions to set up environment tasks import ext_template.tasks # noqa: F401 -from omni.isaac.lab.envs import ManagerBasedRLEnvCfg +from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg from omni.isaac.lab.utils.dict import print_dict from omni.isaac.lab.utils.io import dump_pickle, dump_yaml -from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg +from omni.isaac.lab_tasks.utils import get_checkpoint_path +from omni.isaac.lab_tasks.utils.hydra import hydra_task_config from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper torch.backends.cuda.matmul.allow_tf32 = True @@ -58,13 +61,15 @@ torch.backends.cudnn.benchmark = False -def main(): +@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point") +def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg): """Train with RSL-RL agent.""" - # parse configuration - env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg( - args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric + # override configurations with non-hydra CLI arguments + agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli) + env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs + agent_cfg.max_iterations = ( + args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg.max_iterations ) - agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli) # specify directory for logging experiments log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name) @@ -76,10 +81,6 @@ def main(): log_dir += f"_{agent_cfg.run_name}" log_dir = os.path.join(log_root_path, log_dir) - # max iterations for training - if args_cli.max_iterations: - agent_cfg.max_iterations = args_cli.max_iterations - # create isaac environment env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None) # wrap for video recording