From b466a8e4b73dfec7ea6ec821d1c76ec13f4c2eb3 Mon Sep 17 00:00:00 2001 From: tomcodegen Date: Fri, 21 Feb 2025 15:10:48 -0800 Subject: [PATCH 1/7] fix --- src/codegen/sdk/python/file.py | 26 +++ src/codegen/sdk/python/import_resolution.py | 43 +++- .../test_import_resolution.py | 210 +++++++++++++++++- 3 files changed, 273 insertions(+), 6 deletions(-) diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index ecaa4da7b..a202b957a 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -197,3 +197,29 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P ret[file.name] = file return ret return super().valid_import_names + + + def get_node_from_wildcard_chain(self, symbol_name: str): + node = None + if node := self.get_node_by_name(symbol_name): + return node + + if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}: + for wildcard_import in wildcard_imports: + if imp_resolution := wildcard_import.resolve_import(): + node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name) + + return node + + def get_node_wildcard_resolves_for(self, symbol_name: str): + node = None + if node := self.get_node_by_name(symbol_name): + return node + + if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}: + for wildcard_import in wildcard_imports: + if imp_resolution := wildcard_import.resolve_import(): + if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name): + node = wildcard_import + + return node diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index a80bb2ada..6261d79dd 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -1,15 +1,18 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self, override from codegen.sdk.core.autocommit import reader from codegen.sdk.core.expressions import Name from codegen.sdk.core.import_resolution import ExternalImportResolver, Import, ImportResolution from codegen.sdk.enums import ImportType, NodeType +from codegen.sdk.extensions.resolution import ResolutionStack from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: + from collections.abc import Generator + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -28,6 +31,11 @@ class PyImport(Import["PyFile"]): """Extends Import for Python codebases.""" + def __init__(self, ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type=ImportType.UNKNOWN): + super().__init__(ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type) + self._requesting_names = set() + + @reader def is_module_import(self) -> bool: """Determines if the import is a module-level or wildcard import. @@ -117,13 +125,13 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | filepath = module_source.replace(".", "/") + ".py" filepath = os.path.join(base_path, filepath) if file := self.ctx.get_file(filepath): - symbol = file.get_node_by_name(symbol_name) + symbol = file.get_node_wildcard_resolves_for(symbol_name) return ImportResolution(from_file=file, symbol=symbol) # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") if from_file := self.ctx.get_file(filepath): - symbol = from_file.get_node_by_name(symbol_name) + symbol = from_file.get_node_wildcard_resolves_for(symbol_name) return ImportResolution(from_file=from_file, symbol=symbol) # =====[ Case: Can't resolve the import ]===== @@ -232,6 +240,35 @@ def from_future_import_statement(cls, import_statement: TSNode, file_node_id: No imports.append(imp) return imports + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + """Resolve the types used by this import.""" + ix_seen = set() + + aliased = self.is_aliased_import() + if imported := self._imported_symbol(resolve_exports=True): + if isinstance(imported,PyImport) and imported.is_wildcard_import: + imported.set_requesting_names(self) + yield from self.with_resolution_frame(imported, direct=False, aliased=aliased) + else: + yield ResolutionStack(self, aliased=aliased) + + if self.is_wildcard_import(): + for name, wildcard_import in self.names: + if name in self._requesting_names: + yield from [frame.parent_frame for frame in wildcard_import.resolved_type_frames] + + + @noapidoc + def set_requesting_names(self, requester: PyImport): + if requester.is_wildcard_import(): + self._requesting_names.update(requester._requesting_names) + else: + self._requesting_names.add(requester.name) + + @property @reader def import_specifier(self) -> Editable: diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 6fd9cbe7b..caec59592 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -215,6 +215,209 @@ def func_1(): assert call_site.file == consumer_file + +def test_import_resolution_init_wildcard(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly""" + # language=python + content1 = """TEST_CONST=2 + foo=9 + """ + content2 = """from testdir.test1 import * + bar=foo + test=TEST_CONST""" + content3 = """from testdir import TEST_CONST + test3=TEST_CONST""" + with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase: + file1: SourceFile = codebase.get_file("testdir/test1.py") + file2: SourceFile = codebase.get_file("testdir/__init__.py") + file3: SourceFile = codebase.get_file("test3.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + test3 = file3.get_symbol("test3") + test3_import = file3.get_import("TEST_CONST") + + assert len(symb.usages) == 3 + assert symb.symbol_usages == [test, test3, test3_import] + + +def test_import_resolution_wildcard_func(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly""" + # language=python + content1 = """ + def foo(): + pass + def bar(): + pass + """ + content2 = """ + from testa import * + + foo() + """ + + with get_codebase_session(tmpdir=tmpdir, files={"testa.py": content1, "testb.py": content2}) as codebase: + testa: SourceFile = codebase.get_file("testa.py") + testb: SourceFile = codebase.get_file("testb.py") + + foo = testa.get_symbol("foo") + bar = testa.get_symbol("bar") + assert len(foo.usages)==1 + assert len(foo.call_sites)==1 + + assert len(bar.usages)==0 + assert len(bar.call_sites)==0 + assert len(testb.function_calls)==1 + + + +def test_import_resolution_chaining_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + content1 = """TEST_CONST=2 + foo=9 + """ + content2 = """from testdir.test1 import * + bar=foo + test=TEST_CONST""" + content3 = """from testdir import * + test3=TEST_CONST""" + with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase: + file1: SourceFile = codebase.get_file("testdir/test1.py") + file2: SourceFile = codebase.get_file("testdir/__init__.py") + file3: SourceFile = codebase.get_file("test3.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + bar = file2.get_symbol("bar") + mid_import = file2.get_import("testdir.test1") + test3 = file3.get_symbol("test3") + + assert len(symb.usages) == 2 + assert symb.symbol_usages == [test, test3] + assert mid_import.symbol_usages == [test, bar, test3] + + +def test_import_resolution_init_deep_nested_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files={"test/nest/nest2/test1.py": """test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test/nest/nest2/__init__.py": """from .test1 import * + t1=test_used_parent + """, + "test/nest/__init__.py": """from .nest2 import *""", + "test/__init__.py": """from .nest import *""", + "main.py": """ + from test import * + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py") + + main_test=main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const=deepest_layer.get_symbol("test_const") + test_not_used=deepest_layer.get_symbol("test_not_used") + test_used_parent=deepest_layer.get_symbol("test_used_parent") + + assert len(test_const.usages)==1 + assert test_const.usages[0].usage_symbol==main_test + assert len(test_not_used.usages)==0 + assert len(test_used_parent.usages)==1 + assert test_used_parent.usages[0].usage_symbol==t1 + + +def test_import_resolution_chaining_many_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files={"test1.py": """ + test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test2.py": """from test1 import * + t1=test_used_parent + """, + "test3.py": """from test2 import *""", + "test4.py": """from test3 import *""", + "main.py": """ + from test4 import * + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + furthest_layer: SourceFile = codebase.get_file("test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test2.py") + + main_test=main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const=furthest_layer.get_symbol("test_const") + test_not_used=furthest_layer.get_symbol("test_not_used") + test_used_parent=furthest_layer.get_symbol("test_used_parent") + + assert len(test_const.usages)==1 + assert test_const.usages[0].usage_symbol==main_test + assert len(test_not_used.usages)==0 + assert len(test_used_parent.usages)==1 + assert test_used_parent.usages[0].usage_symbol==t1 + + +def test_import_resolution_init_deep_nested_wildcards_named(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files={"test/nest/nest2/test1.py": """test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test/nest/nest2/__init__.py": """from .test1 import * + t1=test_used_parent + """, + "test/nest/__init__.py": """from .nest2 import *""", + "test/__init__.py": """from .nest import *""", + "main.py": """ + from test import test_const + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py") + test_nest: SourceFile = codebase.get_file("test/__init__.py") + + + main_test=main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const=deepest_layer.get_symbol("test_const") + test_not_used=deepest_layer.get_symbol("test_not_used") + test_used_parent=deepest_layer.get_symbol("test_used_parent") + + test_const_imp = main.get_import("test_const") + test_const_imp_2 = test_nest.get_import(".nest") + + + assert len(test_const.usages)==3 + assert test_const.usages[0].usage_symbol==main_test + assert test_const.usages[1].usage_symbol==test_const_imp + assert test_const.usages[2].usage_symbol==test_const_imp_2 + + assert len(test_not_used.usages)==0 + assert len(test_used_parent.usages)==1 + assert test_used_parent.usages[0].usage_symbol==t1 + + + + def test_import_resolution_circular(tmpdir: str) -> None: """Tests function.usages returns usages from file imports""" # language=python @@ -342,8 +545,7 @@ def some_func(): assert len(some_func.usages) > 0 assert len(some_func.symbol_usages) > 0 - -def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None: +def test_import_wildcard_preserves_import_resultion(tmpdir: str) -> None: """Tests importing from a file that contains a wildcard import doesn't break further resolution. This could occur depending on to_resolve ordering, if the outer file is processed first _wildcards will not be filled in time. """ @@ -367,4 +569,6 @@ def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None: ) as codebase: mainfile: SourceFile = codebase.get_file("file.py") - assert len(mainfile.ctx.edges) == 5 + assert len(mainfile.ctx.edges)==12 + + From 80c84e523529be72d9ee13b20b555f6b632c7adb Mon Sep 17 00:00:00 2001 From: tomcodgen <191515280+tomcodgen@users.noreply.github.com> Date: Fri, 21 Feb 2025 23:12:33 +0000 Subject: [PATCH 2/7] Automated pre-commit update --- src/codegen/sdk/python/file.py | 1 - src/codegen/sdk/python/import_resolution.py | 5 +- .../test_import_resolution.py | 118 +++++++++--------- 3 files changed, 58 insertions(+), 66 deletions(-) diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index a202b957a..74ec2f10f 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -198,7 +198,6 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P return ret return super().valid_import_names - def get_node_from_wildcard_chain(self, symbol_name: str): node = None if node := self.get_node_by_name(symbol_name): diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index 6261d79dd..12dd3b086 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -35,7 +35,6 @@ def __init__(self, ts_node, file_node_id, G, parent, module_node, name_node, ali super().__init__(ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type) self._requesting_names = set() - @reader def is_module_import(self) -> bool: """Determines if the import is a module-level or wildcard import. @@ -249,7 +248,7 @@ def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: aliased = self.is_aliased_import() if imported := self._imported_symbol(resolve_exports=True): - if isinstance(imported,PyImport) and imported.is_wildcard_import: + if isinstance(imported, PyImport) and imported.is_wildcard_import: imported.set_requesting_names(self) yield from self.with_resolution_frame(imported, direct=False, aliased=aliased) else: @@ -260,7 +259,6 @@ def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: if name in self._requesting_names: yield from [frame.parent_frame for frame in wildcard_import.resolved_type_frames] - @noapidoc def set_requesting_names(self, requester: PyImport): if requester.is_wildcard_import(): @@ -268,7 +266,6 @@ def set_requesting_names(self, requester: PyImport): else: self._requesting_names.add(requester.name) - @property @reader def import_specifier(self) -> Editable: diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index caec59592..eb96f25df 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -215,7 +215,6 @@ def func_1(): assert call_site.file == consumer_file - def test_import_resolution_init_wildcard(tmpdir: str) -> None: """Tests that named import from a file with wildcard resolves properly""" # language=python @@ -262,13 +261,12 @@ def bar(): foo = testa.get_symbol("foo") bar = testa.get_symbol("bar") - assert len(foo.usages)==1 - assert len(foo.call_sites)==1 - - assert len(bar.usages)==0 - assert len(bar.call_sites)==0 - assert len(testb.function_calls)==1 + assert len(foo.usages) == 1 + assert len(foo.call_sites) == 1 + assert len(bar.usages) == 0 + assert len(bar.call_sites) == 0 + assert len(testb.function_calls) == 1 def test_import_resolution_chaining_wildcards(tmpdir: str) -> None: @@ -302,120 +300,119 @@ def test_import_resolution_init_deep_nested_wildcards(tmpdir: str) -> None: """Tests that chaining wildcard imports resolves properly""" # language=python - files={"test/nest/nest2/test1.py": """test_const=5 + files = { + "test/nest/nest2/test1.py": """test_const=5 test_not_used=2 test_used_parent=5 """, - "test/nest/nest2/__init__.py": """from .test1 import * + "test/nest/nest2/__init__.py": """from .test1 import * t1=test_used_parent """, - "test/nest/__init__.py": """from .nest2 import *""", - "test/__init__.py": """from .nest import *""", - "main.py": """ + "test/nest/__init__.py": """from .nest2 import *""", + "test/__init__.py": """from .nest import *""", + "main.py": """ from test import * main_test=test_const """, - } + } with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py") main: SourceFile = codebase.get_file("main.py") parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py") - main_test=main.get_symbol("main_test") + main_test = main.get_symbol("main_test") t1 = parent_file.get_symbol("t1") - test_const=deepest_layer.get_symbol("test_const") - test_not_used=deepest_layer.get_symbol("test_not_used") - test_used_parent=deepest_layer.get_symbol("test_used_parent") + test_const = deepest_layer.get_symbol("test_const") + test_not_used = deepest_layer.get_symbol("test_not_used") + test_used_parent = deepest_layer.get_symbol("test_used_parent") - assert len(test_const.usages)==1 - assert test_const.usages[0].usage_symbol==main_test - assert len(test_not_used.usages)==0 - assert len(test_used_parent.usages)==1 - assert test_used_parent.usages[0].usage_symbol==t1 + assert len(test_const.usages) == 1 + assert test_const.usages[0].usage_symbol == main_test + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 def test_import_resolution_chaining_many_wildcards(tmpdir: str) -> None: """Tests that chaining wildcard imports resolves properly""" # language=python - files={"test1.py": """ + files = { + "test1.py": """ test_const=5 test_not_used=2 test_used_parent=5 """, - "test2.py": """from test1 import * + "test2.py": """from test1 import * t1=test_used_parent """, - "test3.py": """from test2 import *""", - "test4.py": """from test3 import *""", - "main.py": """ + "test3.py": """from test2 import *""", + "test4.py": """from test3 import *""", + "main.py": """ from test4 import * main_test=test_const """, - } + } with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: furthest_layer: SourceFile = codebase.get_file("test1.py") main: SourceFile = codebase.get_file("main.py") parent_file: SourceFile = codebase.get_file("test2.py") - main_test=main.get_symbol("main_test") + main_test = main.get_symbol("main_test") t1 = parent_file.get_symbol("t1") - test_const=furthest_layer.get_symbol("test_const") - test_not_used=furthest_layer.get_symbol("test_not_used") - test_used_parent=furthest_layer.get_symbol("test_used_parent") + test_const = furthest_layer.get_symbol("test_const") + test_not_used = furthest_layer.get_symbol("test_not_used") + test_used_parent = furthest_layer.get_symbol("test_used_parent") - assert len(test_const.usages)==1 - assert test_const.usages[0].usage_symbol==main_test - assert len(test_not_used.usages)==0 - assert len(test_used_parent.usages)==1 - assert test_used_parent.usages[0].usage_symbol==t1 + assert len(test_const.usages) == 1 + assert test_const.usages[0].usage_symbol == main_test + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 def test_import_resolution_init_deep_nested_wildcards_named(tmpdir: str) -> None: """Tests that chaining wildcard imports resolves properly""" # language=python - files={"test/nest/nest2/test1.py": """test_const=5 + files = { + "test/nest/nest2/test1.py": """test_const=5 test_not_used=2 test_used_parent=5 """, - "test/nest/nest2/__init__.py": """from .test1 import * + "test/nest/nest2/__init__.py": """from .test1 import * t1=test_used_parent """, - "test/nest/__init__.py": """from .nest2 import *""", - "test/__init__.py": """from .nest import *""", - "main.py": """ + "test/nest/__init__.py": """from .nest2 import *""", + "test/__init__.py": """from .nest import *""", + "main.py": """ from test import test_const main_test=test_const """, - } + } with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py") main: SourceFile = codebase.get_file("main.py") parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py") test_nest: SourceFile = codebase.get_file("test/__init__.py") - - main_test=main.get_symbol("main_test") + main_test = main.get_symbol("main_test") t1 = parent_file.get_symbol("t1") - test_const=deepest_layer.get_symbol("test_const") - test_not_used=deepest_layer.get_symbol("test_not_used") - test_used_parent=deepest_layer.get_symbol("test_used_parent") + test_const = deepest_layer.get_symbol("test_const") + test_not_used = deepest_layer.get_symbol("test_not_used") + test_used_parent = deepest_layer.get_symbol("test_used_parent") test_const_imp = main.get_import("test_const") test_const_imp_2 = test_nest.get_import(".nest") + assert len(test_const.usages) == 3 + assert test_const.usages[0].usage_symbol == main_test + assert test_const.usages[1].usage_symbol == test_const_imp + assert test_const.usages[2].usage_symbol == test_const_imp_2 - assert len(test_const.usages)==3 - assert test_const.usages[0].usage_symbol==main_test - assert test_const.usages[1].usage_symbol==test_const_imp - assert test_const.usages[2].usage_symbol==test_const_imp_2 - - assert len(test_not_used.usages)==0 - assert len(test_used_parent.usages)==1 - assert test_used_parent.usages[0].usage_symbol==t1 - - + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 def test_import_resolution_circular(tmpdir: str) -> None: @@ -545,6 +542,7 @@ def some_func(): assert len(some_func.usages) > 0 assert len(some_func.symbol_usages) > 0 + def test_import_wildcard_preserves_import_resultion(tmpdir: str) -> None: """Tests importing from a file that contains a wildcard import doesn't break further resolution. This could occur depending on to_resolve ordering, if the outer file is processed first _wildcards will not be filled in time. @@ -569,6 +567,4 @@ def test_import_wildcard_preserves_import_resultion(tmpdir: str) -> None: ) as codebase: mainfile: SourceFile = codebase.get_file("file.py") - assert len(mainfile.ctx.edges)==12 - - + assert len(mainfile.ctx.edges) == 12 From c0b72848ca422caeada59dff702f8ad2dd3d0efe Mon Sep 17 00:00:00 2001 From: tomcodegen Date: Fri, 21 Feb 2025 16:05:21 -0800 Subject: [PATCH 3/7] bot changes --- src/codegen/sdk/python/file.py | 33 +++++++++++++++++-- .../test_import_resolution.py | 2 +- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index 74ec2f10f..cec7dc4d3 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -198,7 +198,20 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P return ret return super().valid_import_names - def get_node_from_wildcard_chain(self, symbol_name: str): + @noapidoc + def get_node_from_wildcard_chain(self, symbol_name: str) -> PySymbol | None: + """Recursively searches for a symbol through wildcard import chains. + + Attempts to find a symbol by name in the current file, and if not found, recursively searches + through any wildcard imports (from x import *) to find the symbol in imported modules. + + Args: + symbol_name (str): The name of the symbol to search for. + + Returns: + PySymbol | None: The found symbol if it exists in this file or any of its wildcard + imports, None otherwise. + """ node = None if node := self.get_node_by_name(symbol_name): return node @@ -210,7 +223,23 @@ def get_node_from_wildcard_chain(self, symbol_name: str): return node - def get_node_wildcard_resolves_for(self, symbol_name: str): + @noapidoc + def get_node_wildcard_resolves_for(self, symbol_name: str) -> PyImport | PySymbol | None: + """Finds the wildcard import that resolves a given symbol name. + + Searches for a symbol by name, first in the current file, then through wildcard imports. + Unlike get_node_from_wildcard_chain, this returns the wildcard import that contains + the symbol rather than the symbol itself. + + Args: + symbol_name (str): The name of the symbol to search for. + + Returns: + PyImport | PySymbol | None: + - PySymbol if the symbol is found directly in this file + - PyImport if the symbol is found through a wildcard import + - None if the symbol cannot be found + """ node = None if node := self.get_node_by_name(symbol_name): return node diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index eb96f25df..350e0f646 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -543,7 +543,7 @@ def some_func(): assert len(some_func.symbol_usages) > 0 -def test_import_wildcard_preserves_import_resultion(tmpdir: str) -> None: +def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None: """Tests importing from a file that contains a wildcard import doesn't break further resolution. This could occur depending on to_resolve ordering, if the outer file is processed first _wildcards will not be filled in time. """ From 1e754e47a37ec2e902426ca5a87e08339c79f05f Mon Sep 17 00:00:00 2001 From: tkucar Date: Tue, 25 Feb 2025 18:36:54 +0100 Subject: [PATCH 4/7] cleanup --- src/codegen/sdk/core/import_resolution.py | 15 +++++ src/codegen/sdk/python/import_resolution.py | 50 ++++---------- .../test_import_resolution.py | 65 +++++++++++++++++-- 3 files changed, 87 insertions(+), 43 deletions(-) diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py index 6022aff3e..42f2c5b10 100644 --- a/src/codegen/sdk/core/import_resolution.py +++ b/src/codegen/sdk/core/import_resolution.py @@ -324,6 +324,7 @@ def _imported_symbol(self, resolve_exports: bool = False) -> Symbol | ExternalMo """Returns the symbol directly being imported, including an indirect import and an External Module. """ + from codegen.sdk.python.file import PyFile from codegen.sdk.typescript.file import TSFile symbol = next(iter(self.ctx.successors(self.node_id, edge_type=EdgeType.IMPORT_SYMBOL_RESOLUTION, sort=False)), None) @@ -341,6 +342,15 @@ def _imported_symbol(self, resolve_exports: bool = False) -> Symbol | ExternalMo if self.import_type == ImportType.NAMED_EXPORT: if export := symbol.valid_import_names.get(name, None): return export + elif resolve_exports and isinstance(symbol,PyFile): + name = self.symbol_name.source if self.symbol_name else "" + if self.import_type == ImportType.NAMED_EXPORT: + if symbol.name==name: + return symbol + if imp:= symbol.valid_import_names.get(name,None): + return imp + + if symbol is not self: return symbol @@ -632,6 +642,11 @@ def _compute_dependencies(self, *args, **kwargs) -> None: # if used_frame.parent_frame: # used_frame.parent_frame.add_usage(self.symbol_name or self.module, SymbolUsageType.IMPORTED_WILDCARD, self, self.ctx) # else: + if isinstance(self, Import) and self.import_type==ImportType.NAMED_EXPORT: + #It could be a wildcard import downstream, hence we have to pop the cache + if file := self.from_file: + file.invalidate() + for used_frame in self.resolved_type_frames: if used_frame.parent_frame: used_frame.parent_frame.add_usage(self._unique_node, UsageKind.IMPORTED, self, self.ctx) diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index 12dd3b086..b90e3d314 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -1,18 +1,15 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING, Self, override +from typing import TYPE_CHECKING from codegen.sdk.core.autocommit import reader from codegen.sdk.core.expressions import Name from codegen.sdk.core.import_resolution import ExternalImportResolver, Import, ImportResolution from codegen.sdk.enums import ImportType, NodeType -from codegen.sdk.extensions.resolution import ResolutionStack from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: - from collections.abc import Generator - from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -31,10 +28,6 @@ class PyImport(Import["PyFile"]): """Extends Import for Python codebases.""" - def __init__(self, ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type=ImportType.UNKNOWN): - super().__init__(ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type) - self._requesting_names = set() - @reader def is_module_import(self) -> bool: """Determines if the import is a module-level or wildcard import. @@ -124,14 +117,20 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | filepath = module_source.replace(".", "/") + ".py" filepath = os.path.join(base_path, filepath) if file := self.ctx.get_file(filepath): - symbol = file.get_node_wildcard_resolves_for(symbol_name) - return ImportResolution(from_file=file, symbol=symbol) + symbol = file.get_node_by_name(symbol_name) + if symbol is None: + return ImportResolution(from_file=file, symbol=None,imports_file=True) + else: + return ImportResolution(from_file=file, symbol=symbol) # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") if from_file := self.ctx.get_file(filepath): - symbol = from_file.get_node_wildcard_resolves_for(symbol_name) - return ImportResolution(from_file=from_file, symbol=symbol) + symbol = from_file.get_node_by_name(symbol_name) + if symbol is None: + return ImportResolution(from_file=from_file, symbol=None,imports_file=True) + else: + return ImportResolution(from_file=from_file, symbol=symbol) # =====[ Case: Can't resolve the import ]===== if base_path == "": @@ -239,33 +238,6 @@ def from_future_import_statement(cls, import_statement: TSNode, file_node_id: No imports.append(imp) return imports - @reader - @noapidoc - @override - def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: - """Resolve the types used by this import.""" - ix_seen = set() - - aliased = self.is_aliased_import() - if imported := self._imported_symbol(resolve_exports=True): - if isinstance(imported, PyImport) and imported.is_wildcard_import: - imported.set_requesting_names(self) - yield from self.with_resolution_frame(imported, direct=False, aliased=aliased) - else: - yield ResolutionStack(self, aliased=aliased) - - if self.is_wildcard_import(): - for name, wildcard_import in self.names: - if name in self._requesting_names: - yield from [frame.parent_frame for frame in wildcard_import.resolved_type_frames] - - @noapidoc - def set_requesting_names(self, requester: PyImport): - if requester.is_wildcard_import(): - self._requesting_names.update(requester._requesting_names) - else: - self._requesting_names.add(requester.name) - @property @reader def import_specifier(self) -> Editable: diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 350e0f646..c3a8394f9 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -403,12 +403,10 @@ def test_import_resolution_init_deep_nested_wildcards_named(tmpdir: str) -> None test_used_parent = deepest_layer.get_symbol("test_used_parent") test_const_imp = main.get_import("test_const") - test_const_imp_2 = test_nest.get_import(".nest") - assert len(test_const.usages) == 3 + assert len(test_const.usages) == 2 assert test_const.usages[0].usage_symbol == main_test assert test_const.usages[1].usage_symbol == test_const_imp - assert test_const.usages[2].usage_symbol == test_const_imp_2 assert len(test_not_used.usages) == 0 assert len(test_used_parent.usages) == 1 @@ -567,4 +565,63 @@ def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None: ) as codebase: mainfile: SourceFile = codebase.get_file("file.py") - assert len(mainfile.ctx.edges) == 12 + assert len(mainfile.ctx.edges) == 10 + + +def test_import_resolution_init_wildcard_no_dupe(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly and doesn't + result in duplicate usages + """ + # language=python + content1 = """TEST_CONST=2 + foo=9 + """ + content2 = """from testdir.test1 import * + bar=foo + test=TEST_CONST""" + content3 = """from testdir import TEST_CONST + test3=TEST_CONST""" + content4 = """from testdir import foo + test4=foo""" + with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3, "test4.py": content4}) as codebase: + file1: SourceFile = codebase.get_file("testdir/test1.py") + file2: SourceFile = codebase.get_file("testdir/__init__.py") + file3: SourceFile = codebase.get_file("test3.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + test3 = file3.get_symbol("test3") + test3_import = file3.get_import("TEST_CONST") + + assert len(symb.usages) == 3 + assert symb.symbol_usages == [test, test3, test3_import] + + +def test_import_resolution_init_wildcard_chainging_deep(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly and doesn't + result in duplicate usages + """ + # language=python + content1 = """TEST_CONST=2 + """ + content2 = """from .file1 import *""" + content3 = """from .dir import *""" + content4 = """from .dir import TEST_CONST + test1=TEST_CONST""" + with get_codebase_session(tmpdir=tmpdir, files={ + "dir/dir/dir/dir/file1.py": content1, + "dir/dir/dir/dir/__init__.py": content2, + "dir/dir/dir/__init__.py": content3, + "dir/dir/__init__.py": content3, + "dir/__init__.py": content3, + "file2.py": content4 + }) as codebase: + file1: SourceFile = codebase.get_file("dir/dir/dir/dir/file1.py") + file2: SourceFile = codebase.get_file("file2.py") + + symb = file1.get_symbol("TEST_CONST") + test1 = file2.get_symbol("test1") + imp = file2.get_import("TEST_CONST") + + assert len(symb.usages) == 2 + assert symb.symbol_usages == [test1, imp] From 6f8fa8974752efca12b41edb104914817acb879c Mon Sep 17 00:00:00 2001 From: tomcodgen <191515280+tomcodgen@users.noreply.github.com> Date: Tue, 25 Feb 2025 17:37:51 +0000 Subject: [PATCH 5/7] Automated pre-commit update --- src/codegen/sdk/core/import_resolution.py | 11 +++++------ src/codegen/sdk/python/import_resolution.py | 4 ++-- .../test_import_resolution.py | 19 +++++++++++-------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/codegen/sdk/core/import_resolution.py b/src/codegen/sdk/core/import_resolution.py index 42f2c5b10..f0afb6e7a 100644 --- a/src/codegen/sdk/core/import_resolution.py +++ b/src/codegen/sdk/core/import_resolution.py @@ -342,15 +342,14 @@ def _imported_symbol(self, resolve_exports: bool = False) -> Symbol | ExternalMo if self.import_type == ImportType.NAMED_EXPORT: if export := symbol.valid_import_names.get(name, None): return export - elif resolve_exports and isinstance(symbol,PyFile): + elif resolve_exports and isinstance(symbol, PyFile): name = self.symbol_name.source if self.symbol_name else "" if self.import_type == ImportType.NAMED_EXPORT: - if symbol.name==name: + if symbol.name == name: return symbol - if imp:= symbol.valid_import_names.get(name,None): + if imp := symbol.valid_import_names.get(name, None): return imp - if symbol is not self: return symbol @@ -642,8 +641,8 @@ def _compute_dependencies(self, *args, **kwargs) -> None: # if used_frame.parent_frame: # used_frame.parent_frame.add_usage(self.symbol_name or self.module, SymbolUsageType.IMPORTED_WILDCARD, self, self.ctx) # else: - if isinstance(self, Import) and self.import_type==ImportType.NAMED_EXPORT: - #It could be a wildcard import downstream, hence we have to pop the cache + if isinstance(self, Import) and self.import_type == ImportType.NAMED_EXPORT: + # It could be a wildcard import downstream, hence we have to pop the cache if file := self.from_file: file.invalidate() diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index b90e3d314..bae0478ee 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -119,7 +119,7 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if file := self.ctx.get_file(filepath): symbol = file.get_node_by_name(symbol_name) if symbol is None: - return ImportResolution(from_file=file, symbol=None,imports_file=True) + return ImportResolution(from_file=file, symbol=None, imports_file=True) else: return ImportResolution(from_file=file, symbol=symbol) @@ -128,7 +128,7 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if from_file := self.ctx.get_file(filepath): symbol = from_file.get_node_by_name(symbol_name) if symbol is None: - return ImportResolution(from_file=from_file, symbol=None,imports_file=True) + return ImportResolution(from_file=from_file, symbol=None, imports_file=True) else: return ImportResolution(from_file=from_file, symbol=symbol) diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index c3a8394f9..63c58b5af 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -608,14 +608,17 @@ def test_import_resolution_init_wildcard_chainging_deep(tmpdir: str) -> None: content3 = """from .dir import *""" content4 = """from .dir import TEST_CONST test1=TEST_CONST""" - with get_codebase_session(tmpdir=tmpdir, files={ - "dir/dir/dir/dir/file1.py": content1, - "dir/dir/dir/dir/__init__.py": content2, - "dir/dir/dir/__init__.py": content3, - "dir/dir/__init__.py": content3, - "dir/__init__.py": content3, - "file2.py": content4 - }) as codebase: + with get_codebase_session( + tmpdir=tmpdir, + files={ + "dir/dir/dir/dir/file1.py": content1, + "dir/dir/dir/dir/__init__.py": content2, + "dir/dir/dir/__init__.py": content3, + "dir/dir/__init__.py": content3, + "dir/__init__.py": content3, + "file2.py": content4, + }, + ) as codebase: file1: SourceFile = codebase.get_file("dir/dir/dir/dir/file1.py") file2: SourceFile = codebase.get_file("file2.py") From a443fdd2608986e85020a0a46bf0fdc601498fd5 Mon Sep 17 00:00:00 2001 From: tkucar Date: Wed, 26 Feb 2025 00:02:48 +0100 Subject: [PATCH 6/7] fix --- src/codegen/sdk/python/import_resolution.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index bae0478ee..360ffa5c9 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -119,7 +119,11 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if file := self.ctx.get_file(filepath): symbol = file.get_node_by_name(symbol_name) if symbol is None: - return ImportResolution(from_file=file, symbol=None, imports_file=True) + if file.get_node_from_wildcard_chain(symbol_name): + return ImportResolution(from_file=file, symbol=None, imports_file=True) + else: + #This is most likely a broken import + return ImportResolution(from_file=file, symbol=None) else: return ImportResolution(from_file=file, symbol=symbol) @@ -128,10 +132,16 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if from_file := self.ctx.get_file(filepath): symbol = from_file.get_node_by_name(symbol_name) if symbol is None: - return ImportResolution(from_file=from_file, symbol=None, imports_file=True) + if from_file.get_node_from_wildcard_chain(symbol_name): + return ImportResolution(from_file=from_file, symbol=None, imports_file=True) + else: + #This is most likely a broken import + return ImportResolution(from_file=from_file, symbol=None) + else: return ImportResolution(from_file=from_file, symbol=symbol) + # =====[ Case: Can't resolve the import ]===== if base_path == "": # Try to resolve with "src" as the base path From 1b690ce843d69ebf716c75a88190762772876858 Mon Sep 17 00:00:00 2001 From: tomcodgen <191515280+tomcodgen@users.noreply.github.com> Date: Tue, 25 Feb 2025 23:04:09 +0000 Subject: [PATCH 7/7] Automated pre-commit update --- src/codegen/sdk/python/import_resolution.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index 360ffa5c9..46f9de63c 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -122,7 +122,7 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if file.get_node_from_wildcard_chain(symbol_name): return ImportResolution(from_file=file, symbol=None, imports_file=True) else: - #This is most likely a broken import + # This is most likely a broken import return ImportResolution(from_file=file, symbol=None) else: return ImportResolution(from_file=file, symbol=symbol) @@ -135,13 +135,12 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if from_file.get_node_from_wildcard_chain(symbol_name): return ImportResolution(from_file=from_file, symbol=None, imports_file=True) else: - #This is most likely a broken import + # This is most likely a broken import return ImportResolution(from_file=from_file, symbol=None) else: return ImportResolution(from_file=from_file, symbol=symbol) - # =====[ Case: Can't resolve the import ]===== if base_path == "": # Try to resolve with "src" as the base path