Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import sys
import threading
from collections import Counter
from dataclasses import asdict
from dataclasses import asdict, dataclass, field, fields, MISSING as DATACLASS_MISSING
from itertools import groupby
from pathlib import Path
from pprint import pformat
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torchx.specs as specs
from torchx.cli.argparse_util import ArgOnceAction, torchxconfig_run
Expand All @@ -25,6 +25,7 @@
from torchx.runner import config, get_runner, Runner
from torchx.runner.config import load_sections
from torchx.schedulers import get_default_scheduler_name, get_scheduler_factories
from torchx.specs import CfgVal
from torchx.specs.finder import (
_Component,
ComponentNotFoundException,
Expand All @@ -44,6 +45,68 @@
logger: logging.Logger = logging.getLogger(__name__)


@dataclass
class TorchXRunArgs:
component_name: str
scheduler: str
scheduler_args: Dict[str, Any]
scheduler_cfg: Dict[str, CfgVal] = field(default_factory=dict)
dryrun: bool = False
wait: bool = False
log: bool = False
workspace: str = f"file://{Path.cwd()}"
parent_run_id: Optional[str] = None
tee_logs: bool = False
component_args: Dict[str, Any] = field(default_factory=dict)
component_args_str: List[str] = field(default_factory=list)


def torchx_run_args_from_json(json_data: Dict[str, Any]) -> TorchXRunArgs:
all_fields = [f.name for f in fields(TorchXRunArgs)]
required_fields = {
f.name
for f in fields(TorchXRunArgs)
if f.default is DATACLASS_MISSING and f.default_factory is DATACLASS_MISSING
}
missing_fields = required_fields - json_data.keys()
if missing_fields:
raise ValueError(
f"The following required fields are missing: {', '.join(missing_fields)}"
)

# Fail if there are fields that aren't part of the run command
filtered_json_data = {k: v for k, v in json_data.items() if k in all_fields}
extra_fields = set(json_data.keys()) - set(all_fields)
if extra_fields:
raise ValueError(
f"The following fields are not part of the run command: {', '.join(extra_fields)}.",
"Please check your JSON and try launching again.",
)

return TorchXRunArgs(**filtered_json_data)


def torchx_run_args_from_argparse(
args: argparse.Namespace,
component_name: str,
component_args: List[str],
scheduler_cfg: Dict[str, CfgVal],
) -> TorchXRunArgs:
return TorchXRunArgs(
component_name=component_name,
scheduler=args.scheduler,
scheduler_args={},
scheduler_cfg=scheduler_cfg,
dryrun=args.dryrun,
wait=args.wait,
log=args.log,
workspace=args.workspace,
parent_run_id=args.parent_run_id,
tee_logs=args.tee_logs,
component_args_str=component_args,
)


def _parse_component_name_and_args(
component_name_and_args: List[str],
subparser: argparse.ArgumentParser,
Expand Down
196 changes: 193 additions & 3 deletions torchx/cli/test/cmd_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,30 @@
import argparse
import dataclasses
import io

import os
import shutil
import signal
import tempfile
import unittest
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
from typing import Dict, Generator
from unittest.mock import MagicMock, patch

from torchx.cli.argparse_util import ArgOnceAction, torchxconfig
from torchx.cli.cmd_run import _parse_component_name_and_args, CmdBuiltins, CmdRun
from torchx.cli.cmd_run import (
_parse_component_name_and_args,
CmdBuiltins,
CmdRun,
torchx_run_args_from_argparse,
torchx_run_args_from_json,
TorchXRunArgs,
)
from torchx.runner.config import ENV_TORCHXCONFIG
from torchx.schedulers.local_scheduler import SignalException

from torchx.specs import AppDryRunInfo
from torchx.specs import AppDryRunInfo, CfgVal


@contextmanager
Expand Down Expand Up @@ -384,3 +392,185 @@ def test_print_builtin(self) -> None:

cmd_builtins.run(parser.parse_args(["--print", "dist.ddp"]))
# nothing to assert, just make sure it runs


class TorchXRunArgsTest(unittest.TestCase):
def test_torchx_run_args_from_json(self) -> None:
# Test valid input with all required fields
json_data = {
"scheduler": "local",
"scheduler_args": {"cluster": "test"},
"component_name": "test_component",
}
result = torchx_run_args_from_json(json_data)

self.assertIsInstance(result, TorchXRunArgs)
self.assertEqual(result.scheduler, "local")
self.assertEqual(result.scheduler_args, {"cluster": "test"})
self.assertEqual(result.component_name, "test_component")
# Check defaults are set
self.assertEqual(result.dryrun, False)
self.assertEqual(result.wait, False)
self.assertEqual(result.log, False)
self.assertEqual(result.workspace, f"file://{Path.cwd()}")
self.assertEqual(result.parent_run_id, None)
self.assertEqual(result.tee_logs, False)
self.assertEqual(result.component_args, {})
self.assertEqual(result.component_args_str, [])

# Test valid input with optional fields provided
json_data_with_optionals = {
"scheduler": "k8s",
"scheduler_args": {"namespace": "default"},
"component_name": "my_component",
"component_args": {"param": "test"},
"component_args_str": ["--param test"],
"dryrun": True,
"wait": True,
"log": True,
"workspace": "file:///custom/path",
"parent_run_id": "parent123",
"tee_logs": True,
}
result2 = torchx_run_args_from_json(json_data_with_optionals)

self.assertEqual(result2.scheduler, "k8s")
self.assertEqual(result2.scheduler_args, {"namespace": "default"})
self.assertEqual(result2.component_name, "my_component")
self.assertEqual(result2.component_args, {"param": "test"})
self.assertEqual(result2.component_args_str, ["--param test"])
self.assertEqual(result2.dryrun, True)
self.assertEqual(result2.wait, True)
self.assertEqual(result2.log, True)
self.assertEqual(result2.workspace, "file:///custom/path")
self.assertEqual(result2.parent_run_id, "parent123")
self.assertEqual(result2.tee_logs, True)

# Test missing required field - scheduler
json_missing_scheduler = {
"scheduler_args": {"cluster": "test"},
"component_name": "test_component",
}
with self.assertRaises(ValueError) as cm:
torchx_run_args_from_json(json_missing_scheduler)
self.assertEqual(
"The following required fields are missing: scheduler", cm.exception.args[0]
)

# Test missing required field - component_name
json_missing_component = {
"scheduler": "local",
"scheduler_args": {"cluster": "test"},
}
with self.assertRaises(ValueError) as cm:
torchx_run_args_from_json(json_missing_component)
self.assertEqual(
"The following required fields are missing: component_name",
cm.exception.args[0],
)

# Test missing required field - scheduler_args
json_missing_scheduler_args = {
"scheduler": "local",
"component_name": "test_component",
}
with self.assertRaises(ValueError) as cm:
torchx_run_args_from_json(json_missing_scheduler_args)
self.assertEqual(
"The following required fields are missing: scheduler_args",
cm.exception.args[0],
)

# Test missing multiple required fields
json_missing_multiple = {"dryrun": True, "wait": False}
with self.assertRaises(ValueError) as cm:
torchx_run_args_from_json(json_missing_multiple)
error_msg = str(cm.exception)
self.assertIn("The following required fields are missing:", error_msg)
self.assertIn("scheduler", error_msg)
self.assertIn("scheduler_args", error_msg)
self.assertIn("component_name", error_msg)

# Test unknown fields cause ValueError
json_with_unknown = {
"scheduler": "local",
"scheduler_args": {"cluster": "test"},
"component_name": "test_component",
"component_args": {"arg1": "value1"},
"unknown_field": "should_be_ignored",
"another_unknown": 123,
}

with self.assertRaises(ValueError) as cm:
torchx_run_args_from_json(json_with_unknown)
self.assertIn(
"The following fields are not part of the run command:",
cm.exception.args[0],
)
self.assertIn("unknown_field", cm.exception.args[0])
self.assertIn("another_unknown", cm.exception.args[0])

# Test empty JSON
with self.assertRaises(ValueError) as cm:
torchx_run_args_from_json({})
self.assertIn("The following required fields are missing:", str(cm.exception))

# Test minimal valid input (only required fields)
json_minimal = {
"scheduler": "local",
"scheduler_args": {},
"component_name": "minimal_component",
}
result4 = torchx_run_args_from_json(json_minimal)

self.assertEqual(result4.scheduler, "local")
self.assertEqual(result4.scheduler_args, {})
self.assertEqual(result4.component_name, "minimal_component")
self.assertEqual(result4.component_args, {})
self.assertEqual(result4.component_args_str, [])

def test_torchx_run_args_from_argparse(self) -> None:
# This test case isn't as important, since if the dataclass is being
# init with argparse, the argparsing will have handled most of the missing
# logic etc
# Create a mock argparse.Namespace object
args = argparse.Namespace()
args.scheduler = "k8s"
args.dryrun = True
args.wait = False
args.log = True
args.workspace = "file:///custom/workspace"
args.parent_run_id = "parent_123"
args.tee_logs = False

component_name = "test_component"
component_args = ["--param1", "value1", "--param2", "value2"]
scheduler_cfg: Dict[str, CfgVal] = {
"cluster": "test_cluster",
"namespace": "default",
}

result = torchx_run_args_from_argparse(
args=args,
component_name=component_name,
component_args=component_args,
scheduler_cfg=scheduler_cfg,
)

self.assertIsInstance(result, TorchXRunArgs)
self.assertEqual(result.component_name, "test_component")
self.assertEqual(result.scheduler, "k8s")
self.assertEqual(result.scheduler_args, {})
self.assertEqual(
result.scheduler_cfg, {"cluster": "test_cluster", "namespace": "default"}
)
self.assertEqual(result.dryrun, True)
self.assertEqual(result.wait, False)
self.assertEqual(result.log, True)
self.assertEqual(result.workspace, "file:///custom/workspace")
self.assertEqual(result.parent_run_id, "parent_123")
self.assertEqual(result.tee_logs, False)
self.assertEqual(result.component_args, {})
self.assertEqual(
result.component_args_str, ["--param1", "value1", "--param2", "value2"]
)
Loading