From 13b0a985be1bc5de938bd403b208943ea961aedb Mon Sep 17 00:00:00 2001 From: bagel897 Date: Wed, 29 Jan 2025 15:01:35 -0800 Subject: [PATCH 1/2] Add test --- .../test_import_resolution.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 0da5ae316..d2c12f0af 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -249,3 +249,36 @@ def c_sym(): assert "c_sym" in b_file.valid_symbol_names assert "a_sym" in c_file.valid_symbol_names assert "b_sym" in c_file.valid_symbol_names.keys() + + +def test_import_resolution_nested_module(tmpdir: str) -> None: + """Tests import resolution works with nested module imports""" + # language=python + with get_codebase_session( + tmpdir, + files={ + "a/b/c.py": """ +def d(): + pass +""", + "consumer.py": """ +from a import b + +b.c.d() +""", + }, + ) as codebase: + consumer_file: SourceFile = codebase.get_file("consumer.py") + c_file: SourceFile = codebase.get_file("a/b/c.py") + + # Verify import resolution + assert len(consumer_file.imports) == 1 + import_stmt = consumer_file.imports[0] + resolution = import_stmt.resolve_import() + assert resolution is not None + + # Verify function call resolution + d_func = c_file.get_function("d") + call_sites = d_func.call_sites + assert len(call_sites) == 1 + assert call_sites[0].file == consumer_file From 65f343d5007bf6b2228f618acd7f53d507a881b9 Mon Sep 17 00:00:00 2001 From: bagel897 Date: Wed, 29 Jan 2025 15:36:23 -0800 Subject: [PATCH 2/2] Namespace modules --- src/codegen/sdk/core/external_module.py | 8 +++-- src/codegen/sdk/core/import_resolution.py | 16 +++++++-- src/codegen/sdk/python/file.py | 20 ++++++++++- src/codegen/sdk/python/import_resolution.py | 17 ++++++---- .../sdk/typescript/import_resolution.py | 2 +- .../test_import_resolution.py | 34 +++++++++++++++++-- 6 files changed, 80 insertions(+), 17 deletions(-) diff --git a/src/codegen/sdk/core/external_module.py b/src/codegen/sdk/core/external_module.py index 9a61b297f..efa41d751 100644 --- a/src/codegen/sdk/core/external_module.py +++ b/src/codegen/sdk/core/external_module.py @@ -35,14 +35,16 @@ class ExternalModule( """ node_type: Literal[NodeType.EXTERNAL] = NodeType.EXTERNAL + _import: Import | None = None - def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name, import_node: Import | None = None) -> None: self.node_id = G.add_node(self) super().__init__(ts_node, file_node_id, G, None) self._name_node = import_name self.return_type = StubPlaceholder(parent=self) assert self._idx_key not in self.G._ext_module_idx self.G._ext_module_idx[self._idx_key] = self.node_id + self._import = import_node @property def _idx_key(self) -> str: @@ -68,7 +70,7 @@ def from_import(cls, imp: Import) -> ExternalModule: Returns: ExternalModule: A new ExternalModule instance representing the external module. """ - return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node) + return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node, imp) @property @reader @@ -136,7 +138,7 @@ def viz(self) -> VizNode: @noapidoc @reader def resolve_attribute(self, name: str) -> ExternalModule | None: - return self + return self._import.resolve_attribute(name) or self @noapidoc @commiter diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py index 1c35dcf2c..9003b262e 100644 --- a/src/codegen/sdk/core/import_resolution.py +++ b/src/codegen/sdk/core/import_resolution.py @@ -11,6 +11,7 @@ from codegen.sdk.core.expressions.name import Name from codegen.sdk.core.external_module import ExternalModule from codegen.sdk.core.interfaces.chainable import Chainable +from codegen.sdk.core.interfaces.has_attribute import HasAttribute from codegen.sdk.core.interfaces.usable import Usable from codegen.sdk.core.statements.import_statement import ImportStatement from codegen.sdk.enums import EdgeType, ImportType, NodeType @@ -57,7 +58,7 @@ class ImportResolution(Generic[TSourceFile]): @apidoc -class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile]): +class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile], HasAttribute[TSourceFile]): """Represents a single symbol being imported. For example, this is one `Import` in Python (and similar applies to Typescript, etc.): @@ -115,7 +116,7 @@ def __rich_repr__(self) -> rich.repr.Result: @noapidoc @abstractmethod - def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSourceFile] | None: + def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSourceFile] | None: """Resolves the import to a symbol defined outside the file. Returns an ImportResolution object. @@ -662,6 +663,17 @@ def remove_if_unused(self) -> None: ): self.remove() + @noapidoc + @reader + def resolve_attribute(self, attribute: str) -> TSourceFile | None: + # Handles implicit namespace imports in python + if not isinstance(self._imported_symbol(), ExternalModule): + return None + resolved = self.resolve_import(add_module_name=attribute) + if resolved: + return resolved.symbol or resolved.from_file + return None + TImport = TypeVar("TImport", bound="Import") diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index 36595e165..a89bfe65d 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -6,7 +6,7 @@ from codegen.sdk.core.file import SourceFile from codegen.sdk.core.interface import Interface from codegen.sdk.enums import ImportType, ProgrammingLanguage -from codegen.sdk.extensions.utils import iter_all_descendants +from codegen.sdk.extensions.utils import cached_property, iter_all_descendants from codegen.sdk.python import PyAssignment from codegen.sdk.python.class_definition import PyClass from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock @@ -20,6 +20,7 @@ if TYPE_CHECKING: from codegen.sdk.codebase.codebase_graph import CodebaseGraph + from codegen.sdk.core.import_resolution import WildcardImport from codegen.sdk.python.symbol import PySymbol @@ -173,3 +174,20 @@ def add_import_from_import_string(self, import_string: str) -> None: def remove_unused_exports(self) -> None: """Removes unused exports from the file. NO-OP for python""" pass + + @cached_property + @noapidoc + @reader(cache=True) + def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[PyImport]]: + """Returns a dict mapping name => Symbol (or import) in this file that can be imported from + another file. + """ + if self.name == "__init__": + ret = {} + if self.directory: + for file in self.directory: + if file.name == "__init__": + continue + ret[file.name] = file + return ret + return super().valid_import_names diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index 070d815d6..64e3b5f1c 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -82,10 +82,13 @@ def imported_exports(self) -> list[Exportable]: @noapidoc @reader - def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFile] | None: + def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[PyFile] | None: base_path = base_path or self.G.projects[0].base_path or "" module_source = self.module.source if self.module else "" - + symbol_name = self.symbol_name.source if self.symbol_name else "" + if add_module_name: + module_source += f".{symbol_name}" + symbol_name = add_module_name # If import is relative, convert to absolute path if module_source.startswith("."): module_source = self._relative_to_absolute_import(module_source) @@ -99,7 +102,7 @@ def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFil # `from a.b.c import foo` filepath = os.path.join( base_path, - module_source.replace(".", "/") + "/" + self.symbol_name.source + ".py", + module_source.replace(".", "/") + "/" + symbol_name + ".py", ) if file := self.G.get_file(filepath): return ImportResolution(from_file=file, symbol=None, imports_file=True) @@ -114,22 +117,22 @@ def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFil filepath = module_source.replace(".", "/") + ".py" filepath = os.path.join(base_path, filepath) if file := self.G.get_file(filepath): - symbol = file.get_node_by_name(self.symbol_name.source) + symbol = file.get_node_by_name(symbol_name) return ImportResolution(from_file=file, symbol=symbol) # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") if from_file := self.G.get_file(filepath): - symbol = from_file.get_node_by_name(self.symbol_name.source) + symbol = from_file.get_node_by_name(symbol_name) return ImportResolution(from_file=from_file, symbol=symbol) # =====[ Case: Can't resolve the import ]===== if base_path == "": # Try to resolve with "src" as the base path - return self.resolve_import(base_path="src") + return self.resolve_import(base_path="src", add_module_name=add_module_name) if base_path == "src": # Try "test" next - return self.resolve_import(base_path="test") + return self.resolve_import(base_path="test", add_module_name=add_module_name) # if not G_override: # for resolver in G.import_resolvers: diff --git a/src/codegen/sdk/typescript/import_resolution.py b/src/codegen/sdk/typescript/import_resolution.py index 810843483..9084b3e09 100644 --- a/src/codegen/sdk/typescript/import_resolution.py +++ b/src/codegen/sdk/typescript/import_resolution.py @@ -197,7 +197,7 @@ def resolved_symbol(self) -> Symbol | ExternalModule | TSFile | None: return resolved_symbol @reader - def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSFile] | None: + def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSFile] | None: """Resolves an import statement to its target file and symbol. This method is used by GraphBuilder to resolve import statements to their target files and symbols. It handles both relative and absolute imports, diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index d2c12f0af..07587393a 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -273,9 +273,37 @@ def d(): # Verify import resolution assert len(consumer_file.imports) == 1 - import_stmt = consumer_file.imports[0] - resolution = import_stmt.resolve_import() - assert resolution is not None + + # Verify function call resolution + d_func = c_file.get_function("d") + call_sites = d_func.call_sites + assert len(call_sites) == 1 + assert call_sites[0].file == consumer_file + + +def test_import_resolution_nested_module_init(tmpdir: str) -> None: + """Tests import resolution works with nested module imports""" + # language=python + with get_codebase_session( + tmpdir, + files={ + "a/b/c.py": """ +def d(): + pass +""", + "a/b/__init__.py": """""", + "consumer.py": """ +from a import b + +b.c.d() +""", + }, + ) as codebase: + consumer_file: SourceFile = codebase.get_file("consumer.py") + c_file: SourceFile = codebase.get_file("a/b/c.py") + + # Verify import resolution + assert len(consumer_file.imports) == 1 # Verify function call resolution d_func = c_file.get_function("d")