Skip to content

Commit 33fca61

Browse files
authored
Add a dataclass to represent torchx run command args (#1105)
Differential Revision: D80279871 Pull Request resolved: #1106
1 parent ca3aa34 commit 33fca61

File tree

2 files changed

+258
-5
lines changed

2 files changed

+258
-5
lines changed

torchx/cli/cmd_run.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import sys
1313
import threading
1414
from collections import Counter
15-
from dataclasses import asdict
15+
from dataclasses import asdict, dataclass, field, fields, MISSING as DATACLASS_MISSING
1616
from itertools import groupby
1717
from pathlib import Path
1818
from pprint import pformat
19-
from typing import Dict, List, Optional, Tuple
19+
from typing import Any, Dict, List, Optional, Tuple
2020

2121
import torchx.specs as specs
2222
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run
@@ -25,6 +25,7 @@
2525
from torchx.runner import config, get_runner, Runner
2626
from torchx.runner.config import load_sections
2727
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
28+
from torchx.specs import CfgVal
2829
from torchx.specs.finder import (
2930
_Component,
3031
ComponentNotFoundException,
@@ -44,6 +45,68 @@
4445
logger: logging.Logger = logging.getLogger(__name__)
4546

4647

48+
@dataclass
49+
class TorchXRunArgs:
50+
component_name: str
51+
scheduler: str
52+
scheduler_args: Dict[str, Any]
53+
scheduler_cfg: Dict[str, CfgVal] = field(default_factory=dict)
54+
dryrun: bool = False
55+
wait: bool = False
56+
log: bool = False
57+
workspace: str = f"file://{Path.cwd()}"
58+
parent_run_id: Optional[str] = None
59+
tee_logs: bool = False
60+
component_args: Dict[str, Any] = field(default_factory=dict)
61+
component_args_str: List[str] = field(default_factory=list)
62+
63+
64+
def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs:
65+
all_fields = [f.name for f in fields(TorchXRunArgs)]
66+
required_fields = {
67+
f.name
68+
for f in fields(TorchXRunArgs)
69+
if f.default is DATACLASS_MISSING and f.default_factory is DATACLASS_MISSING
70+
}
71+
missing_fields = required_fields - json_data.keys()
72+
if missing_fields:
73+
raise ValueError(
74+
f"The following required fields are missing: {', '.join(missing_fields)}"
75+
)
76+
77+
# Fail if there are fields that aren't part of the run command
78+
filtered_json_data = {k: v for k, v in json_data.items() if k in all_fields}
79+
extra_fields = set(json_data.keys()) - set(all_fields)
80+
if extra_fields:
81+
raise ValueError(
82+
f"The following fields are not part of the run command: {', '.join(extra_fields)}.",
83+
"Please check your JSON and try launching again.",
84+
)
85+
86+
return TorchXRunArgs(**filtered_json_data)
87+
88+
89+
def torchx_run_args_from_argparse(
90+
args: argparse.Namespace,
91+
component_name: str,
92+
component_args: List[str],
93+
scheduler_cfg: Dict[str, CfgVal],
94+
) -> TorchXRunArgs:
95+
return TorchXRunArgs(
96+
component_name=component_name,
97+
scheduler=args.scheduler,
98+
scheduler_args={},
99+
scheduler_cfg=scheduler_cfg,
100+
dryrun=args.dryrun,
101+
wait=args.wait,
102+
log=args.log,
103+
workspace=args.workspace,
104+
parent_run_id=args.parent_run_id,
105+
tee_logs=args.tee_logs,
106+
component_args_str=component_args,
107+
)
108+
109+
47110
def _parse_component_name_and_args(
48111
component_name_and_args: List[str],
49112
subparser: argparse.ArgumentParser,

torchx/cli/test/cmd_run_test.py

Lines changed: 193 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,30 @@
1010
import argparse
1111
import dataclasses
1212
import io
13+
1314
import os
1415
import shutil
1516
import signal
1617
import tempfile
1718
import unittest
1819
from contextlib import contextmanager
1920
from pathlib import Path
20-
from typing import Generator
21+
from typing import Dict, Generator
2122
from unittest.mock import MagicMock, patch
2223

2324
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig
24-
from torchx.cli.cmd_run import _parse_component_name_and_args, CmdBuiltins, CmdRun
25+
from torchx.cli.cmd_run import (
26+
_parse_component_name_and_args,
27+
CmdBuiltins,
28+
CmdRun,
29+
torchx_run_args_from_argparse,
30+
torchx_run_args_from_json,
31+
TorchXRunArgs,
32+
)
2533
from torchx.runner.config import ENV_TORCHXCONFIG
2634
from torchx.schedulers.local_scheduler import SignalException
2735

28-
from torchx.specs import AppDryRunInfo
36+
from torchx.specs import AppDryRunInfo, CfgVal
2937

3038

3139
@contextmanager
@@ -384,3 +392,185 @@ def test_print_builtin(self) -> None:
384392

385393
cmd_builtins.run(parser.parse_args(["--print", "dist.ddp"]))
386394
# nothing to assert, just make sure it runs
395+
396+
397+
class TorchXRunArgsTest(unittest.TestCase):
398+
def test_torchx_run_args_from_json(self) -> None:
399+
# Test valid input with all required fields
400+
json_data = {
401+
"scheduler": "local",
402+
"scheduler_args": {"cluster": "test"},
403+
"component_name": "test_component",
404+
}
405+
result = torchx_run_args_from_json(json_data)
406+
407+
self.assertIsInstance(result, TorchXRunArgs)
408+
self.assertEqual(result.scheduler, "local")
409+
self.assertEqual(result.scheduler_args, {"cluster": "test"})
410+
self.assertEqual(result.component_name, "test_component")
411+
# Check defaults are set
412+
self.assertEqual(result.dryrun, False)
413+
self.assertEqual(result.wait, False)
414+
self.assertEqual(result.log, False)
415+
self.assertEqual(result.workspace, f"file://{Path.cwd()}")
416+
self.assertEqual(result.parent_run_id, None)
417+
self.assertEqual(result.tee_logs, False)
418+
self.assertEqual(result.component_args, {})
419+
self.assertEqual(result.component_args_str, [])
420+
421+
# Test valid input with optional fields provided
422+
json_data_with_optionals = {
423+
"scheduler": "k8s",
424+
"scheduler_args": {"namespace": "default"},
425+
"component_name": "my_component",
426+
"component_args": {"param": "test"},
427+
"component_args_str": ["--param test"],
428+
"dryrun": True,
429+
"wait": True,
430+
"log": True,
431+
"workspace": "file:///custom/path",
432+
"parent_run_id": "parent123",
433+
"tee_logs": True,
434+
}
435+
result2 = torchx_run_args_from_json(json_data_with_optionals)
436+
437+
self.assertEqual(result2.scheduler, "k8s")
438+
self.assertEqual(result2.scheduler_args, {"namespace": "default"})
439+
self.assertEqual(result2.component_name, "my_component")
440+
self.assertEqual(result2.component_args, {"param": "test"})
441+
self.assertEqual(result2.component_args_str, ["--param test"])
442+
self.assertEqual(result2.dryrun, True)
443+
self.assertEqual(result2.wait, True)
444+
self.assertEqual(result2.log, True)
445+
self.assertEqual(result2.workspace, "file:///custom/path")
446+
self.assertEqual(result2.parent_run_id, "parent123")
447+
self.assertEqual(result2.tee_logs, True)
448+
449+
# Test missing required field - scheduler
450+
json_missing_scheduler = {
451+
"scheduler_args": {"cluster": "test"},
452+
"component_name": "test_component",
453+
}
454+
with self.assertRaises(ValueError) as cm:
455+
torchx_run_args_from_json(json_missing_scheduler)
456+
self.assertEqual(
457+
"The following required fields are missing: scheduler", cm.exception.args[0]
458+
)
459+
460+
# Test missing required field - component_name
461+
json_missing_component = {
462+
"scheduler": "local",
463+
"scheduler_args": {"cluster": "test"},
464+
}
465+
with self.assertRaises(ValueError) as cm:
466+
torchx_run_args_from_json(json_missing_component)
467+
self.assertEqual(
468+
"The following required fields are missing: component_name",
469+
cm.exception.args[0],
470+
)
471+
472+
# Test missing required field - scheduler_args
473+
json_missing_scheduler_args = {
474+
"scheduler": "local",
475+
"component_name": "test_component",
476+
}
477+
with self.assertRaises(ValueError) as cm:
478+
torchx_run_args_from_json(json_missing_scheduler_args)
479+
self.assertEqual(
480+
"The following required fields are missing: scheduler_args",
481+
cm.exception.args[0],
482+
)
483+
484+
# Test missing multiple required fields
485+
json_missing_multiple = {"dryrun": True, "wait": False}
486+
with self.assertRaises(ValueError) as cm:
487+
torchx_run_args_from_json(json_missing_multiple)
488+
error_msg = str(cm.exception)
489+
self.assertIn("The following required fields are missing:", error_msg)
490+
self.assertIn("scheduler", error_msg)
491+
self.assertIn("scheduler_args", error_msg)
492+
self.assertIn("component_name", error_msg)
493+
494+
# Test unknown fields cause ValueError
495+
json_with_unknown = {
496+
"scheduler": "local",
497+
"scheduler_args": {"cluster": "test"},
498+
"component_name": "test_component",
499+
"component_args": {"arg1": "value1"},
500+
"unknown_field": "should_be_ignored",
501+
"another_unknown": 123,
502+
}
503+
504+
with self.assertRaises(ValueError) as cm:
505+
torchx_run_args_from_json(json_with_unknown)
506+
self.assertIn(
507+
"The following fields are not part of the run command:",
508+
cm.exception.args[0],
509+
)
510+
self.assertIn("unknown_field", cm.exception.args[0])
511+
self.assertIn("another_unknown", cm.exception.args[0])
512+
513+
# Test empty JSON
514+
with self.assertRaises(ValueError) as cm:
515+
torchx_run_args_from_json({})
516+
self.assertIn("The following required fields are missing:", str(cm.exception))
517+
518+
# Test minimal valid input (only required fields)
519+
json_minimal = {
520+
"scheduler": "local",
521+
"scheduler_args": {},
522+
"component_name": "minimal_component",
523+
}
524+
result4 = torchx_run_args_from_json(json_minimal)
525+
526+
self.assertEqual(result4.scheduler, "local")
527+
self.assertEqual(result4.scheduler_args, {})
528+
self.assertEqual(result4.component_name, "minimal_component")
529+
self.assertEqual(result4.component_args, {})
530+
self.assertEqual(result4.component_args_str, [])
531+
532+
def test_torchx_run_args_from_argparse(self) -> None:
533+
# This test case isn't as important, since if the dataclass is being
534+
# init with argparse, the argparsing will have handled most of the missing
535+
# logic etc
536+
# Create a mock argparse.Namespace object
537+
args = argparse.Namespace()
538+
args.scheduler = "k8s"
539+
args.dryrun = True
540+
args.wait = False
541+
args.log = True
542+
args.workspace = "file:///custom/workspace"
543+
args.parent_run_id = "parent_123"
544+
args.tee_logs = False
545+
546+
component_name = "test_component"
547+
component_args = ["--param1", "value1", "--param2", "value2"]
548+
scheduler_cfg: Dict[str, CfgVal] = {
549+
"cluster": "test_cluster",
550+
"namespace": "default",
551+
}
552+
553+
result = torchx_run_args_from_argparse(
554+
args=args,
555+
component_name=component_name,
556+
component_args=component_args,
557+
scheduler_cfg=scheduler_cfg,
558+
)
559+
560+
self.assertIsInstance(result, TorchXRunArgs)
561+
self.assertEqual(result.component_name, "test_component")
562+
self.assertEqual(result.scheduler, "k8s")
563+
self.assertEqual(result.scheduler_args, {})
564+
self.assertEqual(
565+
result.scheduler_cfg, {"cluster": "test_cluster", "namespace": "default"}
566+
)
567+
self.assertEqual(result.dryrun, True)
568+
self.assertEqual(result.wait, False)
569+
self.assertEqual(result.log, True)
570+
self.assertEqual(result.workspace, "file:///custom/workspace")
571+
self.assertEqual(result.parent_run_id, "parent_123")
572+
self.assertEqual(result.tee_logs, False)
573+
self.assertEqual(result.component_args, {})
574+
self.assertEqual(
575+
result.component_args_str, ["--param1", "value1", "--param2", "value2"]
576+
)

0 commit comments

Comments
 (0)