diff --git a/src/codegen/sdk/python/assignment.py b/src/codegen/sdk/python/assignment.py index f295f741c..877498235 100644 --- a/src/codegen/sdk/python/assignment.py +++ b/src/codegen/sdk/python/assignment.py @@ -2,8 +2,11 @@ from typing import TYPE_CHECKING +from codegen.sdk.codebase.transactions import RemoveTransaction, TransactionPriority from codegen.sdk.core.assignment import Assignment +from codegen.sdk.core.autocommit.decorators import remover from codegen.sdk.core.expressions.multi_expression import MultiExpression +from codegen.sdk.core.symbol_groups.collection import Collection from codegen.sdk.extensions.autocommit import reader from codegen.sdk.python.symbol import PySymbol from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup @@ -96,3 +99,70 @@ def inline_comment(self) -> PyCommentGroup | None: """ # HACK: This is a temporary solution until comments are fixed return PyCommentGroup.from_symbol_inline_comments(self, self.ts_node.parent) + + @remover + def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: + """Deletes this assignment and its related extended nodes (e.g. decorators, comments). + + + Removes the current node and its extended nodes (e.g. decorators, comments) from the codebase. + After removing the node, it handles cleanup of any surrounding formatting based on the context. + + Args: + delete_formatting (bool): Whether to delete surrounding whitespace and formatting. Defaults to True. + priority (int): Priority of the removal transaction. Higher priority transactions are executed first. Defaults to 0. + dedupe (bool): Whether to deduplicate removal transactions at the same location. Defaults to True. + + Returns: + None + """ + if getattr(self.parent, "assignments", None) and len(self.parent.assignments) > 1: + # Unpacking assignments + name = self.get_name() + if isinstance(self.value, Collection): + # Tuples + transaction_count = [ + any( + self.transaction_manager.get_transactions_at_range( + self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=TransactionPriority.Remove + ) + ) + for asgnmt in self.parent.assignments + ].count(True) + # Check for existing transactions + if transaction_count < len(self.parent.assignments) - 1: + idx = self.parent.left.index(name) + value = self.value[idx] + removal_queue_values = getattr(self.parent, "removal_queue", []) + self.parent.removal_queue = removal_queue_values + removal_queue_values.append(str(value)) + if len(self.value) - transaction_count == 2: + remainder = str(next(x for x in self.value if x not in removal_queue_values)) + r_t = RemoveTransaction(self.value.start_byte, self.value.end_byte, self.file, priority=priority) + self.transaction_manager.add_transaction(r_t) + self.value.insert_at(self.value.start_byte, remainder, priority=priority) + else: + value.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) + name.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) + return + else: + transaction_count = [ + any( + self.transaction_manager.get_transactions_at_range( + self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=TransactionPriority.Edit + ) + ) + for asgnmt in self.parent.assignments + ].count(True) + throwaway = [asgnmt.name == "_" for asgnmt in self.parent.assignments].count(True) + if transaction_count + throwaway < len(self.parent.assignments) - 1: + name.edit("_", priority=priority, dedupe=dedupe) + return + if getattr(self.parent, "removal_queue", None): + for node in self.extended_nodes: + transactions = self.transaction_manager.get_transactions_at_range(self.file.path, start_byte=node.start_byte, end_byte=node.end_byte) + for transaction in transactions: + self.transaction_manager.queued_transactions[self.file.path].remove(transaction) + + for node in self.extended_nodes: + node._remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) diff --git a/src/codegen/sdk/python/file.py b/src/codegen/sdk/python/file.py index ecaa4da7b..74ec2f10f 100644 --- a/src/codegen/sdk/python/file.py +++ b/src/codegen/sdk/python/file.py @@ -197,3 +197,28 @@ def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[P ret[file.name] = file return ret return super().valid_import_names + + def get_node_from_wildcard_chain(self, symbol_name: str): + node = None + if node := self.get_node_by_name(symbol_name): + return node + + if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}: + for wildcard_import in wildcard_imports: + if imp_resolution := wildcard_import.resolve_import(): + node = imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name) + + return node + + def get_node_wildcard_resolves_for(self, symbol_name: str): + node = None + if node := self.get_node_by_name(symbol_name): + return node + + if wildcard_imports := {imp for imp in self.imports if imp.is_wildcard_import()}: + for wildcard_import in wildcard_imports: + if imp_resolution := wildcard_import.resolve_import(): + if imp_resolution.from_file.get_node_from_wildcard_chain(symbol_name=symbol_name): + node = wildcard_import + + return node diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index a80bb2ada..f87bbb5c2 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -1,15 +1,19 @@ from __future__ import annotations import os -from typing import TYPE_CHECKING +from collections.abc import Generator +from typing import TYPE_CHECKING, Self, override from codegen.sdk.core.autocommit import reader from codegen.sdk.core.expressions import Name from codegen.sdk.core.import_resolution import ExternalImportResolver, Import, ImportResolution from codegen.sdk.enums import ImportType, NodeType +from codegen.sdk.extensions.resolution import ResolutionStack from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: + from collections.abc import Generator + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -28,6 +32,10 @@ class PyImport(Import["PyFile"]): """Extends Import for Python codebases.""" + def __init__(self, ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type=ImportType.UNKNOWN): + super().__init__(ts_node, file_node_id, G, parent, module_node, name_node, alias_node, import_type) + self.requesting_names = set() + @reader def is_module_import(self) -> bool: """Determines if the import is a module-level or wildcard import. @@ -117,13 +125,13 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | filepath = module_source.replace(".", "/") + ".py" filepath = os.path.join(base_path, filepath) if file := self.ctx.get_file(filepath): - symbol = file.get_node_by_name(symbol_name) + symbol = file.get_node_wildcard_resolves_for(symbol_name) return ImportResolution(from_file=file, symbol=symbol) # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") if from_file := self.ctx.get_file(filepath): - symbol = from_file.get_node_by_name(symbol_name) + symbol = from_file.get_node_wildcard_resolves_for(symbol_name) return ImportResolution(from_file=from_file, symbol=symbol) # =====[ Case: Can't resolve the import ]===== @@ -133,6 +141,11 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if base_path == "src": # Try "test" next return self.resolve_import(base_path="test", add_module_name=add_module_name) + if base_path == "test" and module_source: + # Try to resolve assuming package nested in repo + possible_package_base_path = module_source.split(".")[0] + if possible_package_base_path not in ("test", "src"): + return self.resolve_import(base_path=possible_package_base_path, add_module_name=add_module_name) # if not G_override: # for resolver in ctx.import_resolvers: @@ -232,6 +245,33 @@ def from_future_import_statement(cls, import_statement: TSNode, file_node_id: No imports.append(imp) return imports + @reader + @noapidoc + @override + def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: + """Resolve the types used by this import.""" + ix_seen = set() + + aliased = self.is_aliased_import() + if imported := self._imported_symbol(resolve_exports=True): + if getattr(imported, "is_wildcard_import", False): + imported.set_requesting_names(self) + yield from self.with_resolution_frame(imported, direct=False, aliased=aliased) + else: + yield ResolutionStack(self, aliased=aliased) + + if self.is_wildcard_import(): + for name, wildcard_import in self.names: + if name in self.requesting_names: + yield from [frame.parent_frame for frame in wildcard_import.resolved_type_frames] + + @noapidoc + def set_requesting_names(self, requester: PyImport): + if requester.is_wildcard_import(): + self.requesting_names.update(requester.requesting_names) + else: + self.requesting_names.add(requester.name) + @property @reader def import_specifier(self) -> Editable: diff --git a/tests/unit/codegen/sdk/python/expressions/test_unpacking.py b/tests/unit/codegen/sdk/python/expressions/test_unpacking.py new file mode 100644 index 000000000..098ef3bdf --- /dev/null +++ b/tests/unit/codegen/sdk/python/expressions/test_unpacking.py @@ -0,0 +1,116 @@ +from codegen.sdk.codebase.factory.get_session import get_codebase_session + + +def test_remove_unpacking_assignment(tmpdir) -> None: + # language=python + content = """foo,bar,buzz = (a, b, c)""" + + with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase: + file1 = codebase.get_file("test1.py") + file2 = codebase.get_file("test2.py") + file3 = codebase.get_file("test3.py") + + foo = file1.get_symbol("foo") + foo.remove() + codebase.commit() + + assert len(file1.symbols) == 2 + statement = file1.symbols[0].parent + assert len(statement.assignments) == 2 + assert len(statement.value) == 2 + assert file1.source == """bar,buzz = (b, c)""" + bar = file2.get_symbol("bar") + bar.remove() + codebase.commit() + assert len(file2.symbols) == 2 + statement = file2.symbols[0].parent + assert len(statement.assignments) == 2 + assert len(statement.value) == 2 + assert file2.source == """foo,buzz = (a, c)""" + + buzz = file3.get_symbol("buzz") + buzz.remove() + codebase.commit() + + assert len(file3.symbols) == 2 + statement = file3.symbols[0].parent + assert len(statement.assignments) == 2 + assert len(statement.value) == 2 + assert file3.source == """foo,bar = (a, b)""" + + file1_bar = file1.get_symbol("bar") + + file1_bar.remove() + codebase.commit() + assert file1.source == """buzz = c""" + + file1_buzz = file1.get_symbol("buzz") + file1_buzz.remove() + + codebase.commit() + assert len(file1.symbols) == 0 + assert file1.source == """""" + + +def test_remove_unpacking_assignment_funct(tmpdir) -> None: + # language=python + content = """foo,bar,buzz = f()""" + + with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase: + file1 = codebase.get_file("test1.py") + file2 = codebase.get_file("test2.py") + file3 = codebase.get_file("test3.py") + + foo = file1.get_symbol("foo") + foo.remove() + codebase.commit() + + assert len(file1.symbols) == 3 + statement = file1.symbols[0].parent + assert len(statement.assignments) == 3 + assert file1.source == """_,bar,buzz = f()""" + bar = file2.get_symbol("bar") + bar.remove() + codebase.commit() + assert len(file2.symbols) == 3 + statement = file2.symbols[0].parent + assert len(statement.assignments) == 3 + assert file2.source == """foo,_,buzz = f()""" + + buzz = file3.get_symbol("buzz") + buzz.remove() + codebase.commit() + + assert len(file3.symbols) == 3 + statement = file3.symbols[0].parent + assert len(statement.assignments) == 3 + assert file3.source == """foo,bar,_ = f()""" + + file1_bar = file1.get_symbol("bar") + file1_buzz = file1.get_symbol("buzz") + + file1_bar.remove() + file1_buzz.remove() + codebase.commit() + assert len(file1.symbols) == 0 + assert file1.source == """""" + + +def test_remove_unpacking_assignment_num(tmpdir) -> None: + # language=python + content = """foo,bar,buzz = (1, 2, 3)""" + + with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content}) as codebase: + file1 = codebase.get_file("test1.py") + + foo = file1.get_symbol("foo") + buzz = file1.get_symbol("buzz") + + foo.remove() + buzz.remove() + codebase.commit() + + assert len(file1.symbols) == 1 + statement = file1.symbols[0].parent + assert len(statement.assignments) == 1 + assert file1.source == """bar = 2""" diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 6fd9cbe7b..b9725a09e 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -215,6 +215,206 @@ def func_1(): assert call_site.file == consumer_file +def test_import_resolution_init_wildcard(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly""" + # language=python + content1 = """TEST_CONST=2 + foo=9 + """ + content2 = """from testdir.test1 import * + bar=foo + test=TEST_CONST""" + content3 = """from testdir import TEST_CONST + test3=TEST_CONST""" + with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase: + file1: SourceFile = codebase.get_file("testdir/test1.py") + file2: SourceFile = codebase.get_file("testdir/__init__.py") + file3: SourceFile = codebase.get_file("test3.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + test3 = file3.get_symbol("test3") + test3_import = file3.get_import("TEST_CONST") + + assert len(symb.usages) == 3 + assert symb.symbol_usages == [test, test3, test3_import] + + +def test_import_resolution_wildcard_func(tmpdir: str) -> None: + """Tests that named import from a file with wildcard resolves properly""" + # language=python + content1 = """ + def foo(): + pass + def bar(): + pass + """ + content2 = """ + from testa import * + + foo() + """ + + with get_codebase_session(tmpdir=tmpdir, files={"testa.py": content1, "testb.py": content2}) as codebase: + testa: SourceFile = codebase.get_file("testa.py") + testb: SourceFile = codebase.get_file("testb.py") + + foo = testa.get_symbol("foo") + bar = testa.get_symbol("bar") + assert len(foo.usages) == 1 + assert len(foo.call_sites) == 1 + + assert len(bar.usages) == 0 + assert len(bar.call_sites) == 0 + assert len(testb.function_calls) == 1 + + +def test_import_resolution_chaining_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + content1 = """TEST_CONST=2 + foo=9 + """ + content2 = """from testdir.test1 import * + bar=foo + test=TEST_CONST""" + content3 = """from testdir import * + test3=TEST_CONST""" + with get_codebase_session(tmpdir=tmpdir, files={"testdir/test1.py": content1, "testdir/__init__.py": content2, "test3.py": content3}) as codebase: + file1: SourceFile = codebase.get_file("testdir/test1.py") + file2: SourceFile = codebase.get_file("testdir/__init__.py") + file3: SourceFile = codebase.get_file("test3.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + bar = file2.get_symbol("bar") + mid_import = file2.get_import("testdir.test1") + test3 = file3.get_symbol("test3") + + assert len(symb.usages) == 2 + assert symb.symbol_usages == [test, test3] + assert mid_import.symbol_usages == [test, bar, test3] + + +def test_import_resolution_init_deep_nested_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files = { + "test/nest/nest2/test1.py": """test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test/nest/nest2/__init__.py": """from .test1 import * + t1=test_used_parent + """, + "test/nest/__init__.py": """from .nest2 import *""", + "test/__init__.py": """from .nest import *""", + "main.py": """ + from test import * + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py") + + main_test = main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const = deepest_layer.get_symbol("test_const") + test_not_used = deepest_layer.get_symbol("test_not_used") + test_used_parent = deepest_layer.get_symbol("test_used_parent") + + assert len(test_const.usages) == 1 + assert test_const.usages[0].usage_symbol == main_test + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 + + +def test_import_resolution_chaining_many_wildcards(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files = { + "test1.py": """ + test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test2.py": """from test1 import * + t1=test_used_parent + """, + "test3.py": """from test2 import *""", + "test4.py": """from test3 import *""", + "main.py": """ + from test4 import * + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + furthest_layer: SourceFile = codebase.get_file("test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test2.py") + + main_test = main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const = furthest_layer.get_symbol("test_const") + test_not_used = furthest_layer.get_symbol("test_not_used") + test_used_parent = furthest_layer.get_symbol("test_used_parent") + + assert len(test_const.usages) == 1 + assert test_const.usages[0].usage_symbol == main_test + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 + + +def test_import_resolution_init_deep_nested_wildcards_named(tmpdir: str) -> None: + """Tests that chaining wildcard imports resolves properly""" + # language=python + + files = { + "test/nest/nest2/test1.py": """test_const=5 + test_not_used=2 + test_used_parent=5 + """, + "test/nest/nest2/__init__.py": """from .test1 import * + t1=test_used_parent + """, + "test/nest/__init__.py": """from .nest2 import *""", + "test/__init__.py": """from .nest import *""", + "main.py": """ + from test import test_const + main_test=test_const + """, + } + with get_codebase_session(tmpdir=tmpdir, files=files) as codebase: + deepest_layer: SourceFile = codebase.get_file("test/nest/nest2/test1.py") + main: SourceFile = codebase.get_file("main.py") + parent_file: SourceFile = codebase.get_file("test/nest/nest2/__init__.py") + test_nest: SourceFile = codebase.get_file("test/__init__.py") + + main_test = main.get_symbol("main_test") + t1 = parent_file.get_symbol("t1") + test_const = deepest_layer.get_symbol("test_const") + test_not_used = deepest_layer.get_symbol("test_not_used") + test_used_parent = deepest_layer.get_symbol("test_used_parent") + + test_const_imp = main.get_import("test_const") + test_const_imp_2 = test_nest.get_import(".nest") + + assert len(test_const.usages) == 3 + assert test_const.usages[0].usage_symbol == main_test + assert test_const.usages[1].usage_symbol == test_const_imp + assert test_const.usages[2].usage_symbol == test_const_imp_2 + + assert len(test_not_used.usages) == 0 + assert len(test_used_parent.usages) == 1 + assert test_used_parent.usages[0].usage_symbol == t1 + + def test_import_resolution_circular(tmpdir: str) -> None: """Tests function.usages returns usages from file imports""" # language=python @@ -343,28 +543,21 @@ def some_func(): assert len(some_func.symbol_usages) > 0 -def test_import_wildcard_preserves_import_resolution(tmpdir: str) -> None: - """Tests importing from a file that contains a wildcard import doesn't break further resolution. - This could occur depending on to_resolve ordering, if the outer file is processed first _wildcards will not be filled in time. - """ +def test_import_nested_installable_resolution(tmpdir: str) -> None: + """Tests that a nested installable resolves internally instead of as external""" # language=python - with get_codebase_session( - tmpdir, - files={ - "testdir/sub/file.py": """ - test_const=5 - b=2 - """, - "testdir/file.py": """ - from testdir.sub.file import * - c=b - """, - "file.py": """ - from testdir.file import test_const - test = test_const - """, - }, - ) as codebase: - mainfile: SourceFile = codebase.get_file("file.py") - - assert len(mainfile.ctx.edges) == 5 + content1 = """ + TEST_CONST=5 + """ + content2 = """from test_pack.test import TEST_CONST + test=TEST_CONST""" + with get_codebase_session(tmpdir=tmpdir, files={"test_pack/test_pack/test.py": content1, "test1.py": content2}) as codebase: + file1: SourceFile = codebase.get_file("test_pack/test_pack/test.py") + file2: SourceFile = codebase.get_file("test1.py") + + symb = file1.get_symbol("TEST_CONST") + test = file2.get_symbol("test") + test_import = file2.get_import("TEST_CONST") + + assert len(symb.usages) == 2 + assert symb.symbol_usages == [test, test_import]