Skip to content
This repository was archived by the owner on Sep 26, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions scripts/rsl_rl/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,36 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol

# load the default configuration
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli)
return rslrl_cfg


def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace):
"""Update configuration for RSL-RL agent based on inputs.

Args:
agent_cfg: The configuration for RSL-RL agent.
args_cli: The command line arguments.

Returns:
The updated configuration for RSL-RL agent based on inputs.
"""
# override the default configuration with CLI arguments
if args_cli.seed is not None:
rslrl_cfg.seed = args_cli.seed
agent_cfg.seed = args_cli.seed
if args_cli.resume is not None:
rslrl_cfg.resume = args_cli.resume
agent_cfg.resume = args_cli.resume
if args_cli.load_run is not None:
rslrl_cfg.load_run = args_cli.load_run
agent_cfg.load_run = args_cli.load_run
if args_cli.checkpoint is not None:
rslrl_cfg.load_checkpoint = args_cli.checkpoint
agent_cfg.load_checkpoint = args_cli.checkpoint
if args_cli.run_name is not None:
rslrl_cfg.run_name = args_cli.run_name
agent_cfg.run_name = args_cli.run_name
if args_cli.logger is not None:
rslrl_cfg.logger = args_cli.logger
agent_cfg.logger = args_cli.logger
# set the project name for wandb and neptune
if rslrl_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name:
rslrl_cfg.wandb_project = args_cli.log_project_name
rslrl_cfg.neptune_project = args_cli.log_project_name
if agent_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name:
agent_cfg.wandb_project = args_cli.log_project_name
agent_cfg.neptune_project = args_cli.log_project_name

return rslrl_cfg
return agent_cfg
31 changes: 16 additions & 15 deletions scripts/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Launch Isaac Sim Simulator first."""

import argparse
import sys

from omni.isaac.lab.app import AppLauncher

Expand All @@ -14,9 +15,6 @@
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
parser.add_argument(
"--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations."
)
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
Expand All @@ -25,11 +23,15 @@
cli_args.add_rsl_rl_args(parser)
# append AppLauncher cli args
AppLauncher.add_app_launcher_args(parser)
args_cli = parser.parse_args()
args_cli, hydra_args = parser.parse_known_args()

# always enable cameras to record video
if args_cli.video:
args_cli.enable_cameras = True

# clear out sys.argv for Hydra
sys.argv = [sys.argv[0]] + hydra_args

# launch omniverse app
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
Expand All @@ -46,10 +48,11 @@
# Import extensions to set up environment tasks
import ext_template.tasks # noqa: F401

from omni.isaac.lab.envs import ManagerBasedRLEnvCfg
from omni.isaac.lab.envs import DirectRLEnvCfg, ManagerBasedRLEnvCfg
from omni.isaac.lab.utils.dict import print_dict
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
from omni.isaac.lab_tasks.utils import get_checkpoint_path, parse_env_cfg
from omni.isaac.lab_tasks.utils import get_checkpoint_path
from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
from omni.isaac.lab_tasks.utils.wrappers.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper

torch.backends.cuda.matmul.allow_tf32 = True
Expand All @@ -58,13 +61,15 @@
torch.backends.cudnn.benchmark = False


def main():
@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
"""Train with RSL-RL agent."""
# parse configuration
env_cfg: ManagerBasedRLEnvCfg = parse_env_cfg(
args_cli.task, device=args_cli.device, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
# override configurations with non-hydra CLI arguments
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
agent_cfg.max_iterations = (
args_cli.max_iterations if args_cli.max_iterations is not None else agent_cfg.max_iterations
)
agent_cfg: RslRlOnPolicyRunnerCfg = cli_args.parse_rsl_rl_cfg(args_cli.task, args_cli)

# specify directory for logging experiments
log_root_path = os.path.join("logs", "rsl_rl", agent_cfg.experiment_name)
Expand All @@ -76,10 +81,6 @@ def main():
log_dir += f"_{agent_cfg.run_name}"
log_dir = os.path.join(log_root_path, log_dir)

# max iterations for training
if args_cli.max_iterations:
agent_cfg.max_iterations = args_cli.max_iterations

# create isaac environment
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
# wrap for video recording
Expand Down