From a1357406b8976a6c1b866c81262fc09c0ffdf0ef Mon Sep 17 00:00:00 2001 From: tkucar Date: Tue, 4 Mar 2025 21:12:55 +0100 Subject: [PATCH 1/8] conditional assigment fix --- src/codegen/sdk/core/assignment.py | 19 ++++++++++++++ src/codegen/sdk/core/expressions/name.py | 10 ++++++++ .../sdk/python/statements/match_case.py | 3 ++- .../sdk/python/statements/match_statement.py | 2 +- .../sdk/typescript/statements/switch_case.py | 3 ++- .../typescript/statements/switch_statement.py | 2 +- .../test_if_block_statement_properties.py | 21 ++++++++++++++++ .../test_try_catch_statement.py | 19 ++++++++++++++ .../test_match_statement.py | 25 +++++++++++++++++++ 9 files changed, 100 insertions(+), 4 deletions(-) diff --git a/src/codegen/sdk/core/assignment.py b/src/codegen/sdk/core/assignment.py index 116fca79d..9d2b409c0 100644 --- a/src/codegen/sdk/core/assignment.py +++ b/src/codegen/sdk/core/assignment.py @@ -285,3 +285,22 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - for usage in self.usages: if usage.match == self.name: usage.match.reduce_condition(bool_condition) + + @noapidoc + def nested_blocks_for_conditional_parent(self): + from codegen.sdk.core.statements.if_block_statement import IfBlockStatement + from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement + from codegen.sdk.python.statements.match_statement import PyMatchCase + from codegen.sdk.typescript.statements.switch_statement import TSSwitchCase + + conditionals = {TryCatchStatement,IfBlockStatement,PyMatchCase,TSSwitchCase} + if parent:= self.parent_of_types(conditionals): + match parent: + case IfBlockStatement(): + if parent._main_if_block: + return parent._main_if_block.nested_code_blocks + case PyMatchCase() | TSSwitchCase(): + return parent.match_statement.nested_code_blocks + case _: + return parent.nested_code_blocks + diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index 3ee1b6411..16206aef8 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -24,6 +24,7 @@ class Name(Expression[Parent], Resolvable, Generic[Parent]): composed of a name. """ + @reader @noapidoc @override @@ -31,6 +32,15 @@ def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: """Resolve the types used by this symbol.""" if used := self.resolve_name(self.source, self.start_byte): yield from self.with_resolution_frame(used) + from codegen.sdk.core.assignment import Assignment + if isinstance(used,Assignment): + if nested_blocks := used.nested_blocks_for_conditional_parent(): + blocks = nested_blocks[:-1] + for block in blocks: + for assignment in block.local_var_assignments: + if assignment.name==self.source: + yield from self.with_resolution_frame(assignment) + @noapidoc @commiter diff --git a/src/codegen/sdk/python/statements/match_case.py b/src/codegen/sdk/python/statements/match_case.py index 69528fbba..a6a4c71ce 100644 --- a/src/codegen/sdk/python/statements/match_case.py +++ b/src/codegen/sdk/python/statements/match_case.py @@ -17,6 +17,7 @@ class PyMatchCase(SwitchCase[PyCodeBlock["PyMatchStatement"]], PyBlockStatement): """Python match case.""" - def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, pos: int | None = None) -> None: + def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, match_statement:"PyMatchStatement", pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("alternative") + self.match_statement=match_statement diff --git a/src/codegen/sdk/python/statements/match_statement.py b/src/codegen/sdk/python/statements/match_statement.py index 804ff2029..ac52a652e 100644 --- a/src/codegen/sdk/python/statements/match_statement.py +++ b/src/codegen/sdk/python/statements/match_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.children_by_field_name("alternative"): - self.cases.append(PyMatchCase(node, file_node_id, ctx, self.parent, self.index)) + self.cases.append(PyMatchCase(node, file_node_id, ctx, self.parent, self, self.index)) diff --git a/src/codegen/sdk/typescript/statements/switch_case.py b/src/codegen/sdk/typescript/statements/switch_case.py index cdd43e1dd..031f0137b 100644 --- a/src/codegen/sdk/typescript/statements/switch_case.py +++ b/src/codegen/sdk/typescript/statements/switch_case.py @@ -23,7 +23,8 @@ class TSSwitchCase(SwitchCase[TSCodeBlock["TSSwitchStatement"]], TSBlockStatemen default: bool - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: TSCodeBlock, pos: int | None = None) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: TSCodeBlock,match_statement, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("value") self.default = self.ts_node.type == "switch_default" + self.match_statement=match_statement diff --git a/src/codegen/sdk/typescript/statements/switch_statement.py b/src/codegen/sdk/typescript/statements/switch_statement.py index 914bde227..ac63a2918 100644 --- a/src/codegen/sdk/typescript/statements/switch_statement.py +++ b/src/codegen/sdk/typescript/statements/switch_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.named_children: - self.cases.append(TSSwitchCase(node, file_node_id, ctx, self.parent)) + self.cases.append(TSSwitchCase(node, file_node_id, ctx, self.parent,self)) diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py index 23a3a7e6a..59f52e85c 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py @@ -126,3 +126,24 @@ def foo(): assert len(alt_blocks[2].alternative_blocks) == 0 assert len(alt_blocks[2].elif_statements) == 0 assert alt_blocks[2].else_statement is None + +def test_if_else_reassigment_handling(tmpdir) -> None: + content=""" + if True: + PYSPARK = True + elif False: + PYSPARK = False + else: + PYSPARK = None + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call=file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match==pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py index 0a1928551..e55f57f33 100644 --- a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py +++ b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py @@ -75,3 +75,22 @@ def risky(): assert not file.function_calls[0].is_wrapped_in(TryCatchStatement) assert file.function_calls[1].is_wrapped_in(TryCatchStatement) assert file.function_calls[2].is_wrapped_in(TryCatchStatement) + +def test_try_except_reassigment_handling(tmpdir) -> None: + content=""" + try: + PYSPARK = True # This gets removed even though there is a later use + except ImportError: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call=file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match==pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py index 3a07f1d81..ad4a7df00 100644 --- a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py +++ b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py @@ -53,3 +53,28 @@ def risky(): assert len(dependencies) == 1 global_var = file.get_global_var("risky_var") assert dependencies[0] == global_var + + + +def test_match_reassigment_handling(tmpdir) -> None: + content=""" +filter = 1 +match filter: + case 1: + PYSPARK=True + case 2: + PYSPARK=False + case _: + PYSPARK=None + +print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call=file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols[1:]: + usage = symb.usages[0] + assert usage.match==pyspark_arg From 3a5dd3aaf1a08050a6e16c535e45974b27f72d93 Mon Sep 17 00:00:00 2001 From: tomcodgen <191515280+tomcodgen@users.noreply.github.com> Date: Tue, 4 Mar 2025 20:14:44 +0000 Subject: [PATCH 2/8] Automated pre-commit update --- src/codegen/sdk/core/assignment.py | 5 ++--- src/codegen/sdk/core/expressions/name.py | 7 +++---- src/codegen/sdk/python/statements/match_case.py | 4 ++-- src/codegen/sdk/typescript/statements/switch_case.py | 4 ++-- src/codegen/sdk/typescript/statements/switch_statement.py | 2 +- .../test_if_block_statement_properties.py | 7 ++++--- .../statements/match_statement/test_try_catch_statement.py | 7 ++++--- .../statements/try_catch_statement/test_match_statement.py | 7 +++---- 8 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/codegen/sdk/core/assignment.py b/src/codegen/sdk/core/assignment.py index 9d2b409c0..87773e404 100644 --- a/src/codegen/sdk/core/assignment.py +++ b/src/codegen/sdk/core/assignment.py @@ -293,8 +293,8 @@ def nested_blocks_for_conditional_parent(self): from codegen.sdk.python.statements.match_statement import PyMatchCase from codegen.sdk.typescript.statements.switch_statement import TSSwitchCase - conditionals = {TryCatchStatement,IfBlockStatement,PyMatchCase,TSSwitchCase} - if parent:= self.parent_of_types(conditionals): + conditionals = {TryCatchStatement, IfBlockStatement, PyMatchCase, TSSwitchCase} + if parent := self.parent_of_types(conditionals): match parent: case IfBlockStatement(): if parent._main_if_block: @@ -303,4 +303,3 @@ def nested_blocks_for_conditional_parent(self): return parent.match_statement.nested_code_blocks case _: return parent.nested_code_blocks - diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index 16206aef8..196473675 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -24,7 +24,6 @@ class Name(Expression[Parent], Resolvable, Generic[Parent]): composed of a name. """ - @reader @noapidoc @override @@ -33,15 +32,15 @@ def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: if used := self.resolve_name(self.source, self.start_byte): yield from self.with_resolution_frame(used) from codegen.sdk.core.assignment import Assignment - if isinstance(used,Assignment): + + if isinstance(used, Assignment): if nested_blocks := used.nested_blocks_for_conditional_parent(): blocks = nested_blocks[:-1] for block in blocks: for assignment in block.local_var_assignments: - if assignment.name==self.source: + if assignment.name == self.source: yield from self.with_resolution_frame(assignment) - @noapidoc @commiter def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None: diff --git a/src/codegen/sdk/python/statements/match_case.py b/src/codegen/sdk/python/statements/match_case.py index a6a4c71ce..9f3638b7c 100644 --- a/src/codegen/sdk/python/statements/match_case.py +++ b/src/codegen/sdk/python/statements/match_case.py @@ -17,7 +17,7 @@ class PyMatchCase(SwitchCase[PyCodeBlock["PyMatchStatement"]], PyBlockStatement): """Python match case.""" - def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, match_statement:"PyMatchStatement", pos: int | None = None) -> None: + def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, match_statement: "PyMatchStatement", pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("alternative") - self.match_statement=match_statement + self.match_statement = match_statement diff --git a/src/codegen/sdk/typescript/statements/switch_case.py b/src/codegen/sdk/typescript/statements/switch_case.py index 031f0137b..fd796a213 100644 --- a/src/codegen/sdk/typescript/statements/switch_case.py +++ b/src/codegen/sdk/typescript/statements/switch_case.py @@ -23,8 +23,8 @@ class TSSwitchCase(SwitchCase[TSCodeBlock["TSSwitchStatement"]], TSBlockStatemen default: bool - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: TSCodeBlock,match_statement, pos: int | None = None) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: TSCodeBlock, match_statement, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("value") self.default = self.ts_node.type == "switch_default" - self.match_statement=match_statement + self.match_statement = match_statement diff --git a/src/codegen/sdk/typescript/statements/switch_statement.py b/src/codegen/sdk/typescript/statements/switch_statement.py index ac63a2918..f880a113a 100644 --- a/src/codegen/sdk/typescript/statements/switch_statement.py +++ b/src/codegen/sdk/typescript/statements/switch_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.named_children: - self.cases.append(TSSwitchCase(node, file_node_id, ctx, self.parent,self)) + self.cases.append(TSSwitchCase(node, file_node_id, ctx, self.parent, self)) diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py index 59f52e85c..6f8392304 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py @@ -127,8 +127,9 @@ def foo(): assert len(alt_blocks[2].elif_statements) == 0 assert alt_blocks[2].else_statement is None + def test_if_else_reassigment_handling(tmpdir) -> None: - content=""" + content = """ if True: PYSPARK = True elif False: @@ -142,8 +143,8 @@ def test_if_else_reassigment_handling(tmpdir) -> None: with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") symbo = file.get_symbol("PYSPARK") - funct_call=file.function_calls[0] + funct_call = file.function_calls[0] pyspark_arg = funct_call.args.children[0] for symb in file.symbols: usage = symb.usages[0] - assert usage.match==pyspark_arg + assert usage.match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py index e55f57f33..76bb5d0f4 100644 --- a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py +++ b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py @@ -76,8 +76,9 @@ def risky(): assert file.function_calls[1].is_wrapped_in(TryCatchStatement) assert file.function_calls[2].is_wrapped_in(TryCatchStatement) + def test_try_except_reassigment_handling(tmpdir) -> None: - content=""" + content = """ try: PYSPARK = True # This gets removed even though there is a later use except ImportError: @@ -89,8 +90,8 @@ def test_try_except_reassigment_handling(tmpdir) -> None: with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") symbo = file.get_symbol("PYSPARK") - funct_call=file.function_calls[0] + funct_call = file.function_calls[0] pyspark_arg = funct_call.args.children[0] for symb in file.symbols: usage = symb.usages[0] - assert usage.match==pyspark_arg + assert usage.match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py index ad4a7df00..be972cffd 100644 --- a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py +++ b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py @@ -55,9 +55,8 @@ def risky(): assert dependencies[0] == global_var - def test_match_reassigment_handling(tmpdir) -> None: - content=""" + content = """ filter = 1 match filter: case 1: @@ -73,8 +72,8 @@ def test_match_reassigment_handling(tmpdir) -> None: with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: file = codebase.get_file("test.py") symbo = file.get_symbol("PYSPARK") - funct_call=file.function_calls[0] + funct_call = file.function_calls[0] pyspark_arg = funct_call.args.children[0] for symb in file.symbols[1:]: usage = symb.usages[0] - assert usage.match==pyspark_arg + assert usage.match == pyspark_arg From 67dbce2db05e661e33d6ae3fa873fbfadd1691c2 Mon Sep 17 00:00:00 2001 From: tomcodegen Date: Fri, 7 Mar 2025 16:13:24 -0800 Subject: [PATCH 3/8] changes --- src/codegen/sdk/core/assignment.py | 17 --- src/codegen/sdk/core/expressions/name.py | 44 ++++--- src/codegen/sdk/core/file.py | 19 +-- src/codegen/sdk/core/function.py | 7 +- .../sdk/core/interfaces/conditional_block.py | 17 +++ src/codegen/sdk/core/interfaces/editable.py | 7 +- .../sdk/core/statements/catch_statement.py | 3 +- .../sdk/core/statements/for_loop_statement.py | 18 ++- .../sdk/core/statements/if_block_statement.py | 29 ++++- .../sdk/core/statements/switch_case.py | 7 +- .../core/statements/try_catch_statement.py | 3 +- src/codegen/sdk/python/function.py | 12 +- .../sdk/python/statements/catch_statement.py | 6 +- .../python/statements/if_block_statement.py | 1 - .../sdk/python/statements/match_case.py | 10 +- .../sdk/python/statements/match_statement.py | 2 +- .../python/statements/try_catch_statement.py | 14 +++ src/codegen/sdk/typescript/export.py | 4 +- src/codegen/sdk/typescript/function.py | 9 +- .../typescript/statements/catch_statement.py | 6 +- .../sdk/typescript/statements/switch_case.py | 3 +- .../typescript/statements/switch_statement.py | 2 +- .../statements/try_catch_statement.py | 16 +++ .../test_if_block_statement_properties.py | 118 ++++++++++++++++++ uv.lock | 2 +- 25 files changed, 303 insertions(+), 73 deletions(-) create mode 100644 src/codegen/sdk/core/interfaces/conditional_block.py diff --git a/src/codegen/sdk/core/assignment.py b/src/codegen/sdk/core/assignment.py index 87773e404..dc939da27 100644 --- a/src/codegen/sdk/core/assignment.py +++ b/src/codegen/sdk/core/assignment.py @@ -286,20 +286,3 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - if usage.match == self.name: usage.match.reduce_condition(bool_condition) - @noapidoc - def nested_blocks_for_conditional_parent(self): - from codegen.sdk.core.statements.if_block_statement import IfBlockStatement - from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement - from codegen.sdk.python.statements.match_statement import PyMatchCase - from codegen.sdk.typescript.statements.switch_statement import TSSwitchCase - - conditionals = {TryCatchStatement, IfBlockStatement, PyMatchCase, TSSwitchCase} - if parent := self.parent_of_types(conditionals): - match parent: - case IfBlockStatement(): - if parent._main_if_block: - return parent._main_if_block.nested_code_blocks - case PyMatchCase() | TSSwitchCase(): - return parent.match_statement.nested_code_blocks - case _: - return parent.nested_code_blocks diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index 196473675..ba49595ae 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -5,13 +5,15 @@ from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.expressions.expression import Expression +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.resolvable import Resolvable from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from codegen.sdk.core.import_resolution import Import, WildcardImport from codegen.sdk.core.interfaces.has_name import HasName - + from codegen.sdk.core.symbol import Symbol Parent = TypeVar("Parent", bound="Expression") @@ -29,19 +31,9 @@ class Name(Expression[Parent], Resolvable, Generic[Parent]): @override def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: """Resolve the types used by this symbol.""" - if used := self.resolve_name(self.source, self.start_byte): - yield from self.with_resolution_frame(used) - from codegen.sdk.core.assignment import Assignment - - if isinstance(used, Assignment): - if nested_blocks := used.nested_blocks_for_conditional_parent(): - blocks = nested_blocks[:-1] - for block in blocks: - for assignment in block.local_var_assignments: - if assignment.name == self.source: - yield from self.with_resolution_frame(assignment) - - @noapidoc + for used in self.resolve_name(self.source, self.start_byte): + if used: + yield from self.with_resolution_frame(used) @commiter def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None: """Compute the dependencies of the export object.""" @@ -57,3 +49,27 @@ def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | def rename_if_matching(self, old: str, new: str): if self.source == old: self.edit(new) + + @noapidoc + @reader + def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator["Symbol | Import | WildcardImport | None"]: + if self.parent is not None: + resolved_name = next(self.parent.resolve_name(name, start_byte or self.start_byte,strict=strict),None) + else: + resolved_name = next(self.file.resolve_name(name, start_byte or self.start_byte,strict=strict),None) + + yield resolved_name + + if resolved_name is not None: + if hasattr(resolved_name,'parent') and (conditional_parent:=resolved_name.parent_of_type(ConditionalBlock)): + top_of_conditional= conditional_parent.start_byte + if self.parent_of_type(ConditionalBlock) == conditional_parent: + #Use in the same block, should only depend on the inside of the block + return + for other_conditional in conditional_parent.other_possible_blocks: + if cond_name:=next(other_conditional.resolve_name(name,start_byte=other_conditional.end_byte_for_condition_block),None): + if cond_name.start_byte>=other_conditional.start_byte: + yield cond_name + top_of_conditional= min(top_of_conditional,other_conditional.start_byte) + + yield from self.resolve_name(name,top_of_conditional,strict=True) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 8ad9e1385..f40213473 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -3,7 +3,7 @@ import resource import sys from abc import abstractmethod -from collections.abc import Sequence +from collections.abc import Generator, Sequence from functools import cached_property from os import PathLike from pathlib import Path @@ -744,7 +744,7 @@ def get_symbol(self, name: str) -> Symbol | None: Returns: Symbol | None: The found symbol, or None if not found. """ - if symbol := self.resolve_name(name, self.end_byte): + if symbol := next(self.resolve_name(name, self.end_byte),None): if isinstance(symbol, Symbol): return symbol return next((x for x in self.symbols if x.name == name), None) @@ -819,7 +819,7 @@ def get_class(self, name: str) -> TClass | None: Returns: TClass | None: The matching Class object if found, None otherwise. """ - if symbol := self.resolve_name(name, self.end_byte): + if symbol := next(self.resolve_name(name, self.end_byte),None): if isinstance(symbol, Class): return symbol @@ -880,13 +880,18 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: if resolved := self.valid_symbol_names.get(name): if start_byte is not None and resolved.end_byte > start_byte: - for symbol in self.symbols: + for symbol in reversed(self.symbols): if symbol.start_byte <= start_byte and symbol.name == name: - return symbol - return resolved + yield symbol + return + if strict: + return + yield resolved + return + return @property @reader diff --git a/src/codegen/sdk/core/function.py b/src/codegen/sdk/core/function.py index 21aa3c1df..efd2b330b 100644 --- a/src/codegen/sdk/core/function.py +++ b/src/codegen/sdk/core/function.py @@ -141,13 +141,14 @@ def is_async(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: from codegen.sdk.core.class_definition import Class for symbol in self.valid_symbol_names: if symbol.name == name and (start_byte is None or (symbol.start_byte if isinstance(symbol, Class | Function) else symbol.end_byte) <= start_byte): - return symbol - return super().resolve_name(name, start_byte) + yield symbol + return + yield from super().resolve_name(name, start_byte) @cached_property @noapidoc diff --git a/src/codegen/sdk/core/interfaces/conditional_block.py b/src/codegen/sdk/core/interfaces/conditional_block.py new file mode 100644 index 000000000..96e6520de --- /dev/null +++ b/src/codegen/sdk/core/interfaces/conditional_block.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence + +from codegen.sdk.core.statements.statement import Statement + + +class ConditionalBlock(Statement,ABC): + """An interface for any code block that might not be executed in the code, e.g if block/else block/try block/catch block ect.""" + + @property + @abstractmethod + def other_possible_blocks(self) -> Sequence["ConditionalBlock"]: + """Should return all other "branches" that might be executed instead.""" + + @property + def end_byte_for_condition_block(self) -> int: + return self.end_byte diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index f59037144..773fd871a 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -1003,10 +1003,11 @@ def viz(self) -> VizNode: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport|None]: if self.parent is not None: - return self.parent.resolve_name(name, start_byte or self.start_byte) - return self.file.resolve_name(name, start_byte or self.start_byte) + yield from self.parent.resolve_name(name, start_byte or self.start_byte,strict=strict) + yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict) + return @cached_property @noapidoc diff --git a/src/codegen/sdk/core/statements/catch_statement.py b/src/codegen/sdk/core/statements/catch_statement.py index e9e96fa09..3df7839a7 100644 --- a/src/codegen/sdk/core/statements/catch_statement.py +++ b/src/codegen/sdk/core/statements/catch_statement.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Generic, Self, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc @@ -17,7 +18,7 @@ @apidoc -class CatchStatement(BlockStatement[Parent], Generic[Parent]): +class CatchStatement(ConditionalBlock,BlockStatement[Parent], Generic[Parent]): """Abstract representation catch clause. Attributes: diff --git a/src/codegen/sdk/core/statements/for_loop_statement.py b/src/codegen/sdk/core/statements/for_loop_statement.py index e6c6bc4b4..67bab5a02 100644 --- a/src/codegen/sdk/core/statements/for_loop_statement.py +++ b/src/codegen/sdk/core/statements/for_loop_statement.py @@ -12,6 +12,8 @@ from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from collections.abc import Generator + from codegen.sdk.core.detached_symbols.code_block import CodeBlock from codegen.sdk.core.expressions import Expression from codegen.sdk.core.import_resolution import Import, WildcardImport @@ -36,19 +38,23 @@ class ForLoopStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: if self.item and isinstance(self.iterable, Chainable): if start_byte is None or start_byte > self.iterable.end_byte: if name == self.item: for frame in self.iterable.resolved_type_frames: if frame.generics: - return next(iter(frame.generics.values())) - return frame.top.node + yield next(iter(frame.generics.values())) + return + yield frame.top.node + return elif isinstance(self.item, Collection): for idx, item in enumerate(self.item): if item == name: for frame in self.iterable.resolved_type_frames: if frame.generics and len(frame.generics) > idx: - return list(frame.generics.values())[idx] - return frame.top.node - return super().resolve_name(name, start_byte) + yield list(frame.generics.values())[idx] + return + yield frame.top.node + return + yield from super().resolve_name(name, start_byte) diff --git a/src/codegen/sdk/core/statements/if_block_statement.py b/src/codegen/sdk/core/statements/if_block_statement.py index 31e2fcbe8..de9ab443b 100644 --- a/src/codegen/sdk/core/statements/if_block_statement.py +++ b/src/codegen/sdk/core/statements/if_block_statement.py @@ -8,11 +8,14 @@ from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.function import Function +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.statement import Statement, StatementType from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from collections.abc import Sequence + from codegen.sdk.core.detached_symbols.code_block import CodeBlock from codegen.sdk.core.detached_symbols.function_call import FunctionCall from codegen.sdk.core.expressions import Expression @@ -26,7 +29,7 @@ @apidoc -class IfBlockStatement(Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): +class IfBlockStatement(ConditionalBlock,Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): """Abstract representation of the if/elif/else if/else statement block. For example, if there is a code block like: @@ -271,3 +274,27 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - self.remove_byte_range(self.ts_node.start_byte, remove_end) else: self.remove() + + + @property + def other_possible_blocks(self)-> Sequence[ConditionalBlock]: + if self.is_if_statement: + return self._main_if_block.alternative_blocks + elif self.is_elif_statement: + main = self._main_if_block + statements = [main] + if main.else_statement: + statements.append(main.else_statement) + for statement in main.elif_statements: + if statement!=self: + statements.append(statement) + return statements + else: + main = self._main_if_block + return [main, *main.elif_statements] + + @property + def end_byte_for_condition_block(self) -> int: + if self.is_if_statement: + return self.consequence_block.end_byte + return self.end_byte diff --git a/src/codegen/sdk/core/statements/switch_case.py b/src/codegen/sdk/core/statements/switch_case.py index 3ebafb57e..514b3e75e 100644 --- a/src/codegen/sdk/core/statements/switch_case.py +++ b/src/codegen/sdk/core/statements/switch_case.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Generic, Self, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc @@ -18,7 +19,7 @@ @apidoc -class SwitchCase(BlockStatement[Parent], Generic[Parent]): +class SwitchCase(ConditionalBlock,BlockStatement[Parent], Generic[Parent]): """Abstract representation for a switch case. Attributes: @@ -34,3 +35,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa if self.condition: self.condition._compute_dependencies(usage_type, dest) super()._compute_dependencies(usage_type, dest) + + @property + def other_possible_blocks(self)-> list[ConditionalBlock]: + return [case for case in self.parent.cases if case!=self] diff --git a/src/codegen/sdk/core/statements/try_catch_statement.py b/src/codegen/sdk/core/statements/try_catch_statement.py index 1371a2d76..181ad2c91 100644 --- a/src/codegen/sdk/core/statements/try_catch_statement.py +++ b/src/codegen/sdk/core/statements/try_catch_statement.py @@ -3,6 +3,7 @@ from abc import ABC from typing import TYPE_CHECKING, Generic, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_block import HasBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.core.statements.statement import StatementType @@ -16,7 +17,7 @@ @apidoc -class TryCatchStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): +class TryCatchStatement(ConditionalBlock,BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): """Abstract representation of the try catch statement block. Attributes: diff --git a/src/codegen/sdk/python/function.py b/src/codegen/sdk/python/function.py index 02a9dd55b..ebe88a485 100644 --- a/src/codegen/sdk/python/function.py +++ b/src/codegen/sdk/python/function.py @@ -19,6 +19,8 @@ from codegen.shared.logging.get_logger import get_logger if TYPE_CHECKING: + from collections.abc import Generator + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -119,15 +121,17 @@ def is_class_method(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: if self.is_method: if not self.is_static_method: if len(self.parameters.symbols) > 0: if name == self.parameters[0].name: - return self.parent_class + yield self.parent_class + return if name == "super()": - return self.parent_class - return super().resolve_name(name, start_byte) + yield self.parent_class + return + yield from super().resolve_name(name, start_byte) @noapidoc @commiter diff --git a/src/codegen/sdk/python/statements/catch_statement.py b/src/codegen/sdk/python/statements/catch_statement.py index 3bbea1b46..04d57fa73 100644 --- a/src/codegen/sdk/python/statements/catch_statement.py +++ b/src/codegen/sdk/python/statements/catch_statement.py @@ -11,9 +11,9 @@ from tree_sitter import Node as PyNode from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.node_id_factory import NodeId - @py_apidoc class PyCatchStatement(CatchStatement[PyCodeBlock], PyBlockStatement): """Python catch clause. @@ -26,3 +26,7 @@ class PyCatchStatement(CatchStatement[PyCodeBlock], PyBlockStatement): def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.children[0] + + @property + def other_possible_blocks(self)-> list[ConditionalBlock]: + return [clause for clause in self.parent.except_clauses if clause!=self]+[self.parent] diff --git a/src/codegen/sdk/python/statements/if_block_statement.py b/src/codegen/sdk/python/statements/if_block_statement.py index 54585b9e7..dc73b21dd 100644 --- a/src/codegen/sdk/python/statements/if_block_statement.py +++ b/src/codegen/sdk/python/statements/if_block_statement.py @@ -14,7 +14,6 @@ from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - Parent = TypeVar("Parent", bound="PyCodeBlock") diff --git a/src/codegen/sdk/python/statements/match_case.py b/src/codegen/sdk/python/statements/match_case.py index 9f3638b7c..02aee4e9f 100644 --- a/src/codegen/sdk/python/statements/match_case.py +++ b/src/codegen/sdk/python/statements/match_case.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.python.statements.match_statement import PyMatchStatement @@ -17,7 +18,12 @@ class PyMatchCase(SwitchCase[PyCodeBlock["PyMatchStatement"]], PyBlockStatement): """Python match case.""" - def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, match_statement: "PyMatchStatement", pos: int | None = None) -> None: + def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("alternative") - self.match_statement = match_statement + + + + @property + def other_possible_blocks(self)-> list["ConditionalBlock"]: + return [case for case in self.parent.cases if case!=self] diff --git a/src/codegen/sdk/python/statements/match_statement.py b/src/codegen/sdk/python/statements/match_statement.py index ac52a652e..59f01164c 100644 --- a/src/codegen/sdk/python/statements/match_statement.py +++ b/src/codegen/sdk/python/statements/match_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.children_by_field_name("alternative"): - self.cases.append(PyMatchCase(node, file_node_id, ctx, self.parent, self, self.index)) + self.cases.append(PyMatchCase(node, file_node_id, ctx, self, self.index)) diff --git a/src/codegen/sdk/python/statements/try_catch_statement.py b/src/codegen/sdk/python/statements/try_catch_statement.py index c4a4827b3..b54051f96 100644 --- a/src/codegen/sdk/python/statements/try_catch_statement.py +++ b/src/codegen/sdk/python/statements/try_catch_statement.py @@ -9,11 +9,14 @@ from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: + from collections.abc import Sequence + from tree_sitter import Node as PyNode from codegen.sdk.codebase.codebase_context import CodebaseContext from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.detached_symbols.function_call import FunctionCall + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.importable import Importable from codegen.sdk.core.node_id_factory import NodeId @@ -96,3 +99,14 @@ def nested_code_blocks(self) -> list[PyCodeBlock]: if self.finalizer: nested_blocks.append(self.finalizer.code_block) return nested_blocks + + @property + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + return self.except_clauses + + @property + def end_byte_for_condition_block(self) -> int: + if self.code_block: + return self.code_block.end_byte + else: + return self.end_byte diff --git a/src/codegen/sdk/typescript/export.py b/src/codegen/sdk/typescript/export.py index 44703749c..ec2009816 100644 --- a/src/codegen/sdk/typescript/export.py +++ b/src/codegen/sdk/typescript/export.py @@ -204,7 +204,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa if frame.parent_frame: frame.parent_frame.add_usage(self._name_node or self, UsageKind.EXPORTED_SYMBOL, self, self.ctx) elif self._exported_symbol: - if not self.resolve_name(self._exported_symbol.source): + if not next(self.resolve_name(self._exported_symbol.source),None): self._exported_symbol._compute_dependencies(UsageKind.BODY, dest=dest or self) elif self.value: self.value._compute_dependencies(UsageKind.EXPORTED_SYMBOL, self) @@ -218,7 +218,7 @@ def compute_export_dependencies(self) -> None: self.ctx.add_edge(self.node_id, self.declared_symbol.node_id, type=EdgeType.EXPORT) elif self._exported_symbol is not None: symbol_name = self._exported_symbol.source - if (used_node := self.resolve_name(symbol_name)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): + if (used_node := next(self.resolve_name(symbol_name),None)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): self.ctx.add_edge(self.node_id, used_node.node_id, type=EdgeType.EXPORT) elif self.value is not None: if isinstance(self.value, Chainable): diff --git a/src/codegen/sdk/typescript/function.py b/src/codegen/sdk/typescript/function.py index a7be7b28f..a6d76b649 100644 --- a/src/codegen/sdk/typescript/function.py +++ b/src/codegen/sdk/typescript/function.py @@ -19,6 +19,8 @@ from codegen.shared.logging.get_logger import get_logger if TYPE_CHECKING: + from collections.abc import Generator + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -358,7 +360,7 @@ def arrow_to_named(self, name: str | None = None) -> None: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: """Resolves the name of a symbol in the function. This method resolves the name of a symbol in the function. If the name is "this", it returns the parent class. @@ -373,8 +375,9 @@ def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Imp """ if self.is_method: if name == "this": - return self.parent_class - return super().resolve_name(name, start_byte) + yield self.parent_class + return + yield from super().resolve_name(name, start_byte) @staticmethod def is_valid_node(node: TSNode) -> bool: diff --git a/src/codegen/sdk/typescript/statements/catch_statement.py b/src/codegen/sdk/typescript/statements/catch_statement.py index c6dc10bae..4f534a6b4 100644 --- a/src/codegen/sdk/typescript/statements/catch_statement.py +++ b/src/codegen/sdk/typescript/statements/catch_statement.py @@ -10,10 +10,10 @@ from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - Parent = TypeVar("Parent", bound="TSCodeBlock") @@ -29,3 +29,7 @@ class TSCatchStatement(CatchStatement[Parent], TSBlockStatement, Generic[Parent] def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("parameter") + + @property + def other_possible_blocks(self)-> list[ConditionalBlock]: + return [self.parent] diff --git a/src/codegen/sdk/typescript/statements/switch_case.py b/src/codegen/sdk/typescript/statements/switch_case.py index fd796a213..cdd43e1dd 100644 --- a/src/codegen/sdk/typescript/statements/switch_case.py +++ b/src/codegen/sdk/typescript/statements/switch_case.py @@ -23,8 +23,7 @@ class TSSwitchCase(SwitchCase[TSCodeBlock["TSSwitchStatement"]], TSBlockStatemen default: bool - def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: TSCodeBlock, match_statement, pos: int | None = None) -> None: + def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: TSCodeBlock, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("value") self.default = self.ts_node.type == "switch_default" - self.match_statement = match_statement diff --git a/src/codegen/sdk/typescript/statements/switch_statement.py b/src/codegen/sdk/typescript/statements/switch_statement.py index f880a113a..0dbec180f 100644 --- a/src/codegen/sdk/typescript/statements/switch_statement.py +++ b/src/codegen/sdk/typescript/statements/switch_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.named_children: - self.cases.append(TSSwitchCase(node, file_node_id, ctx, self.parent, self)) + self.cases.append(TSSwitchCase(node, file_node_id, ctx, self)) diff --git a/src/codegen/sdk/typescript/statements/try_catch_statement.py b/src/codegen/sdk/typescript/statements/try_catch_statement.py index aa24178d2..e1099da06 100644 --- a/src/codegen/sdk/typescript/statements/try_catch_statement.py +++ b/src/codegen/sdk/typescript/statements/try_catch_statement.py @@ -9,11 +9,14 @@ from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from collections.abc import Sequence + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.detached_symbols.function_call import FunctionCall + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.importable import Importable from codegen.sdk.core.node_id_factory import NodeId @@ -91,3 +94,16 @@ def nested_code_blocks(self) -> list[TSCodeBlock]: if self.finalizer: nested_blocks.append(self.finalizer.code_block) return nested_blocks + @property + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + if self.catch: + return [self.catch] + else: + return [] + + @property + def end_byte_for_condition_block(self) -> int: + if self.code_block: + return self.code_block.end_byte + else: + return self.end_byte diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py index 6f8392304..ef675ca5f 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py @@ -128,8 +128,10 @@ def foo(): assert alt_blocks[2].else_statement is None + def test_if_else_reassigment_handling(tmpdir) -> None: content = """ + if True: PYSPARK = True elif False: @@ -148,3 +150,119 @@ def test_if_else_reassigment_handling(tmpdir) -> None: for symb in file.symbols: usage = symb.usages[0] assert usage.match == pyspark_arg + + +def test_if_else_reassigment_handling_function(tmpdir) -> None: + content = """ + if True: + def foo(): + print('t') + elif False: + def foo(): + print('t') + else: + def foo(): + print('t') + foo() + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + foo = file.get_function("foo") + funct_call = file.function_calls[3] + for funct in file.functions: + usage = funct.usages[0] + assert usage.match == funct_call + + +def test_if_else_reassigment_handling_inside_func(tmpdir) -> None: + content = """ + def foo(a): + a = 1 + if xyz: + b = 1 + else: + b = 2 + f(a) # a resolves to 1 name + f(b) # b resolves to 2 possible names + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + foo = file.get_function("foo") + assert foo + assert len(foo.parameters[0].usages)==0 + funct_call_a = foo.function_calls[0].args[0] + funct_call_b = foo.function_calls[1] + for symbol in file.symbols(True): + if symbol.name=='a': + assert len(symbol.usages)==1 + symbol.usages[0].match==funct_call_a + elif symbol.name=='b': + assert len(symbol.usages)==1 + symbol.usages[0].match==funct_call_b + + +def test_if_else_reassigment_handling_partial_if(tmpdir) -> None: + content = """ + PYSPARK = "TEST" + if True: + PYSPARK = True + elif None: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassigment_handling_double(tmpdir) -> None: + content = """ + if False: + PYSPARK = "TEST1" + elif True: + PYSPARK = "TEST2" + + if True: + PYSPARK = True + elif None: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + +def test_if_else_reassigment_handling_nested_usage(tmpdir) -> None: + content = """ + if True: + PYSPARK = True + elif None: + PYSPARK = False + print(PYSPARK) + + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + first=file.symbols[0] + second=file.symbols[1] + assert len(first.usages)==0 + assert second.usages[0].match == pyspark_arg diff --git a/uv.lock b/uv.lock index 9796f29bb..3ef32e3e8 100644 --- a/uv.lock +++ b/uv.lock @@ -690,7 +690,7 @@ requires-dist = [ { name = "langchain-anthropic", specifier = ">=0.3.7" }, { name = "langchain-core" }, { name = "langchain-openai" }, - { name = "langchain-xai" }, + { name = "langchain-xai", specifier = ">=0.2.1" }, { name = "langgraph" }, { name = "langgraph-prebuilt" }, { name = "langsmith" }, From cdcb9a0992050cac7d0e46efbf0b162e91d48d1b Mon Sep 17 00:00:00 2001 From: tomcodgen <191515280+tomcodgen@users.noreply.github.com> Date: Sat, 8 Mar 2025 00:16:13 +0000 Subject: [PATCH 4/8] Automated pre-commit update --- src/codegen/sdk/core/assignment.py | 1 - src/codegen/sdk/core/expressions/name.py | 21 +++++++++--------- src/codegen/sdk/core/file.py | 6 ++--- src/codegen/sdk/core/function.py | 2 +- .../sdk/core/interfaces/conditional_block.py | 2 +- src/codegen/sdk/core/interfaces/editable.py | 4 ++-- .../sdk/core/statements/catch_statement.py | 2 +- .../sdk/core/statements/for_loop_statement.py | 2 +- .../sdk/core/statements/if_block_statement.py | 7 +++--- .../sdk/core/statements/switch_case.py | 6 ++--- .../core/statements/try_catch_statement.py | 2 +- src/codegen/sdk/python/function.py | 2 +- .../sdk/python/statements/catch_statement.py | 5 +++-- .../sdk/python/statements/match_case.py | 6 ++--- src/codegen/sdk/typescript/export.py | 4 ++-- src/codegen/sdk/typescript/function.py | 2 +- .../typescript/statements/catch_statement.py | 2 +- .../statements/try_catch_statement.py | 1 + .../test_if_block_statement_properties.py | 22 +++++++++---------- 19 files changed, 49 insertions(+), 50 deletions(-) diff --git a/src/codegen/sdk/core/assignment.py b/src/codegen/sdk/core/assignment.py index dc939da27..116fca79d 100644 --- a/src/codegen/sdk/core/assignment.py +++ b/src/codegen/sdk/core/assignment.py @@ -285,4 +285,3 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - for usage in self.usages: if usage.match == self.name: usage.match.reduce_condition(bool_condition) - diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index ba49595ae..12c18b617 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -34,6 +34,7 @@ def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: for used in self.resolve_name(self.source, self.start_byte): if used: yield from self.with_resolution_frame(used) + @commiter def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None: """Compute the dependencies of the export object.""" @@ -52,24 +53,24 @@ def rename_if_matching(self, old: str, new: str): @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator["Symbol | Import | WildcardImport | None"]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator["Symbol | Import | WildcardImport | None"]: if self.parent is not None: - resolved_name = next(self.parent.resolve_name(name, start_byte or self.start_byte,strict=strict),None) + resolved_name = next(self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict), None) else: - resolved_name = next(self.file.resolve_name(name, start_byte or self.start_byte,strict=strict),None) + resolved_name = next(self.file.resolve_name(name, start_byte or self.start_byte, strict=strict), None) yield resolved_name if resolved_name is not None: - if hasattr(resolved_name,'parent') and (conditional_parent:=resolved_name.parent_of_type(ConditionalBlock)): - top_of_conditional= conditional_parent.start_byte + if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)): + top_of_conditional = conditional_parent.start_byte if self.parent_of_type(ConditionalBlock) == conditional_parent: - #Use in the same block, should only depend on the inside of the block + # Use in the same block, should only depend on the inside of the block return for other_conditional in conditional_parent.other_possible_blocks: - if cond_name:=next(other_conditional.resolve_name(name,start_byte=other_conditional.end_byte_for_condition_block),None): - if cond_name.start_byte>=other_conditional.start_byte: + if cond_name := next(other_conditional.resolve_name(name, start_byte=other_conditional.end_byte_for_condition_block), None): + if cond_name.start_byte >= other_conditional.start_byte: yield cond_name - top_of_conditional= min(top_of_conditional,other_conditional.start_byte) + top_of_conditional = min(top_of_conditional, other_conditional.start_byte) - yield from self.resolve_name(name,top_of_conditional,strict=True) + yield from self.resolve_name(name, top_of_conditional, strict=True) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index f40213473..1613e7f5d 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -744,7 +744,7 @@ def get_symbol(self, name: str) -> Symbol | None: Returns: Symbol | None: The found symbol, or None if not found. """ - if symbol := next(self.resolve_name(name, self.end_byte),None): + if symbol := next(self.resolve_name(name, self.end_byte), None): if isinstance(symbol, Symbol): return symbol return next((x for x in self.symbols if x.name == name), None) @@ -819,7 +819,7 @@ def get_class(self, name: str) -> TClass | None: Returns: TClass | None: The matching Class object if found, None otherwise. """ - if symbol := next(self.resolve_name(name, self.end_byte),None): + if symbol := next(self.resolve_name(name, self.end_byte), None): if isinstance(symbol, Class): return symbol @@ -880,7 +880,7 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: if resolved := self.valid_symbol_names.get(name): if start_byte is not None and resolved.end_byte > start_byte: for symbol in reversed(self.symbols): diff --git a/src/codegen/sdk/core/function.py b/src/codegen/sdk/core/function.py index efd2b330b..0a529c1f1 100644 --- a/src/codegen/sdk/core/function.py +++ b/src/codegen/sdk/core/function.py @@ -141,7 +141,7 @@ def is_async(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: from codegen.sdk.core.class_definition import Class for symbol in self.valid_symbol_names: diff --git a/src/codegen/sdk/core/interfaces/conditional_block.py b/src/codegen/sdk/core/interfaces/conditional_block.py index 96e6520de..a11990908 100644 --- a/src/codegen/sdk/core/interfaces/conditional_block.py +++ b/src/codegen/sdk/core/interfaces/conditional_block.py @@ -4,7 +4,7 @@ from codegen.sdk.core.statements.statement import Statement -class ConditionalBlock(Statement,ABC): +class ConditionalBlock(Statement, ABC): """An interface for any code block that might not be executed in the code, e.g if block/else block/try block/catch block ect.""" @property diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index 773fd871a..3f111bf6a 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -1003,9 +1003,9 @@ def viz(self) -> VizNode: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport|None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: if self.parent is not None: - yield from self.parent.resolve_name(name, start_byte or self.start_byte,strict=strict) + yield from self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict) yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict) return diff --git a/src/codegen/sdk/core/statements/catch_statement.py b/src/codegen/sdk/core/statements/catch_statement.py index 3df7839a7..6d7b36071 100644 --- a/src/codegen/sdk/core/statements/catch_statement.py +++ b/src/codegen/sdk/core/statements/catch_statement.py @@ -18,7 +18,7 @@ @apidoc -class CatchStatement(ConditionalBlock,BlockStatement[Parent], Generic[Parent]): +class CatchStatement(ConditionalBlock, BlockStatement[Parent], Generic[Parent]): """Abstract representation catch clause. Attributes: diff --git a/src/codegen/sdk/core/statements/for_loop_statement.py b/src/codegen/sdk/core/statements/for_loop_statement.py index 67bab5a02..dcd5e4978 100644 --- a/src/codegen/sdk/core/statements/for_loop_statement.py +++ b/src/codegen/sdk/core/statements/for_loop_statement.py @@ -38,7 +38,7 @@ class ForLoopStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: if self.item and isinstance(self.iterable, Chainable): if start_byte is None or start_byte > self.iterable.end_byte: if name == self.item: diff --git a/src/codegen/sdk/core/statements/if_block_statement.py b/src/codegen/sdk/core/statements/if_block_statement.py index de9ab443b..e3becc13c 100644 --- a/src/codegen/sdk/core/statements/if_block_statement.py +++ b/src/codegen/sdk/core/statements/if_block_statement.py @@ -29,7 +29,7 @@ @apidoc -class IfBlockStatement(ConditionalBlock,Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): +class IfBlockStatement(ConditionalBlock, Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): """Abstract representation of the if/elif/else if/else statement block. For example, if there is a code block like: @@ -275,9 +275,8 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - else: self.remove() - @property - def other_possible_blocks(self)-> Sequence[ConditionalBlock]: + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: if self.is_if_statement: return self._main_if_block.alternative_blocks elif self.is_elif_statement: @@ -286,7 +285,7 @@ def other_possible_blocks(self)-> Sequence[ConditionalBlock]: if main.else_statement: statements.append(main.else_statement) for statement in main.elif_statements: - if statement!=self: + if statement != self: statements.append(statement) return statements else: diff --git a/src/codegen/sdk/core/statements/switch_case.py b/src/codegen/sdk/core/statements/switch_case.py index 514b3e75e..d293ad034 100644 --- a/src/codegen/sdk/core/statements/switch_case.py +++ b/src/codegen/sdk/core/statements/switch_case.py @@ -19,7 +19,7 @@ @apidoc -class SwitchCase(ConditionalBlock,BlockStatement[Parent], Generic[Parent]): +class SwitchCase(ConditionalBlock, BlockStatement[Parent], Generic[Parent]): """Abstract representation for a switch case. Attributes: @@ -37,5 +37,5 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa super()._compute_dependencies(usage_type, dest) @property - def other_possible_blocks(self)-> list[ConditionalBlock]: - return [case for case in self.parent.cases if case!=self] + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/core/statements/try_catch_statement.py b/src/codegen/sdk/core/statements/try_catch_statement.py index 181ad2c91..177ddde68 100644 --- a/src/codegen/sdk/core/statements/try_catch_statement.py +++ b/src/codegen/sdk/core/statements/try_catch_statement.py @@ -17,7 +17,7 @@ @apidoc -class TryCatchStatement(ConditionalBlock,BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): +class TryCatchStatement(ConditionalBlock, BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): """Abstract representation of the try catch statement block. Attributes: diff --git a/src/codegen/sdk/python/function.py b/src/codegen/sdk/python/function.py index ebe88a485..3a5749e62 100644 --- a/src/codegen/sdk/python/function.py +++ b/src/codegen/sdk/python/function.py @@ -121,7 +121,7 @@ def is_class_method(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: if self.is_method: if not self.is_static_method: if len(self.parameters.symbols) > 0: diff --git a/src/codegen/sdk/python/statements/catch_statement.py b/src/codegen/sdk/python/statements/catch_statement.py index 04d57fa73..9ebee3f3f 100644 --- a/src/codegen/sdk/python/statements/catch_statement.py +++ b/src/codegen/sdk/python/statements/catch_statement.py @@ -14,6 +14,7 @@ from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.node_id_factory import NodeId + @py_apidoc class PyCatchStatement(CatchStatement[PyCodeBlock], PyBlockStatement): """Python catch clause. @@ -28,5 +29,5 @@ def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, self.condition = self.children[0] @property - def other_possible_blocks(self)-> list[ConditionalBlock]: - return [clause for clause in self.parent.except_clauses if clause!=self]+[self.parent] + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [clause for clause in self.parent.except_clauses if clause != self] + [self.parent] diff --git a/src/codegen/sdk/python/statements/match_case.py b/src/codegen/sdk/python/statements/match_case.py index 02aee4e9f..d5e1298fc 100644 --- a/src/codegen/sdk/python/statements/match_case.py +++ b/src/codegen/sdk/python/statements/match_case.py @@ -22,8 +22,6 @@ def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext" super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("alternative") - - @property - def other_possible_blocks(self)-> list["ConditionalBlock"]: - return [case for case in self.parent.cases if case!=self] + def other_possible_blocks(self) -> list["ConditionalBlock"]: + return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/typescript/export.py b/src/codegen/sdk/typescript/export.py index ec2009816..36c499358 100644 --- a/src/codegen/sdk/typescript/export.py +++ b/src/codegen/sdk/typescript/export.py @@ -204,7 +204,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa if frame.parent_frame: frame.parent_frame.add_usage(self._name_node or self, UsageKind.EXPORTED_SYMBOL, self, self.ctx) elif self._exported_symbol: - if not next(self.resolve_name(self._exported_symbol.source),None): + if not next(self.resolve_name(self._exported_symbol.source), None): self._exported_symbol._compute_dependencies(UsageKind.BODY, dest=dest or self) elif self.value: self.value._compute_dependencies(UsageKind.EXPORTED_SYMBOL, self) @@ -218,7 +218,7 @@ def compute_export_dependencies(self) -> None: self.ctx.add_edge(self.node_id, self.declared_symbol.node_id, type=EdgeType.EXPORT) elif self._exported_symbol is not None: symbol_name = self._exported_symbol.source - if (used_node := next(self.resolve_name(symbol_name),None)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): + if (used_node := next(self.resolve_name(symbol_name), None)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): self.ctx.add_edge(self.node_id, used_node.node_id, type=EdgeType.EXPORT) elif self.value is not None: if isinstance(self.value, Chainable): diff --git a/src/codegen/sdk/typescript/function.py b/src/codegen/sdk/typescript/function.py index a6d76b649..89b56d9c6 100644 --- a/src/codegen/sdk/typescript/function.py +++ b/src/codegen/sdk/typescript/function.py @@ -360,7 +360,7 @@ def arrow_to_named(self, name: str | None = None) -> None: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None,strict:bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: """Resolves the name of a symbol in the function. This method resolves the name of a symbol in the function. If the name is "this", it returns the parent class. diff --git a/src/codegen/sdk/typescript/statements/catch_statement.py b/src/codegen/sdk/typescript/statements/catch_statement.py index 4f534a6b4..ed46d2efc 100644 --- a/src/codegen/sdk/typescript/statements/catch_statement.py +++ b/src/codegen/sdk/typescript/statements/catch_statement.py @@ -31,5 +31,5 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, self.condition = self.child_by_field_name("parameter") @property - def other_possible_blocks(self)-> list[ConditionalBlock]: + def other_possible_blocks(self) -> list[ConditionalBlock]: return [self.parent] diff --git a/src/codegen/sdk/typescript/statements/try_catch_statement.py b/src/codegen/sdk/typescript/statements/try_catch_statement.py index e1099da06..8f499da04 100644 --- a/src/codegen/sdk/typescript/statements/try_catch_statement.py +++ b/src/codegen/sdk/typescript/statements/try_catch_statement.py @@ -94,6 +94,7 @@ def nested_code_blocks(self) -> list[TSCodeBlock]: if self.finalizer: nested_blocks.append(self.finalizer.code_block) return nested_blocks + @property def other_possible_blocks(self) -> Sequence[ConditionalBlock]: if self.catch: diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py index ef675ca5f..22f8af23f 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py @@ -128,7 +128,6 @@ def foo(): assert alt_blocks[2].else_statement is None - def test_if_else_reassigment_handling(tmpdir) -> None: content = """ @@ -191,16 +190,16 @@ def foo(a): file = codebase.get_file("test.py") foo = file.get_function("foo") assert foo - assert len(foo.parameters[0].usages)==0 + assert len(foo.parameters[0].usages) == 0 funct_call_a = foo.function_calls[0].args[0] funct_call_b = foo.function_calls[1] for symbol in file.symbols(True): - if symbol.name=='a': - assert len(symbol.usages)==1 - symbol.usages[0].match==funct_call_a - elif symbol.name=='b': - assert len(symbol.usages)==1 - symbol.usages[0].match==funct_call_b + if symbol.name == "a": + assert len(symbol.usages) == 1 + symbol.usages[0].match == funct_call_a + elif symbol.name == "b": + assert len(symbol.usages) == 1 + symbol.usages[0].match == funct_call_b def test_if_else_reassigment_handling_partial_if(tmpdir) -> None: @@ -248,6 +247,7 @@ def test_if_else_reassigment_handling_double(tmpdir) -> None: usage = symb.usages[0] assert usage.match == pyspark_arg + def test_if_else_reassigment_handling_nested_usage(tmpdir) -> None: content = """ if True: @@ -262,7 +262,7 @@ def test_if_else_reassigment_handling_nested_usage(tmpdir) -> None: file = codebase.get_file("test.py") funct_call = file.function_calls[0] pyspark_arg = funct_call.args.children[0] - first=file.symbols[0] - second=file.symbols[1] - assert len(first.usages)==0 + first = file.symbols[0] + second = file.symbols[1] + assert len(first.usages) == 0 assert second.usages[0].match == pyspark_arg From 68115295878a09d9a2ad59e4dbd574c2850b3f23 Mon Sep 17 00:00:00 2001 From: tkucar Date: Tue, 11 Mar 2025 18:00:41 +0100 Subject: [PATCH 5/8] patch --- src/codegen/sdk/core/expressions/name.py | 37 +++++++++---------- src/codegen/sdk/core/file.py | 4 +- src/codegen/sdk/core/function.py | 2 +- src/codegen/sdk/core/interfaces/editable.py | 6 +-- .../sdk/core/statements/for_loop_statement.py | 2 +- src/codegen/sdk/python/function.py | 2 +- src/codegen/sdk/typescript/function.py | 5 ++- uv.lock | 2 + 8 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index 12c18b617..eddb87306 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -32,8 +32,7 @@ class Name(Expression[Parent], Resolvable, Generic[Parent]): def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: """Resolve the types used by this symbol.""" for used in self.resolve_name(self.source, self.start_byte): - if used: - yield from self.with_resolution_frame(used) + yield from self.with_resolution_frame(used) @commiter def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None: @@ -53,24 +52,22 @@ def rename_if_matching(self, old: str, new: str): @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator["Symbol | Import | WildcardImport | None"]: - if self.parent is not None: - resolved_name = next(self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict), None) + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator["Symbol | Import | WildcardImport"]: + resolved_name = next(super().resolve_name(name,start_byte or self.start_byte, strict=strict),None) + if resolved_name: + yield resolved_name else: - resolved_name = next(self.file.resolve_name(name, start_byte or self.start_byte, strict=strict), None) + return - yield resolved_name + if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)): + top_of_conditional = conditional_parent.start_byte + if self.parent_of_type(ConditionalBlock) == conditional_parent: + # Use in the same block, should only depend on the inside of the block + return + for other_conditional in conditional_parent.other_possible_blocks: + if cond_name := next(other_conditional.resolve_name(name, start_byte=other_conditional.end_byte_for_condition_block), None): + if cond_name.start_byte >= other_conditional.start_byte: + yield cond_name + top_of_conditional = min(top_of_conditional, other_conditional.start_byte) - if resolved_name is not None: - if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)): - top_of_conditional = conditional_parent.start_byte - if self.parent_of_type(ConditionalBlock) == conditional_parent: - # Use in the same block, should only depend on the inside of the block - return - for other_conditional in conditional_parent.other_possible_blocks: - if cond_name := next(other_conditional.resolve_name(name, start_byte=other_conditional.end_byte_for_condition_block), None): - if cond_name.start_byte >= other_conditional.start_byte: - yield cond_name - top_of_conditional = min(top_of_conditional, other_conditional.start_byte) - - yield from self.resolve_name(name, top_of_conditional, strict=True) + yield from self.resolve_name(name, top_of_conditional, strict=False) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 1613e7f5d..a7252a724 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -880,14 +880,14 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if resolved := self.valid_symbol_names.get(name): if start_byte is not None and resolved.end_byte > start_byte: for symbol in reversed(self.symbols): if symbol.start_byte <= start_byte and symbol.name == name: yield symbol return - if strict: + if not strict: return yield resolved return diff --git a/src/codegen/sdk/core/function.py b/src/codegen/sdk/core/function.py index 0a529c1f1..408c15a84 100644 --- a/src/codegen/sdk/core/function.py +++ b/src/codegen/sdk/core/function.py @@ -141,7 +141,7 @@ def is_async(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: from codegen.sdk.core.class_definition import Class for symbol in self.valid_symbol_names: diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index 3f111bf6a..22ae37f51 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -1003,11 +1003,11 @@ def viz(self) -> VizNode: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.parent is not None: yield from self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict) - yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict) - return + else: + yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict) @cached_property @noapidoc diff --git a/src/codegen/sdk/core/statements/for_loop_statement.py b/src/codegen/sdk/core/statements/for_loop_statement.py index dcd5e4978..d884a52d0 100644 --- a/src/codegen/sdk/core/statements/for_loop_statement.py +++ b/src/codegen/sdk/core/statements/for_loop_statement.py @@ -38,7 +38,7 @@ class ForLoopStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.item and isinstance(self.iterable, Chainable): if start_byte is None or start_byte > self.iterable.end_byte: if name == self.item: diff --git a/src/codegen/sdk/python/function.py b/src/codegen/sdk/python/function.py index 3a5749e62..77d7e623d 100644 --- a/src/codegen/sdk/python/function.py +++ b/src/codegen/sdk/python/function.py @@ -121,7 +121,7 @@ def is_class_method(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.is_method: if not self.is_static_method: if len(self.parameters.symbols) > 0: diff --git a/src/codegen/sdk/typescript/function.py b/src/codegen/sdk/typescript/function.py index 89b56d9c6..5882bec74 100644 --- a/src/codegen/sdk/typescript/function.py +++ b/src/codegen/sdk/typescript/function.py @@ -360,7 +360,7 @@ def arrow_to_named(self, name: str | None = None) -> None: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = False) -> Generator[Symbol | Import | WildcardImport | None]: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: """Resolves the name of a symbol in the function. This method resolves the name of a symbol in the function. If the name is "this", it returns the parent class. @@ -369,9 +369,10 @@ def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = Args: name (str): The name of the symbol to resolve. start_byte (int | None): The start byte of the symbol to resolve. + strict (bool): If True considers candidates that don't satisfy start byte if none do. Returns: - Symbol | Import | WildcardImport | None: The resolved symbol, import, or wildcard import, or None if not found. + Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import, or None if not found. """ if self.is_method: if name == "this": diff --git a/uv.lock b/uv.lock index 8e0364f57..c029ab5e6 100644 --- a/uv.lock +++ b/uv.lock @@ -531,6 +531,7 @@ wheels = [ name = "codegen" source = { editable = "." } dependencies = [ + { name = "anthropic" }, { name = "astor" }, { name = "click" }, { name = "codeowners" }, @@ -659,6 +660,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "anthropic" }, { name = "astor", specifier = ">=0.8.1,<1.0.0" }, { name = "attrs", marker = "extra == 'lsp'", specifier = ">=25.1.0" }, { name = "click", specifier = ">=8.1.7" }, From 17b4cc0320a37054c21ba172dfa94baba0217802 Mon Sep 17 00:00:00 2001 From: tomcodgen <191515280+tomcodgen@users.noreply.github.com> Date: Tue, 11 Mar 2025 17:01:41 +0000 Subject: [PATCH 6/8] Automated pre-commit update --- src/codegen/sdk/core/expressions/name.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index eddb87306..df5ef6872 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -53,7 +53,7 @@ def rename_if_matching(self, old: str, new: str): @noapidoc @reader def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator["Symbol | Import | WildcardImport"]: - resolved_name = next(super().resolve_name(name,start_byte or self.start_byte, strict=strict),None) + resolved_name = next(super().resolve_name(name, start_byte or self.start_byte, strict=strict), None) if resolved_name: yield resolved_name else: From f69a3e6df75552e7e4cc3f9f90d21bc1205e8be7 Mon Sep 17 00:00:00 2001 From: tkucar Date: Wed, 12 Mar 2025 17:20:10 +0100 Subject: [PATCH 7/8] comment --- src/codegen/sdk/core/file.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index a7252a724..12bcab303 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -881,14 +881,37 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp @noapidoc @reader def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: + """Resolves a name to a symbol, import, or wildcard import within the file's scope. + + Performs name resolution by first checking the file's valid symbols and imports. When a start_byte + is provided, ensures proper scope handling by only resolving to symbols that are defined before + that position in the file. + + Args: + name (str): The name to resolve. + start_byte (int | None): If provided, only resolves to symbols defined before this byte position + in the file. Used for proper scope handling. Defaults to None. + strict (bool): When True and using start_byte, only yields symbols if found in the correct scope. + When False, allows falling back to global scope. Defaults to True. + + Yields: + Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import that matches + the name and scope requirements. Yields at most one result. + """ if resolved := self.valid_symbol_names.get(name): + # If we have a start_byte and the resolved symbol is after it, + # we need to look for earlier definitions of the symbol if start_byte is not None and resolved.end_byte > start_byte: + # Search backwards through symbols to find the most recent definition + # that comes before our start_byte position for symbol in reversed(self.symbols): if symbol.start_byte <= start_byte and symbol.name == name: yield symbol return + # If strict mode and no valid symbol found, return nothing if not strict: return + # Either no start_byte constraint or symbol is before start_byte yield resolved return return From 28302003576a426969e4c4fe2ac1a45fee54c2ca Mon Sep 17 00:00:00 2001 From: tkucar Date: Wed, 12 Mar 2025 19:26:48 +0100 Subject: [PATCH 8/8] fix --- src/codegen/sdk/python/import_resolution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index f8066c583..5c2a1f640 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -211,7 +211,11 @@ def _file_by_custom_resolve_paths(self, resolve_paths: list[str], filepath: str) """ for resolve_path in resolve_paths: filepath_new: str = os.path.join(resolve_path, filepath) - if file := self.ctx.get_file(filepath_new): + try: + file = self.ctx.get_file(filepath_new) + except AssertionError as e: + file = None + if file: return file return None