diff --git a/torchx/cli/cmd_run.py b/torchx/cli/cmd_run.py index 787ee13dd..dd774203c 100644 --- a/torchx/cli/cmd_run.py +++ b/torchx/cli/cmd_run.py @@ -12,11 +12,11 @@ import sys import threading from collections import Counter -from dataclasses import asdict +from dataclasses import asdict, dataclass, field, fields, MISSING as DATACLASS_MISSING from itertools import groupby from pathlib import Path from pprint import pformat -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torchx.specs as specs from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run @@ -25,6 +25,7 @@ from torchx.runner import config, get_runner, Runner from torchx.runner.config import load_sections from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories +from torchx.specs import CfgVal from torchx.specs.finder import ( _Component, ComponentNotFoundException, @@ -44,6 +45,68 @@ logger: logging.Logger = logging.getLogger(__name__) +@dataclass +class TorchXRunArgs: + component_name: str + scheduler: str + scheduler_args: Dict[str, Any] + scheduler_cfg: Dict[str, CfgVal] = field(default_factory=dict) + dryrun: bool = False + wait: bool = False + log: bool = False + workspace: str = f"file://{Path.cwd()}" + parent_run_id: Optional[str] = None + tee_logs: bool = False + component_args: Dict[str, Any] = field(default_factory=dict) + component_args_str: List[str] = field(default_factory=list) + + +def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs: + all_fields = [f.name for f in fields(TorchXRunArgs)] + required_fields = { + f.name + for f in fields(TorchXRunArgs) + if f.default is DATACLASS_MISSING and f.default_factory is DATACLASS_MISSING + } + missing_fields = required_fields - json_data.keys() + if missing_fields: + raise ValueError( + f"The following required fields are missing: {', '.join(missing_fields)}" + ) + + # Fail if there are fields that aren't part of the run command + filtered_json_data = {k: v for k, v in json_data.items() if k in all_fields} + extra_fields = set(json_data.keys()) - set(all_fields) + if extra_fields: + raise ValueError( + f"The following fields are not part of the run command: {', '.join(extra_fields)}.", + "Please check your JSON and try launching again.", + ) + + return TorchXRunArgs(**filtered_json_data) + + +def torchx_run_args_from_argparse( + args: argparse.Namespace, + component_name: str, + component_args: List[str], + scheduler_cfg: Dict[str, CfgVal], +) -> TorchXRunArgs: + return TorchXRunArgs( + component_name=component_name, + scheduler=args.scheduler, + scheduler_args={}, + scheduler_cfg=scheduler_cfg, + dryrun=args.dryrun, + wait=args.wait, + log=args.log, + workspace=args.workspace, + parent_run_id=args.parent_run_id, + tee_logs=args.tee_logs, + component_args_str=component_args, + ) + + def _parse_component_name_and_args( component_name_and_args: List[str], subparser: argparse.ArgumentParser, diff --git a/torchx/cli/test/cmd_run_test.py b/torchx/cli/test/cmd_run_test.py index 38b3d4b43..ed64f8fee 100644 --- a/torchx/cli/test/cmd_run_test.py +++ b/torchx/cli/test/cmd_run_test.py @@ -10,6 +10,7 @@ import argparse import dataclasses import io + import os import shutil import signal @@ -17,15 +18,22 @@ import unittest from contextlib import contextmanager from pathlib import Path -from typing import Generator +from typing import Dict, Generator from unittest.mock import MagicMock, patch from torchx.cli.argparse_util import ArgOnceAction, torchxconfig -from torchx.cli.cmd_run import _parse_component_name_and_args, CmdBuiltins, CmdRun +from torchx.cli.cmd_run import ( + _parse_component_name_and_args, + CmdBuiltins, + CmdRun, + torchx_run_args_from_argparse, + torchx_run_args_from_json, + TorchXRunArgs, +) from torchx.runner.config import ENV_TORCHXCONFIG from torchx.schedulers.local_scheduler import SignalException -from torchx.specs import AppDryRunInfo +from torchx.specs import AppDryRunInfo, CfgVal @contextmanager @@ -384,3 +392,185 @@ def test_print_builtin(self) -> None: cmd_builtins.run(parser.parse_args(["--print", "dist.ddp"])) # nothing to assert, just make sure it runs + + +class TorchXRunArgsTest(unittest.TestCase): + def test_torchx_run_args_from_json(self) -> None: + # Test valid input with all required fields + json_data = { + "scheduler": "local", + "scheduler_args": {"cluster": "test"}, + "component_name": "test_component", + } + result = torchx_run_args_from_json(json_data) + + self.assertIsInstance(result, TorchXRunArgs) + self.assertEqual(result.scheduler, "local") + self.assertEqual(result.scheduler_args, {"cluster": "test"}) + self.assertEqual(result.component_name, "test_component") + # Check defaults are set + self.assertEqual(result.dryrun, False) + self.assertEqual(result.wait, False) + self.assertEqual(result.log, False) + self.assertEqual(result.workspace, f"file://{Path.cwd()}") + self.assertEqual(result.parent_run_id, None) + self.assertEqual(result.tee_logs, False) + self.assertEqual(result.component_args, {}) + self.assertEqual(result.component_args_str, []) + + # Test valid input with optional fields provided + json_data_with_optionals = { + "scheduler": "k8s", + "scheduler_args": {"namespace": "default"}, + "component_name": "my_component", + "component_args": {"param": "test"}, + "component_args_str": ["--param test"], + "dryrun": True, + "wait": True, + "log": True, + "workspace": "file:///custom/path", + "parent_run_id": "parent123", + "tee_logs": True, + } + result2 = torchx_run_args_from_json(json_data_with_optionals) + + self.assertEqual(result2.scheduler, "k8s") + self.assertEqual(result2.scheduler_args, {"namespace": "default"}) + self.assertEqual(result2.component_name, "my_component") + self.assertEqual(result2.component_args, {"param": "test"}) + self.assertEqual(result2.component_args_str, ["--param test"]) + self.assertEqual(result2.dryrun, True) + self.assertEqual(result2.wait, True) + self.assertEqual(result2.log, True) + self.assertEqual(result2.workspace, "file:///custom/path") + self.assertEqual(result2.parent_run_id, "parent123") + self.assertEqual(result2.tee_logs, True) + + # Test missing required field - scheduler + json_missing_scheduler = { + "scheduler_args": {"cluster": "test"}, + "component_name": "test_component", + } + with self.assertRaises(ValueError) as cm: + torchx_run_args_from_json(json_missing_scheduler) + self.assertEqual( + "The following required fields are missing: scheduler", cm.exception.args[0] + ) + + # Test missing required field - component_name + json_missing_component = { + "scheduler": "local", + "scheduler_args": {"cluster": "test"}, + } + with self.assertRaises(ValueError) as cm: + torchx_run_args_from_json(json_missing_component) + self.assertEqual( + "The following required fields are missing: component_name", + cm.exception.args[0], + ) + + # Test missing required field - scheduler_args + json_missing_scheduler_args = { + "scheduler": "local", + "component_name": "test_component", + } + with self.assertRaises(ValueError) as cm: + torchx_run_args_from_json(json_missing_scheduler_args) + self.assertEqual( + "The following required fields are missing: scheduler_args", + cm.exception.args[0], + ) + + # Test missing multiple required fields + json_missing_multiple = {"dryrun": True, "wait": False} + with self.assertRaises(ValueError) as cm: + torchx_run_args_from_json(json_missing_multiple) + error_msg = str(cm.exception) + self.assertIn("The following required fields are missing:", error_msg) + self.assertIn("scheduler", error_msg) + self.assertIn("scheduler_args", error_msg) + self.assertIn("component_name", error_msg) + + # Test unknown fields cause ValueError + json_with_unknown = { + "scheduler": "local", + "scheduler_args": {"cluster": "test"}, + "component_name": "test_component", + "component_args": {"arg1": "value1"}, + "unknown_field": "should_be_ignored", + "another_unknown": 123, + } + + with self.assertRaises(ValueError) as cm: + torchx_run_args_from_json(json_with_unknown) + self.assertIn( + "The following fields are not part of the run command:", + cm.exception.args[0], + ) + self.assertIn("unknown_field", cm.exception.args[0]) + self.assertIn("another_unknown", cm.exception.args[0]) + + # Test empty JSON + with self.assertRaises(ValueError) as cm: + torchx_run_args_from_json({}) + self.assertIn("The following required fields are missing:", str(cm.exception)) + + # Test minimal valid input (only required fields) + json_minimal = { + "scheduler": "local", + "scheduler_args": {}, + "component_name": "minimal_component", + } + result4 = torchx_run_args_from_json(json_minimal) + + self.assertEqual(result4.scheduler, "local") + self.assertEqual(result4.scheduler_args, {}) + self.assertEqual(result4.component_name, "minimal_component") + self.assertEqual(result4.component_args, {}) + self.assertEqual(result4.component_args_str, []) + + def test_torchx_run_args_from_argparse(self) -> None: + # This test case isn't as important, since if the dataclass is being + # init with argparse, the argparsing will have handled most of the missing + # logic etc + # Create a mock argparse.Namespace object + args = argparse.Namespace() + args.scheduler = "k8s" + args.dryrun = True + args.wait = False + args.log = True + args.workspace = "file:///custom/workspace" + args.parent_run_id = "parent_123" + args.tee_logs = False + + component_name = "test_component" + component_args = ["--param1", "value1", "--param2", "value2"] + scheduler_cfg: Dict[str, CfgVal] = { + "cluster": "test_cluster", + "namespace": "default", + } + + result = torchx_run_args_from_argparse( + args=args, + component_name=component_name, + component_args=component_args, + scheduler_cfg=scheduler_cfg, + ) + + self.assertIsInstance(result, TorchXRunArgs) + self.assertEqual(result.component_name, "test_component") + self.assertEqual(result.scheduler, "k8s") + self.assertEqual(result.scheduler_args, {}) + self.assertEqual( + result.scheduler_cfg, {"cluster": "test_cluster", "namespace": "default"} + ) + self.assertEqual(result.dryrun, True) + self.assertEqual(result.wait, False) + self.assertEqual(result.log, True) + self.assertEqual(result.workspace, "file:///custom/workspace") + self.assertEqual(result.parent_run_id, "parent_123") + self.assertEqual(result.tee_logs, False) + self.assertEqual(result.component_args, {}) + self.assertEqual( + result.component_args_str, ["--param1", "value1", "--param2", "value2"] + )