diff --git a/src/codegen/sdk/python/assignment.py b/src/codegen/sdk/python/assignment.py index f295f741c..7a521e67b 100644 --- a/src/codegen/sdk/python/assignment.py +++ b/src/codegen/sdk/python/assignment.py @@ -1,9 +1,13 @@ from __future__ import annotations +from collections.abc import Collection from typing import TYPE_CHECKING +from codegen.sdk.codebase.transactions import RemoveTransaction, TransactionPriority from codegen.sdk.core.assignment import Assignment +from codegen.sdk.core.autocommit.decorators import remover from codegen.sdk.core.expressions.multi_expression import MultiExpression +from codegen.sdk.core.statements.assignment_statement import AssignmentStatement from codegen.sdk.extensions.autocommit import reader from codegen.sdk.python.symbol import PySymbol from codegen.sdk.python.symbol_groups.comment_group import PyCommentGroup @@ -96,3 +100,63 @@ def inline_comment(self) -> PyCommentGroup | None: """ # HACK: This is a temporary solution until comments are fixed return PyCommentGroup.from_symbol_inline_comments(self, self.ts_node.parent) + + @noapidoc + def _partial_remove_when_tuple(self, name, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True): + idx = self.parent.left.index(name) + value = self.value[idx] + self.parent._values_scheduled_for_removal.append(value) + # Special case for removing brackets of value + if len(self.value) - len(self.parent._values_scheduled_for_removal) == 1: + remainder = str(next(x for x in self.value if x not in self.parent._values_scheduled_for_removal and x != value)) + r_t = RemoveTransaction(self.value.start_byte, self.value.end_byte, self.file, priority=priority) + self.transaction_manager.add_transaction(r_t) + self.value.insert_at(self.value.start_byte, remainder, priority=priority) + else: + # Normal just remove one value + value.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) + # Remove assignment name + name.remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) + + @noapidoc + def _active_transactions_on_assignment_names(self, transaction_order: TransactionPriority) -> int: + return [ + any(self.transaction_manager.get_transactions_at_range(self.file.path, start_byte=asgnmt.get_name().start_byte, end_byte=asgnmt.get_name().end_byte, transaction_order=transaction_order)) + for asgnmt in self.parent.assignments + ].count(True) + + @remover + def remove(self, delete_formatting: bool = True, priority: int = 0, dedupe: bool = True) -> None: + """Deletes this assignment and its related extended nodes (e.g. decorators, comments). + + + Removes the current node and its extended nodes (e.g. decorators, comments) from the codebase. + After removing the node, it handles cleanup of any surrounding formatting based on the context. + + Args: + delete_formatting (bool): Whether to delete surrounding whitespace and formatting. Defaults to True. + priority (int): Priority of the removal transaction. Higher priority transactions are executed first. Defaults to 0. + dedupe (bool): Whether to deduplicate removal transactions at the same location. Defaults to True. + + Returns: + None + """ + if self.ctx.config.feature_flags.unpacking_assignment_partial_removal: + if isinstance(self.parent, AssignmentStatement) and len(self.parent.assignments) > 1: + # Unpacking assignments + name = self.get_name() + if isinstance(self.value, Collection): + if len(self.parent._values_scheduled_for_removal) < len(self.parent.assignments) - 1: + self._partial_remove_when_tuple(name, delete_formatting, priority, dedupe) + return + else: + self.parent._values_scheduled_for_removal = [] + else: + transaction_count = self._active_transactions_on_assignment_names(TransactionPriority.Edit) + throwaway = [asgnmt.name == "_" for asgnmt in self.parent.assignments].count(True) + # Only edit if we didn't already omit all the other assignments, otherwise just remove the whole thing + if transaction_count + throwaway < len(self.parent.assignments) - 1: + name.edit("_", priority=priority, dedupe=dedupe) + return + + super().remove(delete_formatting=delete_formatting, priority=priority, dedupe=dedupe) diff --git a/src/codegen/sdk/python/statements/assignment_statement.py b/src/codegen/sdk/python/statements/assignment_statement.py index a53690fcd..ad433ab3c 100644 --- a/src/codegen/sdk/python/statements/assignment_statement.py +++ b/src/codegen/sdk/python/statements/assignment_statement.py @@ -30,6 +30,10 @@ class PyAssignmentStatement(AssignmentStatement["PyCodeBlock", PyAssignment]): assignment_types = {"assignment", "augmented_assignment", "named_expression"} + def __init__(self, ts_node, file_node_id, ctx, parent, pos, assignment_node): + super().__init__(ts_node, file_node_id, ctx, parent, pos, assignment_node) + self._values_scheduled_for_removal = [] + @classmethod def from_assignment(cls, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int, assignment_node: TSNode) -> PyAssignmentStatement: """Creates a PyAssignmentStatement instance from a TreeSitter assignment node. diff --git a/src/codegen/shared/configs/models/feature_flags.py b/src/codegen/shared/configs/models/feature_flags.py index 8cecd1dca..f1030f9fb 100644 --- a/src/codegen/shared/configs/models/feature_flags.py +++ b/src/codegen/shared/configs/models/feature_flags.py @@ -28,6 +28,7 @@ class CodebaseFeatureFlags(BaseSettings): generics: bool = True import_resolution_overrides: dict[str, str] = Field(default_factory=lambda: {}) typescript: TypescriptConfig = Field(default_factory=TypescriptConfig) + unpacking_assignment_partial_removal: bool = True class FeatureFlagsConfig(BaseModel): diff --git a/tests/unit/codegen/sdk/python/expressions/test_unpacking.py b/tests/unit/codegen/sdk/python/expressions/test_unpacking.py new file mode 100644 index 000000000..cdf853e37 --- /dev/null +++ b/tests/unit/codegen/sdk/python/expressions/test_unpacking.py @@ -0,0 +1,157 @@ +from codegen.sdk.codebase.factory.get_session import get_codebase_session + + +def test_remove_unpacking_assignment(tmpdir) -> None: + # language=python + content = """foo,bar,buzz = (a, b, c)""" + + with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase: + file1 = codebase.get_file("test1.py") + file2 = codebase.get_file("test2.py") + file3 = codebase.get_file("test3.py") + + foo = file1.get_symbol("foo") + foo.remove() + codebase.commit() + + assert len(file1.symbols) == 2 + statement = file1.symbols[0].parent + assert len(statement.assignments) == 2 + assert len(statement.value) == 2 + assert file1.source == """bar,buzz = (b, c)""" + bar = file2.get_symbol("bar") + bar.remove() + codebase.commit() + assert len(file2.symbols) == 2 + statement = file2.symbols[0].parent + assert len(statement.assignments) == 2 + assert len(statement.value) == 2 + assert file2.source == """foo,buzz = (a, c)""" + + buzz = file3.get_symbol("buzz") + buzz.remove() + codebase.commit() + + assert len(file3.symbols) == 2 + statement = file3.symbols[0].parent + assert len(statement.assignments) == 2 + assert len(statement.value) == 2 + assert file3.source == """foo,bar = (a, b)""" + + file1_bar = file1.get_symbol("bar") + + file1_bar.remove() + codebase.commit() + assert file1.source == """buzz = c""" + + file1_buzz = file1.get_symbol("buzz") + file1_buzz.remove() + + codebase.commit() + assert len(file1.symbols) == 0 + assert file1.source == """""" + + +def test_remove_unpacking_assignment_funct(tmpdir) -> None: + # language=python + content = """foo,bar,buzz = f()""" + + with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content, "test3.py": content}) as codebase: + file1 = codebase.get_file("test1.py") + file2 = codebase.get_file("test2.py") + file3 = codebase.get_file("test3.py") + + foo = file1.get_symbol("foo") + foo.remove() + codebase.commit() + + assert len(file1.symbols) == 3 + statement = file1.symbols[0].parent + assert len(statement.assignments) == 3 + assert file1.source == """_,bar,buzz = f()""" + bar = file2.get_symbol("bar") + bar.remove() + codebase.commit() + assert len(file2.symbols) == 3 + statement = file2.symbols[0].parent + assert len(statement.assignments) == 3 + assert file2.source == """foo,_,buzz = f()""" + + buzz = file3.get_symbol("buzz") + buzz.remove() + codebase.commit() + + assert len(file3.symbols) == 3 + statement = file3.symbols[0].parent + assert len(statement.assignments) == 3 + assert file3.source == """foo,bar,_ = f()""" + + file1_bar = file1.get_symbol("bar") + file1_buzz = file1.get_symbol("buzz") + + file1_bar.remove() + file1_buzz.remove() + codebase.commit() + assert len(file1.symbols) == 0 + assert file1.source == """""" + + +def test_remove_unpacking_assignment_num(tmpdir) -> None: + # language=python + content = """a,b,c,d,e,f = (1, 2, 2, 4, 5, 3)""" + + with get_codebase_session(tmpdir=tmpdir, files={"test1.py": content, "test2.py": content}) as codebase: + file1 = codebase.get_file("test1.py") + + a = file1.get_symbol("a") + d = file1.get_symbol("d") + + a.remove() + d.remove() + codebase.commit() + + assert len(file1.symbols) == 4 + statement = file1.symbols[0].parent + assert len(statement.assignments) == 4 + assert file1.source == """b,c,e,f = (2, 2, 5, 3)""" + + e = file1.get_symbol("e") + c = file1.get_symbol("c") + + e.remove() + c.remove() + codebase.commit() + + assert len(file1.symbols) == 2 + statement = file1.symbols[0].parent + assert len(statement.assignments) == 2 + assert file1.source == """b,f = (2, 3)""" + + f = file1.get_symbol("f") + + f.remove() + codebase.commit() + + assert len(file1.symbols) == 1 + statement = file1.symbols[0].parent + assert len(statement.assignments) == 1 + assert file1.source == """b = 2""" + file2 = codebase.get_file("test2.py") + a = file2.get_symbol("a") + d = file2.get_symbol("d") + e = file2.get_symbol("e") + c = file2.get_symbol("c") + f = file2.get_symbol("f") + b = file2.get_symbol("b") + + a.remove() + b.remove() + c.remove() + d.remove() + e.remove() + f.remove() + + codebase.commit() + + assert len(file2.symbols) == 0 + assert file2.source == """"""