From 5cb0c3e698d7d10224231593b40e9bd77ee4811e Mon Sep 17 00:00:00 2001 From: Daniel Ohayon Date: Thu, 5 Jun 2025 19:40:00 -0700 Subject: [PATCH] handle components whose implementation lives in a different file (#1075) Summary: Pull Request resolved: https://github.com/pytorch/torchx/pull/1075 Add support for cases like: ```lang=python # some_file.py # ==================== def my_component(...) -> specs.AppDef: ... # other_file.py # ==================== from some_file import my_component ``` where the component is invoked with `torchx run ... other_file.py:my_component` This was currently failing with a validation error because in the step where we inspect the AST of the component, we assume that the file where the component is being looked up is the same as the file where it is implemented. Reviewed By: kiukchung Differential Revision: D75496839 --- torchx/specs/finder.py | 23 ++++++++++++++++++++--- torchx/specs/test/finder_test.py | 5 +++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/torchx/specs/finder.py b/torchx/specs/finder.py index 4258ec7a7..66900e7a4 100644 --- a/torchx/specs/finder.py +++ b/torchx/specs/finder.py @@ -274,12 +274,23 @@ def _get_validation_errors( linter_errors = validate(path, function_name, validators) return [linter_error.description for linter_error in linter_errors] + def _get_path_to_function_decl( + self, function: Callable[..., Any] # pyre-ignore[2] + ) -> str: + """ + Attempts to return the path to the file where the function is implemented. + This can be different from the path where the function is looked up, for example if we have: + my_component defined in some_file.py, imported in other_file.py + and the component is invoked as other_file.py:my_component + """ + path_to_function_decl = inspect.getabsfile(function) + if path_to_function_decl is None or not os.path.isfile(path_to_function_decl): + return self._filepath + return path_to_function_decl + def find( self, validators: Optional[List[TorchxFunctionValidator]] ) -> List[_Component]: - validation_errors = self._get_validation_errors( - self._filepath, self._function_name, validators - ) file_source = read_conf_file(self._filepath) namespace = copy.copy(globals()) @@ -292,6 +303,12 @@ def find( ) app_fn = namespace[self._function_name] fn_desc, _ = get_fn_docstring(app_fn) + + func_path = self._get_path_to_function_decl(app_fn) + validation_errors = self._get_validation_errors( + func_path, self._function_name, validators + ) + return [ _Component( name=f"{self._filepath}:{self._function_name}", diff --git a/torchx/specs/test/finder_test.py b/torchx/specs/test/finder_test.py index 69dfecf01..18f01b4c5 100644 --- a/torchx/specs/test/finder_test.py +++ b/torchx/specs/test/finder_test.py @@ -29,6 +29,7 @@ get_components, ModuleComponentsFinder, ) +from torchx.specs.test.components.a import comp_a from torchx.util.test.entrypoints_test import EntryPoint_from_text from torchx.util.types import none_throws @@ -238,6 +239,10 @@ def test_get_component_invalid(self) -> None: with self.assertRaises(ComponentValidationException): get_component(f"{current_file_path()}:invalid_component") + def test_get_component_imported_from_other_file(self) -> None: + component = get_component(f"{current_file_path()}:comp_a") + self.assertListEqual([], component.validation_errors) + class GetBuiltinSourceTest(unittest.TestCase): def setUp(self) -> None: