Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions torchx/specs/file_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,18 @@ class TorchFunctionVisitor(ast.NodeVisitor):

"""

def __init__(self, component_function_name: str) -> None:
self.validators = [
TorchxFunctionArgsValidator(),
TorchxReturnValidator(),
]
def __init__(
self,
component_function_name: str,
validators: Optional[List[TorchxFunctionValidator]],
) -> None:
if validators is None:
self.validators: List[TorchxFunctionValidator] = [
TorchxFunctionArgsValidator(),
TorchxReturnValidator(),
]
else:
self.validators = validators
self.linter_errors: List[LinterMessage] = []
self.component_function_name = component_function_name
self.visited_function = False
Expand All @@ -264,7 +271,11 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
self.linter_errors += validator.validate(node)


def validate(path: str, component_function: str) -> List[LinterMessage]:
def validate(
path: str,
component_function: str,
validators: Optional[List[TorchxFunctionValidator]],
) -> List[LinterMessage]:
"""
Validates the function to make sure it complies the component standard.

Expand Down Expand Up @@ -293,7 +304,7 @@ def validate(path: str, component_function: str) -> List[LinterMessage]:
severity="error",
)
return [linter_message]
visitor = TorchFunctionVisitor(component_function)
visitor = TorchFunctionVisitor(component_function, validators)
visitor.visit(module)
linter_errors = visitor.linter_errors
if not visitor.visited_function:
Expand Down
79 changes: 53 additions & 26 deletions torchx/specs/finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Callable, Dict, Generator, List, Optional, Union

from torchx.specs import AppDef
from torchx.specs.file_linter import get_fn_docstring, validate
from torchx.specs.file_linter import get_fn_docstring, TorchxFunctionValidator, validate
from torchx.util import entrypoints
from torchx.util.io import read_conf_file
from torchx.util.types import none_throws
Expand Down Expand Up @@ -59,7 +59,9 @@ class _Component:

class ComponentsFinder(abc.ABC):
@abc.abstractmethod
def find(self) -> List[_Component]:
def find(
self, validators: Optional[List[TorchxFunctionValidator]]
) -> List[_Component]:
"""
Retrieves a set of components. A component is defined as a python
function that conforms to ``torchx.specs.file_linter`` linter.
Expand Down Expand Up @@ -203,10 +205,12 @@ def _iter_modules_recursive(
else:
yield self._try_import(module_info.name)

def find(self) -> List[_Component]:
def find(
self, validators: Optional[List[TorchxFunctionValidator]]
) -> List[_Component]:
components = []
for m in self._iter_modules_recursive(self.base_module):
components += self._get_components_from_module(m)
components += self._get_components_from_module(m, validators)
return components

def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
Expand All @@ -221,7 +225,9 @@ def _try_import(self, module: Union[str, ModuleType]) -> ModuleType:
else:
return module

def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
def _get_components_from_module(
self, module: ModuleType, validators: Optional[List[TorchxFunctionValidator]]
) -> List[_Component]:
functions = getmembers(module, isfunction)
component_defs = []

Expand All @@ -230,7 +236,7 @@ def _get_components_from_module(self, module: ModuleType) -> List[_Component]:
module_path = os.path.abspath(module_path)
rel_module_name = module_relname(module, relative_to=self.base_module)
for function_name, function in functions:
linter_errors = validate(module_path, function_name)
linter_errors = validate(module_path, function_name, validators)
component_desc, _ = get_fn_docstring(function)

# remove empty string to deal with group=""
Expand All @@ -255,13 +261,20 @@ def __init__(self, filepath: str, function_name: str) -> None:
self._filepath = filepath
self._function_name = function_name

def _get_validation_errors(self, path: str, function_name: str) -> List[str]:
linter_errors = validate(path, function_name)
def _get_validation_errors(
self,
path: str,
function_name: str,
validators: Optional[List[TorchxFunctionValidator]],
) -> List[str]:
linter_errors = validate(path, function_name, validators)
return [linter_error.description for linter_error in linter_errors]

def find(self) -> List[_Component]:
def find(
self, validators: Optional[List[TorchxFunctionValidator]]
) -> List[_Component]:
validation_errors = self._get_validation_errors(
self._filepath, self._function_name
self._filepath, self._function_name, validators
)

file_source = read_conf_file(self._filepath)
Expand All @@ -284,7 +297,9 @@ def find(self) -> List[_Component]:
]


def _load_custom_components() -> List[_Component]:
def _load_custom_components(
validators: Optional[List[TorchxFunctionValidator]],
) -> List[_Component]:
component_modules = {
name: load_fn()
for name, load_fn in
Expand All @@ -303,11 +318,13 @@ def _load_custom_components() -> List[_Component]:
# _0 = torchx.components.dist
# _1 = torchx.components.utils
group = "" if group.startswith("_") else group
components += ModuleComponentsFinder(module, group).find()
components += ModuleComponentsFinder(module, group).find(validators)
return components


def _load_components() -> Dict[str, _Component]:
def _load_components(
validators: Optional[List[TorchxFunctionValidator]],
) -> Dict[str, _Component]:
"""
Loads either the custom component defs from the entrypoint ``[torchx.components]``
or the default builtins from ``torchx.components`` module.
Expand All @@ -318,37 +335,43 @@ def _load_components() -> Dict[str, _Component]:

"""

components = _load_custom_components()
components = _load_custom_components(validators)
if not components:
components = ModuleComponentsFinder("torchx.components", "").find()
components = ModuleComponentsFinder("torchx.components", "").find(validators)
return {c.name: c for c in components}


_components: Optional[Dict[str, _Component]] = None


def _find_components() -> Dict[str, _Component]:
def _find_components(
validators: Optional[List[TorchxFunctionValidator]],
) -> Dict[str, _Component]:
global _components
if not _components:
_components = _load_components()
_components = _load_components(validators)
return none_throws(_components)


def _is_custom_component(component_name: str) -> bool:
return ":" in component_name


def _find_custom_components(name: str) -> Dict[str, _Component]:
def _find_custom_components(
name: str, validators: Optional[List[TorchxFunctionValidator]]
) -> Dict[str, _Component]:
if ":" not in name:
raise ValueError(
f"Invalid custom component: {name}, valid template : `FILEPATH`:`FUNCTION_NAME`"
)
filepath, component_name = name.split(":")
components = CustomComponentsFinder(filepath, component_name).find()
components = CustomComponentsFinder(filepath, component_name).find(validators)
return {component.name: component for component in components}


def get_components() -> Dict[str, _Component]:
def get_components(
validators: Optional[List[TorchxFunctionValidator]] = None,
) -> Dict[str, _Component]:
"""
Returns all custom components registered via ``[torchx.components]`` entrypoints
OR builtin components that ship with TorchX (but not both).
Expand Down Expand Up @@ -395,23 +418,25 @@ def get_components() -> Dict[str, _Component]:
"""

valid_components: Dict[str, _Component] = {}
for component_name, component in _find_components().items():
for component_name, component in _find_components(validators).items():
if len(component.validation_errors) == 0:
valid_components[component_name] = component
return valid_components


def get_component(name: str) -> _Component:
def get_component(
name: str, validators: Optional[List[TorchxFunctionValidator]] = None
) -> _Component:
"""
Retrieves components by the provided name.

Returns:
Component or None if no component with ``name`` exists
"""
if _is_custom_component(name):
components = _find_custom_components(name)
components = _find_custom_components(name, validators)
else:
components = _find_components()
components = _find_components(validators)
if name not in components:
raise ComponentNotFoundException(
f"Component `{name}` not found. Please make sure it is one of the "
Expand All @@ -428,7 +453,9 @@ def get_component(name: str) -> _Component:
return component


def get_builtin_source(name: str) -> str:
def get_builtin_source(
name: str, validators: Optional[List[TorchxFunctionValidator]] = None
) -> str:
"""
Returns a string of the the builtin component's function source code
with all the import statements. Intended to be used to make a copy
Expand All @@ -446,7 +473,7 @@ def get_builtin_source(name: str) -> str:
are optimized and formatting adheres to your organization's standards.
"""

component = get_component(name)
component = get_component(name, validators)
fn = component.fn
fn_name = component.name.split(".")[-1]

Expand Down
32 changes: 20 additions & 12 deletions torchx/specs/test/file_linter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,21 @@ def test_syntax_error(self) -> None:
content = "!!foo====bar"
with patch("torchx.specs.file_linter.read_conf_file") as read_conf_file_mock:
read_conf_file_mock.return_value = content
errors = validate(self._path, "unknown_function")
errors = validate(self._path, "unknown_function", None)
self.assertEqual(1, len(errors))
self.assertEqual("invalid syntax", errors[0].description)

def test_validate_varargs_kwargs_fn(self) -> None:
linter_errors = validate(
self._path,
"_test_invalid_fn_with_varags_and_kwargs",
self._path, "_test_invalid_fn_with_varags_and_kwargs", None
)
self.assertEqual(1, len(linter_errors))
self.assertTrue(
"Arg args missing type annotation", linter_errors[0].description
)

def test_validate_no_return(self) -> None:
linter_errors = validate(self._path, "_test_fn_no_return")
linter_errors = validate(self._path, "_test_fn_no_return", None)
self.assertEqual(1, len(linter_errors))
expected_desc = (
"Function: _test_fn_no_return missing return annotation or "
Expand All @@ -145,20 +144,32 @@ def test_validate_no_return(self) -> None:
self.assertEqual(expected_desc, linter_errors[0].description)

def test_validate_incorrect_return(self) -> None:
linter_errors = validate(self._path, "_test_fn_return_int")
linter_errors = validate(self._path, "_test_fn_return_int", None)
self.assertEqual(1, len(linter_errors))
expected_desc = (
"Function: _test_fn_return_int has incorrect return annotation, "
"supported annotation: AppDef"
)
self.assertEqual(expected_desc, linter_errors[0].description)

def test_no_validators_has_no_validation(self) -> None:
linter_errors = validate(self._path, "_test_fn_return_int", [])
self.assertEqual(0, len(linter_errors))

linter_errors = validate(self._path, "_test_fn_no_return", [])
self.assertEqual(0, len(linter_errors))

linter_errors = validate(
self._path, "_test_invalid_fn_with_varags_and_kwargs", []
)
self.assertEqual(0, len(linter_errors))

def test_validate_empty_fn(self) -> None:
linter_errors = validate(self._path, "_test_empty_fn")
linter_errors = validate(self._path, "_test_empty_fn", None)
self.assertEqual(0, len(linter_errors))

def test_validate_args_no_type_defs(self) -> None:
linter_errors = validate(self._path, "_test_args_no_type_defs")
linter_errors = validate(self._path, "_test_args_no_type_defs", None)
print(linter_errors)
self.assertEqual(2, len(linter_errors))
self.assertEqual(
Expand All @@ -169,10 +180,7 @@ def test_validate_args_no_type_defs(self) -> None:
)

def test_validate_args_no_type_defs_complex(self) -> None:
linter_errors = validate(
self._path,
"_test_args_dict_list_complex_types",
)
linter_errors = validate(self._path, "_test_args_dict_list_complex_types", None)
self.assertEqual(5, len(linter_errors))
self.assertEqual(
"Arg arg0 missing type annotation", linter_errors[0].description
Expand Down Expand Up @@ -210,7 +218,7 @@ def test_validate_docstring_no_docs(self) -> None:
self.assertEqual(" ", param_desc["arg0"])

def test_validate_unknown_function(self) -> None:
linter_errors = validate(self._path, "unknown_function")
linter_errors = validate(self._path, "unknown_function", None)
self.assertEqual(1, len(linter_errors))
self.assertEqual(
"Function unknown_function not found", linter_errors[0].description
Expand Down
Loading
Loading