diff --git a/torchx/specs/file_linter.py b/torchx/specs/file_linter.py index f7f1676b6..90b1df064 100644 --- a/torchx/specs/file_linter.py +++ b/torchx/specs/file_linter.py @@ -244,11 +244,18 @@ class TorchFunctionVisitor(ast.NodeVisitor): """ - def __init__(self, component_function_name: str) -> None: - self.validators = [ - TorchxFunctionArgsValidator(), - TorchxReturnValidator(), - ] + def __init__( + self, + component_function_name: str, + validators: Optional[List[TorchxFunctionValidator]], + ) -> None: + if validators is None: + self.validators: List[TorchxFunctionValidator] = [ + TorchxFunctionArgsValidator(), + TorchxReturnValidator(), + ] + else: + self.validators = validators self.linter_errors: List[LinterMessage] = [] self.component_function_name = component_function_name self.visited_function = False @@ -264,7 +271,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None: self.linter_errors += validator.validate(node) -def validate(path: str, component_function: str) -> List[LinterMessage]: +def validate( + path: str, + component_function: str, + validators: Optional[List[TorchxFunctionValidator]], +) -> List[LinterMessage]: """ Validates the function to make sure it complies the component standard. @@ -293,7 +304,7 @@ def validate(path: str, component_function: str) -> List[LinterMessage]: severity="error", ) return [linter_message] - visitor = TorchFunctionVisitor(component_function) + visitor = TorchFunctionVisitor(component_function, validators) visitor.visit(module) linter_errors = visitor.linter_errors if not visitor.visited_function: diff --git a/torchx/specs/finder.py b/torchx/specs/finder.py index d9f285ade..980aa9473 100644 --- a/torchx/specs/finder.py +++ b/torchx/specs/finder.py @@ -19,7 +19,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union from torchx.specs import AppDef -from torchx.specs.file_linter import get_fn_docstring, validate +from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate from torchx.util import entrypoints from torchx.util.io import read_conf_file from torchx.util.types import none_throws @@ -59,7 +59,9 @@ class _Component: class ComponentsFinder(abc.ABC): @abc.abstractmethod - def find(self) -> List[_Component]: + def find( + self, validators: Optional[List[TorchxFunctionValidator]] + ) -> List[_Component]: """ Retrieves a set of components. A component is defined as a python function that conforms to ``torchx.specs.file_linter`` linter. @@ -203,10 +205,12 @@ def _iter_modules_recursive( else: yield self._try_import(module_info.name) - def find(self) -> List[_Component]: + def find( + self, validators: Optional[List[TorchxFunctionValidator]] + ) -> List[_Component]: components = [] for m in self._iter_modules_recursive(self.base_module): - components += self._get_components_from_module(m) + components += self._get_components_from_module(m, validators) return components def _try_import(self, module: Union[str, ModuleType]) -> ModuleType: @@ -221,7 +225,9 @@ def _try_import(self, module: Union[str, ModuleType]) -> ModuleType: else: return module - def _get_components_from_module(self, module: ModuleType) -> List[_Component]: + def _get_components_from_module( + self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]] + ) -> List[_Component]: functions = getmembers(module, isfunction) component_defs = [] @@ -230,7 +236,7 @@ def _get_components_from_module(self, module: ModuleType) -> List[_Component]: module_path = os.path.abspath(module_path) rel_module_name = module_relname(module, relative_to=self.base_module) for function_name, function in functions: - linter_errors = validate(module_path, function_name) + linter_errors = validate(module_path, function_name, validators) component_desc, _ = get_fn_docstring(function) # remove empty string to deal with group="" @@ -255,13 +261,20 @@ def __init__(self, filepath: str, function_name: str) -> None: self._filepath = filepath self._function_name = function_name - def _get_validation_errors(self, path: str, function_name: str) -> List[str]: - linter_errors = validate(path, function_name) + def _get_validation_errors( + self, + path: str, + function_name: str, + validators: Optional[List[TorchxFunctionValidator]], + ) -> List[str]: + linter_errors = validate(path, function_name, validators) return [linter_error.description for linter_error in linter_errors] - def find(self) -> List[_Component]: + def find( + self, validators: Optional[List[TorchxFunctionValidator]] + ) -> List[_Component]: validation_errors = self._get_validation_errors( - self._filepath, self._function_name + self._filepath, self._function_name, validators ) file_source = read_conf_file(self._filepath) @@ -284,7 +297,9 @@ def find(self) -> List[_Component]: ] -def _load_custom_components() -> List[_Component]: +def _load_custom_components( + validators: Optional[List[TorchxFunctionValidator]], +) -> List[_Component]: component_modules = { name: load_fn() for name, load_fn in @@ -303,11 +318,13 @@ def _load_custom_components() -> List[_Component]: # _0 = torchx.components.dist # _1 = torchx.components.utils group = "" if group.startswith("_") else group - components += ModuleComponentsFinder(module, group).find() + components += ModuleComponentsFinder(module, group).find(validators) return components -def _load_components() -> Dict[str, _Component]: +def _load_components( + validators: Optional[List[TorchxFunctionValidator]], +) -> Dict[str, _Component]: """ Loads either the custom component defs from the entrypoint ``[torchx.components]`` or the default builtins from ``torchx.components`` module. @@ -318,19 +335,21 @@ def _load_components() -> Dict[str, _Component]: """ - components = _load_custom_components() + components = _load_custom_components(validators) if not components: - components = ModuleComponentsFinder("torchx.components", "").find() + components = ModuleComponentsFinder("torchx.components", "").find(validators) return {c.name: c for c in components} _components: Optional[Dict[str, _Component]] = None -def _find_components() -> Dict[str, _Component]: +def _find_components( + validators: Optional[List[TorchxFunctionValidator]], +) -> Dict[str, _Component]: global _components if not _components: - _components = _load_components() + _components = _load_components(validators) return none_throws(_components) @@ -338,17 +357,21 @@ def _is_custom_component(component_name: str) -> bool: return ":" in component_name -def _find_custom_components(name: str) -> Dict[str, _Component]: +def _find_custom_components( + name: str, validators: Optional[List[TorchxFunctionValidator]] +) -> Dict[str, _Component]: if ":" not in name: raise ValueError( f"Invalid custom component: {name}, valid template : `FILEPATH`:`FUNCTION_NAME`" ) filepath, component_name = name.split(":") - components = CustomComponentsFinder(filepath, component_name).find() + components = CustomComponentsFinder(filepath, component_name).find(validators) return {component.name: component for component in components} -def get_components() -> Dict[str, _Component]: +def get_components( + validators: Optional[List[TorchxFunctionValidator]] = None, +) -> Dict[str, _Component]: """ Returns all custom components registered via ``[torchx.components]`` entrypoints OR builtin components that ship with TorchX (but not both). @@ -395,13 +418,15 @@ def get_components() -> Dict[str, _Component]: """ valid_components: Dict[str, _Component] = {} - for component_name, component in _find_components().items(): + for component_name, component in _find_components(validators).items(): if len(component.validation_errors) == 0: valid_components[component_name] = component return valid_components -def get_component(name: str) -> _Component: +def get_component( + name: str, validators: Optional[List[TorchxFunctionValidator]] = None +) -> _Component: """ Retrieves components by the provided name. @@ -409,9 +434,9 @@ def get_component(name: str) -> _Component: Component or None if no component with ``name`` exists """ if _is_custom_component(name): - components = _find_custom_components(name) + components = _find_custom_components(name, validators) else: - components = _find_components() + components = _find_components(validators) if name not in components: raise ComponentNotFoundException( f"Component `{name}` not found. Please make sure it is one of the " @@ -428,7 +453,9 @@ def get_component(name: str) -> _Component: return component -def get_builtin_source(name: str) -> str: +def get_builtin_source( + name: str, validators: Optional[List[TorchxFunctionValidator]] = None +) -> str: """ Returns a string of the the builtin component's function source code with all the import statements. Intended to be used to make a copy @@ -446,7 +473,7 @@ def get_builtin_source(name: str) -> str: are optimized and formatting adheres to your organization's standards. """ - component = get_component(name) + component = get_component(name, validators) fn = component.fn fn_name = component.name.split(".")[-1] diff --git a/torchx/specs/test/file_linter_test.py b/torchx/specs/test/file_linter_test.py index adf66013d..5428ce95c 100644 --- a/torchx/specs/test/file_linter_test.py +++ b/torchx/specs/test/file_linter_test.py @@ -121,14 +121,13 @@ def test_syntax_error(self) -> None: content = "!!foo====bar" with patch("torchx.specs.file_linter.read_conf_file") as read_conf_file_mock: read_conf_file_mock.return_value = content - errors = validate(self._path, "unknown_function") + errors = validate(self._path, "unknown_function", None) self.assertEqual(1, len(errors)) self.assertEqual("invalid syntax", errors[0].description) def test_validate_varargs_kwargs_fn(self) -> None: linter_errors = validate( - self._path, - "_test_invalid_fn_with_varags_and_kwargs", + self._path, "_test_invalid_fn_with_varags_and_kwargs", None ) self.assertEqual(1, len(linter_errors)) self.assertTrue( @@ -136,7 +135,7 @@ def test_validate_varargs_kwargs_fn(self) -> None: ) def test_validate_no_return(self) -> None: - linter_errors = validate(self._path, "_test_fn_no_return") + linter_errors = validate(self._path, "_test_fn_no_return", None) self.assertEqual(1, len(linter_errors)) expected_desc = ( "Function: _test_fn_no_return missing return annotation or " @@ -145,7 +144,7 @@ def test_validate_no_return(self) -> None: self.assertEqual(expected_desc, linter_errors[0].description) def test_validate_incorrect_return(self) -> None: - linter_errors = validate(self._path, "_test_fn_return_int") + linter_errors = validate(self._path, "_test_fn_return_int", None) self.assertEqual(1, len(linter_errors)) expected_desc = ( "Function: _test_fn_return_int has incorrect return annotation, " @@ -153,12 +152,24 @@ def test_validate_incorrect_return(self) -> None: ) self.assertEqual(expected_desc, linter_errors[0].description) + def test_no_validators_has_no_validation(self) -> None: + linter_errors = validate(self._path, "_test_fn_return_int", []) + self.assertEqual(0, len(linter_errors)) + + linter_errors = validate(self._path, "_test_fn_no_return", []) + self.assertEqual(0, len(linter_errors)) + + linter_errors = validate( + self._path, "_test_invalid_fn_with_varags_and_kwargs", [] + ) + self.assertEqual(0, len(linter_errors)) + def test_validate_empty_fn(self) -> None: - linter_errors = validate(self._path, "_test_empty_fn") + linter_errors = validate(self._path, "_test_empty_fn", None) self.assertEqual(0, len(linter_errors)) def test_validate_args_no_type_defs(self) -> None: - linter_errors = validate(self._path, "_test_args_no_type_defs") + linter_errors = validate(self._path, "_test_args_no_type_defs", None) print(linter_errors) self.assertEqual(2, len(linter_errors)) self.assertEqual( @@ -169,10 +180,7 @@ def test_validate_args_no_type_defs(self) -> None: ) def test_validate_args_no_type_defs_complex(self) -> None: - linter_errors = validate( - self._path, - "_test_args_dict_list_complex_types", - ) + linter_errors = validate(self._path, "_test_args_dict_list_complex_types", None) self.assertEqual(5, len(linter_errors)) self.assertEqual( "Arg arg0 missing type annotation", linter_errors[0].description @@ -210,7 +218,7 @@ def test_validate_docstring_no_docs(self) -> None: self.assertEqual(" ", param_desc["arg0"]) def test_validate_unknown_function(self) -> None: - linter_errors = validate(self._path, "unknown_function") + linter_errors = validate(self._path, "unknown_function", None) self.assertEqual(1, len(linter_errors)) self.assertEqual( "Function unknown_function not found", linter_errors[0].description diff --git a/torchx/specs/test/finder_test.py b/torchx/specs/test/finder_test.py index 812ab1c5c..69dfecf01 100644 --- a/torchx/specs/test/finder_test.py +++ b/torchx/specs/test/finder_test.py @@ -106,13 +106,13 @@ def test_get_unknown_component_by_name(self, _: MagicMock) -> None: @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) def test_get_invalid_component(self, _: MagicMock) -> None: - components = _load_components() + components = _load_components(None) foobar_component = components["invalid_component"] self.assertEqual(1, len(foobar_component.validation_errors)) @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) def test_get_entrypoints_components(self, _: MagicMock) -> None: - components = _load_components() + components = _load_components(None) foobar_component = components["_test_component"] self.assertEqual(_test_component, foobar_component.fn) self.assertEqual("_test_component", foobar_component.fn_name) @@ -132,7 +132,7 @@ def test_get_entrypoints_components(self, _: MagicMock) -> None: ), ) def test_load_custom_components(self, _: MagicMock) -> None: - components = _load_components() + components = _load_components(None) # the name of the appdefs returned by each component # is the expected component name @@ -155,7 +155,7 @@ def test_load_custom_components(self, _: MagicMock) -> None: ), ) def test_load_custom_components_nogroup(self, _: MagicMock) -> None: - components = _load_components() + components = _load_components(None) # test component names are hardcoded expecting # test.components.* to be grouped under foo.* @@ -166,16 +166,17 @@ def test_load_custom_components_nogroup(self, _: MagicMock) -> None: self.assertEqual(expected_name, actual_name) def test_load_builtins(self) -> None: - components = _load_components() + components = _load_components(None) # if nothing registered in entrypoints, then builtins should be loaded expected = { - c.name for c in ModuleComponentsFinder("torchx.components", group="").find() + c.name + for c in ModuleComponentsFinder("torchx.components", group="").find(None) } self.assertEqual(components.keys(), expected) def test_load_builtin_echo(self) -> None: - components = _load_components() + components = _load_components(None) self.assertTrue(len(components) > 1) component = components["utils.echo"] self.assertEqual("utils.echo", component.name) @@ -194,7 +195,7 @@ class CustomComponentsFinderTest(unittest.TestCase): def test_find_components(self) -> None: components = CustomComponentsFinder( current_file_path(), "_test_component" - ).find() + ).find(None) self.assertEqual(1, len(components)) component = components[0] self.assertEqual(f"{current_file_path()}:_test_component", component.name) @@ -205,7 +206,7 @@ def test_find_components(self) -> None: def test_find_components_without_docstring(self) -> None: components = CustomComponentsFinder( current_file_path(), "_test_component_without_docstring" - ).find() + ).find(None) self.assertEqual(1, len(components)) component = components[0] self.assertEqual(