diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py index 6022aff3e..f0afb6e7a 100644 --- a/src/codegen/sdk/core/import_resolution.py +++ b/src/codegen/sdk/core/import_resolution.py @@ -324,6 +324,7 @@ def _imported_symbol(self, resolve_exports: bool = False) -> Symbol | ExternalMo """Returns the symbol directly being imported, including an indirect import and an External Module. """ + from codegen.sdk.python.file import PyFile from codegen.sdk.typescript.file import TSFile symbol = next(iter(self.ctx.successors(self.node_id, edge_type=EdgeType.IMPORT_SYMBOL_RESOLUTION, sort=False)), None) @@ -341,6 +342,14 @@ def _imported_symbol(self, resolve_exports: bool = False) -> Symbol | ExternalMo if self.import_type == ImportType.NAMED_EXPORT: if export := symbol.valid_import_names.get(name, None): return export + elif resolve_exports and isinstance(symbol, PyFile): + name = self.symbol_name.source if self.symbol_name else "" + if self.import_type == ImportType.NAMED_EXPORT: + if symbol.name == name: + return symbol + if imp := symbol.valid_import_names.get(name, None): + return imp + if symbol is not self: return symbol @@ -632,6 +641,11 @@ def _compute_dependencies(self, *args, **kwargs) -> None: # if used_frame.parent_frame: # used_frame.parent_frame.add_usage(self.symbol_name or self.module, SymbolUsageType.IMPORTED_WILDCARD, self, self.ctx) # else: + if isinstance(self, Import) and self.import_type == ImportType.NAMED_EXPORT: + # It could be a wildcard import downstream, hence we have to pop the cache + if file := self.from_file: + file.invalidate() + for used_frame in self.resolved_type_frames: if used_frame.parent_frame: used_frame.parent_frame.add_usage(self._unique_node, UsageKind.IMPORTED, self, self.ctx) diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index ecaa4da7b..cec7dc4d3 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -197,3 +197,57 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P ret[file.name] = file return ret return super().valid_import_names + + @noapidoc + def get_node_from_wildcard_chain(self, symbol_name: str) -> PySymbol | None: + """Recursively searches for a symbol through wildcard import chains. + + Attempts to find a symbol by name in the current file, and if not found, recursively searches + through any wildcard imports (from x import *) to find the symbol in imported modules. + + Args: + symbol_name (str): The name of the symbol to search for. + + Returns: + PySymbol | None: The found symbol if it exists in this file or any of its wildcard + imports, None otherwise. + """ + node = None + if node := self.get_node_by_name(symbol_name): + return node + + if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}: + for wildcard_import in wildcard_imports: + if imp_resolution := wildcard_import.resolve_import(): + node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name) + + return node + + @noapidoc + def get_node_wildcard_resolves_for(self, symbol_name: str) -> PyImport | PySymbol | None: + """Finds the wildcard import that resolves a given symbol name. + + Searches for a symbol by name, first in the current file, then through wildcard imports. + Unlike get_node_from_wildcard_chain, this returns the wildcard import that contains + the symbol rather than the symbol itself. + + Args: + symbol_name (str): The name of the symbol to search for. + + Returns: + PyImport | PySymbol | None: + - PySymbol if the symbol is found directly in this file + - PyImport if the symbol is found through a wildcard import + - None if the symbol cannot be found + """ + node = None + if node := self.get_node_by_name(symbol_name): + return node + + if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}: + for wildcard_import in wildcard_imports: + if imp_resolution := wildcard_import.resolve_import(): + if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name): + node = wildcard_import + + return node diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index a80bb2ada..46f9de63c 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -118,13 +118,28 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | filepath = os.path.join(base_path, filepath) if file := self.ctx.get_file(filepath): symbol = file.get_node_by_name(symbol_name) - return ImportResolution(from_file=file, symbol=symbol) + if symbol is None: + if file.get_node_from_wildcard_chain(symbol_name): + return ImportResolution(from_file=file, symbol=None, imports_file=True) + else: + # This is most likely a broken import + return ImportResolution(from_file=file, symbol=None) + else: + return ImportResolution(from_file=file, symbol=symbol) # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") if from_file := self.ctx.get_file(filepath): symbol = from_file.get_node_by_name(symbol_name) - return ImportResolution(from_file=from_file, symbol=symbol) + if symbol is None: + if from_file.get_node_from_wildcard_chain(symbol_name): + return ImportResolution(from_file=from_file, symbol=None, imports_file=True) + else: + # This is most likely a broken import + return ImportResolution(from_file=from_file, symbol=None) + + else: + return ImportResolution(from_file=from_file, symbol=symbol) # =====[ Case: Can't resolve the import ]===== if base_path == "": diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 6fd9cbe7b..63c58b5af 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -215,6 +215,204 @@ def func_1(): assert call_site.file == consumer_file +def test_import_resolution_init_wildcard(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly""" + # language=python + content1 = """TEST_CONST=2 + foo=9 + """ + content2 = """from testdir.test1 import * + bar=foo + test=TEST_CONST""" + content3 = """from testdir import TEST_CONST + test3=TEST_CONST""" + with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase: + file1: SourceFile = codebase.get_file("testdir/test1.py") + file2: SourceFile = codebase.get_file("testdir/__init__.py") + file3: SourceFile = codebase.get_file("test3.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + test3 = file3.get_symbol("test3") + test3_import = file3.get_import("TEST_CONST") + + assert len(symb.usages) == 3 + assert symb.symbol_usages == [test, test3, test3_import] + + +def test_import_resolution_wildcard_func(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly""" + # language=python + content1 = """ + def foo(): + pass + def bar(): + pass + """ + content2 = """ + from testa import * + + foo() + """ + + with get_codebase_session(tmpdir=tmpdir, files={"testa.py": content1, "testb.py": content2}) as codebase: + testa: SourceFile = codebase.get_file("testa.py") + testb: SourceFile = codebase.get_file("testb.py") + + foo = testa.get_symbol("foo") + bar = testa.get_symbol("bar") + assert len(foo.usages) == 1 + assert len(foo.call_sites) == 1 + + assert len(bar.usages) == 0 + assert len(bar.call_sites) == 0 + assert len(testb.function_calls) == 1 + + +def test_import_resolution_chaining_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + content1 = """TEST_CONST=2 + foo=9 + """ + content2 = """from testdir.test1 import * + bar=foo + test=TEST_CONST""" + content3 = """from testdir import * + test3=TEST_CONST""" + with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase: + file1: SourceFile = codebase.get_file("testdir/test1.py") + file2: SourceFile = codebase.get_file("testdir/__init__.py") + file3: SourceFile = codebase.get_file("test3.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + bar = file2.get_symbol("bar") + mid_import = file2.get_import("testdir.test1") + test3 = file3.get_symbol("test3") + + assert len(symb.usages) == 2 + assert symb.symbol_usages == [test, test3] + assert mid_import.symbol_usages == [test, bar, test3] + + +def test_import_resolution_init_deep_nested_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files = { + "test/nest/nest2/test1.py": """test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test/nest/nest2/__init__.py": """from .test1 import * + t1=test_used_parent + """, + "test/nest/__init__.py": """from .nest2 import *""", + "test/__init__.py": """from .nest import *""", + "main.py": """ + from test import * + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py") + + main_test = main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const = deepest_layer.get_symbol("test_const") + test_not_used = deepest_layer.get_symbol("test_not_used") + test_used_parent = deepest_layer.get_symbol("test_used_parent") + + assert len(test_const.usages) == 1 + assert test_const.usages[0].usage_symbol == main_test + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 + + +def test_import_resolution_chaining_many_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files = { + "test1.py": """ + test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test2.py": """from test1 import * + t1=test_used_parent + """, + "test3.py": """from test2 import *""", + "test4.py": """from test3 import *""", + "main.py": """ + from test4 import * + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + furthest_layer: SourceFile = codebase.get_file("test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test2.py") + + main_test = main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const = furthest_layer.get_symbol("test_const") + test_not_used = furthest_layer.get_symbol("test_not_used") + test_used_parent = furthest_layer.get_symbol("test_used_parent") + + assert len(test_const.usages) == 1 + assert test_const.usages[0].usage_symbol == main_test + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 + + +def test_import_resolution_init_deep_nested_wildcards_named(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files = { + "test/nest/nest2/test1.py": """test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test/nest/nest2/__init__.py": """from .test1 import * + t1=test_used_parent + """, + "test/nest/__init__.py": """from .nest2 import *""", + "test/__init__.py": """from .nest import *""", + "main.py": """ + from test import test_const + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py") + test_nest: SourceFile = codebase.get_file("test/__init__.py") + + main_test = main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const = deepest_layer.get_symbol("test_const") + test_not_used = deepest_layer.get_symbol("test_not_used") + test_used_parent = deepest_layer.get_symbol("test_used_parent") + + test_const_imp = main.get_import("test_const") + + assert len(test_const.usages) == 2 + assert test_const.usages[0].usage_symbol == main_test + assert test_const.usages[1].usage_symbol == test_const_imp + + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 + + def test_import_resolution_circular(tmpdir: str) -> None: """Tests function.usages returns usages from file imports""" # language=python @@ -367,4 +565,66 @@ def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None: ) as codebase: mainfile: SourceFile = codebase.get_file("file.py") - assert len(mainfile.ctx.edges) == 5 + assert len(mainfile.ctx.edges) == 10 + + +def test_import_resolution_init_wildcard_no_dupe(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly and doesn't + result in duplicate usages + """ + # language=python + content1 = """TEST_CONST=2 + foo=9 + """ + content2 = """from testdir.test1 import * + bar=foo + test=TEST_CONST""" + content3 = """from testdir import TEST_CONST + test3=TEST_CONST""" + content4 = """from testdir import foo + test4=foo""" + with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3, "test4.py": content4}) as codebase: + file1: SourceFile = codebase.get_file("testdir/test1.py") + file2: SourceFile = codebase.get_file("testdir/__init__.py") + file3: SourceFile = codebase.get_file("test3.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + test3 = file3.get_symbol("test3") + test3_import = file3.get_import("TEST_CONST") + + assert len(symb.usages) == 3 + assert symb.symbol_usages == [test, test3, test3_import] + + +def test_import_resolution_init_wildcard_chainging_deep(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly and doesn't + result in duplicate usages + """ + # language=python + content1 = """TEST_CONST=2 + """ + content2 = """from .file1 import *""" + content3 = """from .dir import *""" + content4 = """from .dir import TEST_CONST + test1=TEST_CONST""" + with get_codebase_session( + tmpdir=tmpdir, + files={ + "dir/dir/dir/dir/file1.py": content1, + "dir/dir/dir/dir/__init__.py": content2, + "dir/dir/dir/__init__.py": content3, + "dir/dir/__init__.py": content3, + "dir/__init__.py": content3, + "file2.py": content4, + }, + ) as codebase: + file1: SourceFile = codebase.get_file("dir/dir/dir/dir/file1.py") + file2: SourceFile = codebase.get_file("file2.py") + + symb = file1.get_symbol("TEST_CONST") + test1 = file2.get_symbol("test1") + imp = file2.get_import("TEST_CONST") + + assert len(symb.usages) == 2 + assert symb.symbol_usages == [test1, imp]