Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 9 additions & 68 deletions torchx/cli/cmd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
# LICENSE file in the root directory of this source tree.

import argparse
import ast
import glob
import importlib
import os
from dataclasses import dataclass
from inspect import getmembers, isfunction
from typing import Dict, Iterable, List, Optional, Type, Union
from typing import Dict, List, Optional, Union

import torchx.specs as specs
from pyre_extensions import none_throws
Expand All @@ -32,63 +31,6 @@ def parse_args_children(arg: str) -> Dict[str, Union[str, List[str]]]:
return conf


class UnsupportFeatureError(Exception):
def __init__(self, name: str) -> None:
super().__init__(f"Using unsupported feature {name} in config.")


class ConfValidator(ast.NodeVisitor):
IMPORT_ALLOWLIST: Iterable[str] = (
"torchx",
"torchelastic.tsm",
"os.path",
"pytorch.elastic.torchelastic.tsm",
)

FEATURE_BLOCKLIST: Iterable[Type[object]] = (
# statements
ast.FunctionDef,
ast.ClassDef,
ast.Return,
ast.Delete,
ast.For,
ast.AsyncFor,
ast.While,
ast.If,
ast.With,
ast.AsyncWith,
ast.Raise,
ast.Try,
ast.Global,
ast.Nonlocal,
# expressions
ast.ListComp,
ast.SetComp,
ast.DictComp,
# ast.GeneratorExp,
)

def visit(self, node: ast.AST) -> None:
if node.__class__ in self.FEATURE_BLOCKLIST:
raise UnsupportFeatureError(node.__class__.__name__)

super().visit(node)

def _validate_import_path(self, names: List[str]) -> None:
for name in names:
if not any(name.startswith(prefix) for prefix in self.IMPORT_ALLOWLIST):
raise ImportError(
f"import {name} not in allowed import prefixes {self.IMPORT_ALLOWLIST}"
)

def visit_Import(self, node: ast.Import) -> None:
self._validate_import_path([alias.name for alias in node.names])

def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if module := node.module:
self._validate_import_path([module])


def _parse_run_config(arg: str) -> specs.RunConfig:
conf = specs.RunConfig()
for key, value in parse_args_children(arg).items():
Expand Down Expand Up @@ -122,18 +64,17 @@ def _get_component_definition(module: str, function_name: str) -> str:
return f"{module}.{function_name}"


def _get_components_from_file(filepath: str) -> List[BuiltinComponent]:
if filepath.endswith("__init__.py"):
return []
def _to_relative(filepath: str) -> str:
if os.path.isabs(filepath):
if str(COMPONENTS_DIR) not in filepath:
return []
# make path torchx/components/$suffix out of the abs
rel_path = filepath.split(str(COMPONENTS_DIR))[1]
if rel_path[1:].startswith("test"):
return []
components_path = f"{str(COMPONENTS_DIR)}{rel_path}"
return f"{str(COMPONENTS_DIR)}{rel_path}"
else:
components_path = os.path.join(str(COMPONENTS_DIR), filepath)
return os.path.join(str(COMPONENTS_DIR), filepath)


def _get_components_from_file(filepath: str) -> List[BuiltinComponent]:
components_path = _to_relative(filepath)
components_module_path = _to_module(components_path)
module = importlib.import_module(components_module_path)
functions = getmembers(module, isfunction)
Expand Down
17 changes: 4 additions & 13 deletions torchx/cli/test/cmd_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,12 @@ def setUp(self) -> None:
def tearDown(self) -> None:
shutil.rmtree(self.tmpdir)

def test_run_with_builtin(self) -> None:
foobar_txt = str(self.tmpdir / "foobar.txt")
args = self.parser.parse_args(
["--scheduler", "local", "tests.touch", "--file", foobar_txt]
)

self.cmd_run.run(args)
self.assertTrue(os.path.isfile(foobar_txt))

def test_run_with_user_conf_abs_path(self) -> None:
args = self.parser.parse_args(
[
"--scheduler",
"local",
str(Path(__file__).parent / "examples/test.py:touch"),
str(Path(__file__).parent / "components.py:touch"),
"--file",
str(self.tmpdir / "foobar.txt"),
]
Expand All @@ -62,12 +53,12 @@ def test_run_with_user_conf_abs_path(self) -> None:

def test_run_with_relpath(self) -> None:
# should pick up test/examples/touch.torchx (not the builtin)
with cwd(str(Path(__file__).parent / "examples")):
with cwd(str(Path(__file__).parent)):
args = self.parser.parse_args(
[
"--scheduler",
"local",
"tests.touch_v2",
str(Path(__file__).parent / "components.py:touch_v2"),
"--file",
str(self.tmpdir / "foobar.txt"),
]
Expand Down Expand Up @@ -95,7 +86,7 @@ def test_run_dryrun(self, mock_runner_run: MagicMock) -> None:
"--verbose",
"--scheduler",
"local",
"tests.echo",
"utils.echo",
]
)
self.cmd_run.run(args)
Expand Down
19 changes: 0 additions & 19 deletions torchx/cli/test/examples/test.py → torchx/cli/test/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,6 @@
import torchx.specs as specs


def echo(msg: str = "hello world", image: str = "/tmp") -> specs.AppDef:
"""Echos a message

Args:
msg: Message to echo
image: Image to run
"""

echo = specs.Role(
name="echo",
image=image,
entrypoint="/bin/echo",
args=[msg],
num_replicas=1,
)

return specs.AppDef(name="echo", roles=[echo])


def touch(file: str) -> specs.AppDef:
"""Echos a message

Expand Down
8 changes: 4 additions & 4 deletions torchx/cli/test/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

_root: Path = Path(__file__).parent

_SIMPLE_CONF: str = "tests.simple"
_SIMPLE_CONF: str = str(Path(__file__).parent / "components.py:simple")


class CLITest(unittest.TestCase):
Expand All @@ -25,11 +25,11 @@ def test_run_abs_config_path(self) -> None:
"run",
"--scheduler",
"local",
str(_root / "examples" / "test.py:simple"),
str(_root / "components.py:simple"),
"--num_trainers",
"2",
"--trainer_image",
str(_root / "examples" / "container"),
str(_root / "container"),
]
)

Expand All @@ -43,7 +43,7 @@ def test_run_builtin_config(self) -> None:
"--num_trainers",
"2",
"--trainer_image",
str(_root / "examples" / "container"),
str(_root / "container"),
]
)

Expand Down
100 changes: 0 additions & 100 deletions torchx/components/tests.py

This file was deleted.

1 change: 0 additions & 1 deletion torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def run_from_path(
AppNotReRunnableException: if the session/scheduler does not support re-running attached apps
ValueError: if the ``component_path`` is failed to resolve.
"""

app_fn = entrypoints.load("torchx.components", component_path, default=NONE)
if app_fn != NONE:
app = from_function(app_fn, app_args)
Expand Down