Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 112 additions & 20 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import argparse
import json
import logging
import os
import sys
Expand Down Expand Up @@ -41,6 +42,12 @@
"missing component name, either provide it from the CLI or in .torchxconfig"
)

LOCAL_SCHEDULER_WARNING_MSG = (
"`local` scheduler is deprecated and will be"
" removed in the near future,"
" please use other variants of the local scheduler"
" (e.g. `local_cwd`)"
)

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -54,7 +61,7 @@ class TorchXRunArgs:
dryrun: bool = False
wait: bool = False
log: bool = False
workspace: str = f"file://{Path.cwd()}"
workspace: str = ""
parent_run_id: Optional[str] = None
tee_logs: bool = False
component_args: Dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -83,7 +90,10 @@ def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs:
"Please check your JSON and try launching again.",
)

return TorchXRunArgs(**filtered_json_data)
torchx_args = TorchXRunArgs(**filtered_json_data)
if torchx_args.workspace == "":
torchx_args.workspace = f"file://{Path.cwd()}"
return torchx_args


def torchx_run_args_from_argparse(
Expand Down Expand Up @@ -256,35 +266,35 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
default=False,
help="Add additional prefix to log lines to indicate which replica is printing the log",
)
subparser.add_argument(
"--stdin",
action="store_true",
default=False,
help="Read JSON input from stdin to parse into torchx run args and run the component.",
)
subparser.add_argument(
"component_name_and_args",
nargs=argparse.REMAINDER,
)

def _run(self, runner: Runner, args: argparse.Namespace) -> None:
def _run_inner(self, runner: Runner, args: TorchXRunArgs) -> None:
if args.scheduler == "local":
logger.warning(
"`local` scheduler is deprecated and will be"
" removed in the near future,"
" please use other variants of the local scheduler"
" (e.g. `local_cwd`)"
)

cfg = dict(runner.cfg_from_str(args.scheduler, args.scheduler_args))
config.apply(scheduler=args.scheduler, cfg=cfg)
logger.warning(LOCAL_SCHEDULER_WARNING_MSG)

component, component_args = _parse_component_name_and_args(
args.component_name_and_args,
none_throws(self._subparser),
config.apply(scheduler=args.scheduler, cfg=args.scheduler_cfg)
component_args = (
args.component_args_str
if args.component_args_str != []
else args.component_args
)
try:
if args.dryrun:
dryrun_info = runner.dryrun_component(
component,
args.component_name,
component_args,
args.scheduler,
workspace=args.workspace,
cfg=cfg,
cfg=args.scheduler_cfg,
parent_run_id=args.parent_run_id,
)
print(
Expand All @@ -295,11 +305,11 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
print("\n=== SCHEDULER REQUEST ===\n" f"{dryrun_info}")
else:
app_handle = runner.run_component(
component,
args.component_name,
component_args,
args.scheduler,
workspace=args.workspace,
cfg=cfg,
cfg=args.scheduler_cfg,
parent_run_id=args.parent_run_id,
)
# DO NOT delete this line. It is used by slurm tests to retrieve the app id
Expand All @@ -320,7 +330,9 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
)

except (ComponentValidationException, ComponentNotFoundException) as e:
error_msg = f"\nFailed to run component `{component}` got errors: \n {e}"
error_msg = (
f"\nFailed to run component `{args.component_name}` got errors: \n {e}"
)
logger.error(error_msg)
sys.exit(1)
except specs.InvalidRunConfigException as e:
Expand All @@ -335,6 +347,86 @@ def _run(self, runner: Runner, args: argparse.Namespace) -> None:
print(error_msg % (e, args.scheduler, args.scheduler), file=sys.stderr)
sys.exit(1)

def _run_from_cli_args(self, runner: Runner, args: argparse.Namespace) -> None:
scheduler_opts = runner.scheduler_run_opts(args.scheduler)
cfg = scheduler_opts.cfg_from_str(args.scheduler_args)

component, component_args = _parse_component_name_and_args(
args.component_name_and_args,
none_throws(self._subparser),
)
torchx_run_args = torchx_run_args_from_argparse(
args, component, component_args, cfg
)
self._run_inner(runner, torchx_run_args)

def _run_from_stdin_args(self, runner: Runner, stdin_data: Dict[str, Any]) -> None:
torchx_run_args = torchx_run_args_from_json(stdin_data)
scheduler_opts = runner.scheduler_run_opts(torchx_run_args.scheduler)
cfg = scheduler_opts.cfg_from_json_repr(
json.dumps(torchx_run_args.scheduler_args)
)
torchx_run_args.scheduler_cfg = cfg
self._run_inner(runner, torchx_run_args)

def torchx_json_from_stdin(self) -> Dict[str, Any]:
try:
stdin_data_json = json.load(sys.stdin)
if not isinstance(stdin_data_json, dict):
logger.error(
"Invalid JSON input for `torchx run` command. Expected a dictionary."
)
sys.exit(1)
return stdin_data_json
except (json.JSONDecodeError, EOFError):
logger.error(
"Unable to parse JSON input for `torchx run` command, please make sure it's a valid JSON input."
)
sys.exit(1)

def verify_no_extra_args(self, args: argparse.Namespace) -> None:
"""
Verifies that only --stdin was provided when using stdin mode.
"""
if not args.stdin:
return

subparser = none_throws(self._subparser)
conflicting_args = []

# Check each argument against its default value
for action in subparser._actions:
if action.dest == "stdin": # Skip stdin itself
continue
if action.dest == "help": # Skip help
continue

current_value = getattr(args, action.dest, None)
default_value = action.default

# For arguments that differ from default
if current_value != default_value:
# Handle special cases where non-default doesn't mean explicitly set
if action.dest == "component_name_and_args" and current_value == []:
continue # Empty list is still default
print(f"*********\n {default_value} = {current_value}")
conflicting_args.append(f"--{action.dest.replace('_', '-')}")

if conflicting_args:
subparser.error(
f"Cannot specify {', '.join(conflicting_args)} when using --stdin. "
"All configuration should be provided in JSON input."
)

def _run(self, runner: Runner, args: argparse.Namespace) -> None:
# Verify no conflicting arguments when using to loop over the stdin
self.verify_no_extra_args(args)
if args.stdin:
stdin_data_json = self.torchx_json_from_stdin()
self._run_from_stdin_args(runner, stdin_data_json)
else:
self._run_from_cli_args(runner, args)

def run(self, args: argparse.Namespace) -> None:
os.environ["TORCHX_CONTEXT_NAME"] = os.getenv("TORCHX_CONTEXT_NAME", "cli_run")
component_defaults = load_sections(prefix="component")
Expand Down
105 changes: 96 additions & 9 deletions torchx/cli/test/cmd_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import argparse
import dataclasses
import io

import os
import shutil
import signal
Expand Down Expand Up @@ -67,14 +66,12 @@ def tearDown(self) -> None:
torchxconfig.called_args = set()

def test_run_with_multiple_scheduler_args(self) -> None:

args = ["--scheduler_args", "first_args", "--scheduler_args", "second_args"]
with self.assertRaises(SystemExit) as cm:
self.parser.parse_args(args)
self.assertEqual(cm.exception.code, 1)

def test_run_with_multiple_schedule_args(self) -> None:

args = [
"--scheduler",
"local_cwd",
Expand Down Expand Up @@ -179,13 +176,13 @@ def test_conf_file_missing(self) -> None:
with patch(
"torchx.runner.config.DEFAULT_CONFIG_DIRS", return_value=[self.tmpdir]
):
args = self.parser.parse_args(
[
"--scheduler",
"local_cwd",
]
)
with self.assertRaises(SystemExit):
args = self.parser.parse_args(
[
"--scheduler",
"local_cwd",
]
)
self.cmd_run.run(args)

@patch("torchx.runner.Runner.run")
Expand Down Expand Up @@ -364,6 +361,96 @@ def test_parse_component_name_and_args_with_default(self) -> None:
_parse_component_name_and_args(["-m", "hello"], sp, dirs),
)

def test_verify_no_extra_args_stdin_only(self) -> None:
"""Test that only --stdin is allowed when using stdin mode."""
args = self.parser.parse_args(["--stdin"])
# Should not raise any exception
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_no_stdin(self) -> None:
"""Test that verification is skipped when not using stdin."""
args = self.parser.parse_args(["--scheduler", "local_cwd", "utils.echo"])
# Should not raise any exception
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_component_name(self) -> None:
"""Test that component name conflicts with stdin."""
args = self.parser.parse_args(["--stdin", "utils.echo"])
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_scheduler_args(self) -> None:
"""Test that scheduler_args conflicts with stdin."""
args = self.parser.parse_args(["--stdin", "--scheduler_args", "cluster=test"])
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_scheduler(self) -> None:
"""Test that non-default scheduler conflicts with stdin."""
args = self.parser.parse_args(["--stdin", "--scheduler", "kubernetes"])
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_boolean_flags(self) -> None:
"""Test that boolean flags conflict with stdin."""
boolean_flags = ["--dryrun", "--wait", "--log", "--tee_logs"]
for flag in boolean_flags:
args = self.parser.parse_args(["--stdin", flag])
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_value_args(self) -> None:
"""Test that arguments with values conflict with stdin."""
args = self.parser.parse_args(["--stdin", "--workspace", "file:///custom/path"])
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

args = self.parser.parse_args(["--stdin", "--parent_run_id", "experiment_123"])
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_multiple_conflicts(self) -> None:
"""Test that multiple conflicting arguments with stdin are detected."""
args = self.parser.parse_args(
["--stdin", "--dryrun", "--wait", "--scheduler_args", "cluster=test"]
)
with self.assertRaises(SystemExit):
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_default_scheduler(self) -> None:
"""Test that using default scheduler with stdin doesn't conflict."""
# Get the default scheduler and use it explicitly - should not conflict
from torchx.schedulers import get_default_scheduler_name

default_scheduler = get_default_scheduler_name()

args = self.parser.parse_args(["--stdin", "--scheduler", default_scheduler])
# Should not raise any exception since it's the same as default
self.cmd_run.verify_no_extra_args(args)

def test_verify_no_extra_args_stdin_with_default_workspace(self) -> None:
"""Test that using default workspace with stdin doesn't conflict."""
# Get the actual default workspace from a fresh parser
fresh_parser = argparse.ArgumentParser()
fresh_cmd_run = CmdRun()
fresh_cmd_run.add_arguments(fresh_parser)

# Find the workspace argument's default value
workspace_default = None
for action in fresh_parser._actions:
if action.dest == "workspace":
workspace_default = action.default
break

self.assertIsNotNone(
workspace_default, "workspace argument should have a default"
)

# Use the actual default - this should not conflict with stdin
args = fresh_parser.parse_args(["--stdin", "--workspace", workspace_default])
# Should not raise any exception since it's the same as default
fresh_cmd_run.verify_no_extra_args(args)


class CmdBuiltinTest(unittest.TestCase):
def test_run(self) -> None:
Expand Down
10 changes: 7 additions & 3 deletions torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Type,
TYPE_CHECKING,
TypeVar,
Union,
)

from torchx.runner.events import log_event
Expand Down Expand Up @@ -167,7 +168,7 @@ def close(self) -> None:
def run_component(
self,
component: str,
component_args: List[str],
component_args: Union[list[str], dict[str, Any]],
scheduler: str,
cfg: Optional[Mapping[str, CfgVal]] = None,
workspace: Optional[str] = None,
Expand Down Expand Up @@ -226,7 +227,7 @@ def run_component(
def dryrun_component(
self,
component: str,
component_args: List[str],
component_args: Union[list[str], dict[str, Any]],
scheduler: str,
cfg: Optional[Mapping[str, CfgVal]] = None,
workspace: Optional[str] = None,
Expand All @@ -237,10 +238,13 @@ def dryrun_component(
component, but just returns what "would" have run.
"""
component_def = get_component(component)
args_from_cli = component_args if isinstance(component_args, list) else []
args_from_json = component_args if isinstance(component_args, dict) else {}
app = materialize_appdef(
component_def.fn,
component_args,
args_from_cli,
self._component_defaults.get(component, None),
args_from_json,
)
return self.dryrun(
app,
Expand Down
3 changes: 3 additions & 0 deletions torchx/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,8 @@ def gpu_x_1() -> Dict[str, Resource]:
"make_app_handle",
"materialize_appdef",
"parse_mounts",
"torchx_run_args_from_argparse",
"torchx_run_args_from_json",
"TorchXRunArgs",
"ALL",
]
Loading
Loading