Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,16 @@ def _get_torchx_stdin_args(
if not args.stdin:
return None
if self._stdin_data_json is None:
self._stdin_data_json = self.torchx_json_from_stdin()
self._stdin_data_json = self.torchx_json_from_stdin(args)
return self._stdin_data_json

def torchx_json_from_stdin(self) -> Dict[str, Any]:
def torchx_json_from_stdin(
self, args: Optional[argparse.Namespace] = None
) -> Dict[str, Any]:
try:
stdin_data_json = json.load(sys.stdin)
if args and args.dryrun:
stdin_data_json["dryrun"] = True
if not isinstance(stdin_data_json, dict):
logger.error(
"Invalid JSON input for `torchx run` command. Expected a dictionary."
Expand Down Expand Up @@ -413,6 +417,8 @@ def verify_no_extra_args(self, args: argparse.Namespace) -> None:
continue
if action.dest == "help": # Skip help
continue
if action.dest == "dryrun": # Skip dryrun
continue

current_value = getattr(args, action.dest, None)
default_value = action.default
Expand Down
7 changes: 6 additions & 1 deletion torchx/cli/test/cmd_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,12 +393,17 @@ def test_verify_no_extra_args_stdin_with_scheduler(self) -> None:

def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None:
"""Test that boolean flags conflict with stdin."""
boolean_flags = ["--dryrun", "--wait", "--log", "--tee_logs"]
boolean_flags = ["--wait", "--log", "--tee_logs"]
for flag in boolean_flags:
args = self.parser.parse_args(["--stdin", flag])
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_dryrun_pass(self) -> None:
"""Test that dryrun is allowed."""
args = self.parser.parse_args(["--stdin", "--dryrun"])
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_value_args(self) -> None:
"""Test that arguments with values conflict with stdin."""
args = self.parser.parse_args(["--stdin", "--workspace", "/custom/path"])
Expand Down
Loading