diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index 0c0176abb..9dac52281 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -379,12 +379,16 @@ def _get_torchx_stdin_args( if not args.stdin: return None if self._stdin_data_json is None: - self._stdin_data_json = self.torchx_json_from_stdin() + self._stdin_data_json = self.torchx_json_from_stdin(args) return self._stdin_data_json - def torchx_json_from_stdin(self) -> Dict[str, Any]: + def torchx_json_from_stdin( + self, args: Optional[argparse.Namespace] = None + ) -> Dict[str, Any]: try: stdin_data_json = json.load(sys.stdin) + if args and args.dryrun: + stdin_data_json["dryrun"] = True if not isinstance(stdin_data_json, dict): logger.error( "Invalid JSON input for `torchx run` command. Expected a dictionary." @@ -413,6 +417,8 @@ def verify_no_extra_args(self, args: argparse.Namespace) -> None: continue if action.dest == "help": # Skip help continue + if action.dest == "dryrun": # Skip dryrun + continue current_value = getattr(args, action.dest, None) default_value = action.default diff --git a/torchx/cli/test/cmd_run_test.py b/torchx/cli/test/cmd_run_test.py index 8fc632a67..31bf46cb9 100644 --- a/torchx/cli/test/cmd_run_test.py +++ b/torchx/cli/test/cmd_run_test.py @@ -393,12 +393,17 @@ def test_verify_no_extra_args_stdin_with_scheduler(self) -> None: def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None: """Test that boolean flags conflict with stdin.""" - boolean_flags = ["--dryrun", "--wait", "--log", "--tee_logs"] + boolean_flags = ["--wait", "--log", "--tee_logs"] for flag in boolean_flags: args = self.parser.parse_args(["--stdin", flag]) with self.assertRaises(SystemExit): self.cmd_run.verify_no_extra_args(args) + def test_verify_no_extra_args_stdin_dryrun_pass(self) -> None: + """Test that dryrun is allowed.""" + args = self.parser.parse_args(["--stdin", "--dryrun"]) + self.cmd_run.verify_no_extra_args(args) + def test_verify_no_extra_args_stdin_with_value_args(self) -> None: """Test that arguments with values conflict with stdin.""" args = self.parser.parse_args(["--stdin", "--workspace", "/custom/path"])