diff --git a/src/codegen/sdk/core/function.py b/src/codegen/sdk/core/function.py index 408c15a84..ea5b8fc95 100644 --- a/src/codegen/sdk/core/function.py +++ b/src/codegen/sdk/core/function.py @@ -148,7 +148,7 @@ def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = if symbol.name == name and (start_byte is None or (symbol.start_byte if isinstance(symbol, Class | Function) else symbol.end_byte) <= start_byte): yield symbol return - yield from super().resolve_name(name, start_byte) + yield from super().resolve_name(name, start_byte, strict=strict) @cached_property @noapidoc diff --git a/src/codegen/sdk/core/interfaces/conditional_block.py b/src/codegen/sdk/core/interfaces/conditional_block.py index a11990908..2689badc3 100644 --- a/src/codegen/sdk/core/interfaces/conditional_block.py +++ b/src/codegen/sdk/core/interfaces/conditional_block.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from codegen.sdk.core.statements.statement import Statement +from codegen.shared.decorators.docs import noapidoc class ConditionalBlock(Statement, ABC): @@ -9,9 +10,12 @@ class ConditionalBlock(Statement, ABC): @property @abstractmethod + @noapidoc def other_possible_blocks(self) -> Sequence["ConditionalBlock"]: """Should return all other "branches" that might be executed instead.""" @property + @noapidoc def end_byte_for_condition_block(self) -> int: + """Returns the end byte for the specific condition block""" return self.end_byte diff --git a/src/codegen/sdk/core/statements/for_loop_statement.py b/src/codegen/sdk/core/statements/for_loop_statement.py index d884a52d0..f8753cdac 100644 --- a/src/codegen/sdk/core/statements/for_loop_statement.py +++ b/src/codegen/sdk/core/statements/for_loop_statement.py @@ -57,4 +57,4 @@ def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = return yield frame.top.node return - yield from super().resolve_name(name, start_byte) + yield from super().resolve_name(name, start_byte, strict=strict) diff --git a/src/codegen/sdk/core/statements/if_block_statement.py b/src/codegen/sdk/core/statements/if_block_statement.py index e3becc13c..5d6a99fe7 100644 --- a/src/codegen/sdk/core/statements/if_block_statement.py +++ b/src/codegen/sdk/core/statements/if_block_statement.py @@ -276,9 +276,10 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - self.remove() @property + @noapidoc def other_possible_blocks(self) -> Sequence[ConditionalBlock]: if self.is_if_statement: - return self._main_if_block.alternative_blocks + return self.alternative_blocks elif self.is_elif_statement: main = self._main_if_block statements = [main] @@ -293,6 +294,7 @@ def other_possible_blocks(self) -> Sequence[ConditionalBlock]: return [main, *main.elif_statements] @property + @noapidoc def end_byte_for_condition_block(self) -> int: if self.is_if_statement: return self.consequence_block.end_byte diff --git a/src/codegen/sdk/core/statements/switch_case.py b/src/codegen/sdk/core/statements/switch_case.py index d293ad034..46f83af64 100644 --- a/src/codegen/sdk/core/statements/switch_case.py +++ b/src/codegen/sdk/core/statements/switch_case.py @@ -37,5 +37,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa super()._compute_dependencies(usage_type, dest) @property + @noapidoc def other_possible_blocks(self) -> list[ConditionalBlock]: + """Returns the end byte for the specific condition block""" return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/python/function.py b/src/codegen/sdk/python/function.py index 77d7e623d..0ab63f114 100644 --- a/src/codegen/sdk/python/function.py +++ b/src/codegen/sdk/python/function.py @@ -131,7 +131,7 @@ def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = if name == "super()": yield self.parent_class return - yield from super().resolve_name(name, start_byte) + yield from super().resolve_name(name, start_byte, strict=strict) @noapidoc @commiter diff --git a/src/codegen/sdk/python/statements/catch_statement.py b/src/codegen/sdk/python/statements/catch_statement.py index 9ebee3f3f..f5b36bd2b 100644 --- a/src/codegen/sdk/python/statements/catch_statement.py +++ b/src/codegen/sdk/python/statements/catch_statement.py @@ -5,7 +5,7 @@ from codegen.sdk.core.statements.catch_statement import CatchStatement from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock from codegen.sdk.python.statements.block_statement import PyBlockStatement -from codegen.shared.decorators.docs import py_apidoc +from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: from tree_sitter import Node as PyNode @@ -29,5 +29,6 @@ def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, self.condition = self.children[0] @property + @noapidoc def other_possible_blocks(self) -> list[ConditionalBlock]: return [clause for clause in self.parent.except_clauses if clause != self] + [self.parent] diff --git a/src/codegen/sdk/python/statements/match_case.py b/src/codegen/sdk/python/statements/match_case.py index d5e1298fc..1140ccc38 100644 --- a/src/codegen/sdk/python/statements/match_case.py +++ b/src/codegen/sdk/python/statements/match_case.py @@ -6,7 +6,7 @@ from codegen.sdk.core.statements.switch_case import SwitchCase from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock from codegen.sdk.python.statements.block_statement import PyBlockStatement -from codegen.shared.decorators.docs import py_apidoc +from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -23,5 +23,6 @@ def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext" self.condition = self.child_by_field_name("alternative") @property + @noapidoc def other_possible_blocks(self) -> list["ConditionalBlock"]: return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/python/statements/try_catch_statement.py b/src/codegen/sdk/python/statements/try_catch_statement.py index b54051f96..fda319130 100644 --- a/src/codegen/sdk/python/statements/try_catch_statement.py +++ b/src/codegen/sdk/python/statements/try_catch_statement.py @@ -101,10 +101,12 @@ def nested_code_blocks(self) -> list[PyCodeBlock]: return nested_blocks @property + @noapidoc def other_possible_blocks(self) -> Sequence[ConditionalBlock]: return self.except_clauses @property + @noapidoc def end_byte_for_condition_block(self) -> int: if self.code_block: return self.code_block.end_byte diff --git a/src/codegen/sdk/typescript/function.py b/src/codegen/sdk/typescript/function.py index 5882bec74..ee71ee9db 100644 --- a/src/codegen/sdk/typescript/function.py +++ b/src/codegen/sdk/typescript/function.py @@ -378,7 +378,7 @@ def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = if name == "this": yield self.parent_class return - yield from super().resolve_name(name, start_byte) + yield from super().resolve_name(name, start_byte, strict=strict) @staticmethod def is_valid_node(node: TSNode) -> bool: diff --git a/src/codegen/sdk/typescript/statements/catch_statement.py b/src/codegen/sdk/typescript/statements/catch_statement.py index ed46d2efc..e6027d3a7 100644 --- a/src/codegen/sdk/typescript/statements/catch_statement.py +++ b/src/codegen/sdk/typescript/statements/catch_statement.py @@ -4,7 +4,7 @@ from codegen.sdk.core.statements.catch_statement import CatchStatement from codegen.sdk.typescript.statements.block_statement import TSBlockStatement -from codegen.shared.decorators.docs import apidoc +from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: from tree_sitter import Node as TSNode @@ -31,5 +31,6 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, self.condition = self.child_by_field_name("parameter") @property + @noapidoc def other_possible_blocks(self) -> list[ConditionalBlock]: return [self.parent] diff --git a/src/codegen/sdk/typescript/statements/try_catch_statement.py b/src/codegen/sdk/typescript/statements/try_catch_statement.py index 8f499da04..315f9f33c 100644 --- a/src/codegen/sdk/typescript/statements/try_catch_statement.py +++ b/src/codegen/sdk/typescript/statements/try_catch_statement.py @@ -96,6 +96,7 @@ def nested_code_blocks(self) -> list[TSCodeBlock]: return nested_blocks @property + @noapidoc def other_possible_blocks(self) -> Sequence[ConditionalBlock]: if self.catch: return [self.catch] @@ -103,6 +104,7 @@ def other_possible_blocks(self) -> Sequence[ConditionalBlock]: return [] @property + @noapidoc def end_byte_for_condition_block(self) -> int: if self.code_block: return self.code_block.end_byte diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py index 22f8af23f..6e38b9a45 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py @@ -223,6 +223,24 @@ def test_if_else_reassigment_handling_partial_if(tmpdir) -> None: assert usage.match == pyspark_arg +def test_if_else_reassigment_handling_solo_if(tmpdir) -> None: + content = """ + PYSPARK = "TEST" + if True: + PYSPARK = True + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + def test_if_else_reassigment_handling_double(tmpdir) -> None: content = """ if False: @@ -266,3 +284,24 @@ def test_if_else_reassigment_handling_nested_usage(tmpdir) -> None: second = file.symbols[1] assert len(first.usages) == 0 assert second.usages[0].match == pyspark_arg + + +def test_if_else_reassigment_inside_func_with_external_element(tmpdir) -> None: + content = """ + PYSPARK="0" + def foo(): + if True: + PYSPARK = True + else: + PYSPARK = False + print(PYSPARK) + + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + func = file.get_function("foo") + for assign in func.valid_symbol_names[:-1]: + assign.usages[0] == pyspark_arg