diff --git a/src/codegen/sdk/core/expressions/name.py b/src/codegen/sdk/core/expressions/name.py index 3ee1b6411..df5ef6872 100644 --- a/src/codegen/sdk/core/expressions/name.py +++ b/src/codegen/sdk/core/expressions/name.py @@ -5,13 +5,15 @@ from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.expressions.expression import Expression +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.resolvable import Resolvable from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from codegen.sdk.core.import_resolution import Import, WildcardImport from codegen.sdk.core.interfaces.has_name import HasName - + from codegen.sdk.core.symbol import Symbol Parent = TypeVar("Parent", bound="Expression") @@ -29,10 +31,9 @@ class Name(Expression[Parent], Resolvable, Generic[Parent]): @override def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]: """Resolve the types used by this symbol.""" - if used := self.resolve_name(self.source, self.start_byte): + for used in self.resolve_name(self.source, self.start_byte): yield from self.with_resolution_frame(used) - @noapidoc @commiter def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None: """Compute the dependencies of the export object.""" @@ -48,3 +49,25 @@ def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | def rename_if_matching(self, old: str, new: str): if self.source == old: self.edit(new) + + @noapidoc + @reader + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator["Symbol | Import | WildcardImport"]: + resolved_name = next(super().resolve_name(name, start_byte or self.start_byte, strict=strict), None) + if resolved_name: + yield resolved_name + else: + return + + if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)): + top_of_conditional = conditional_parent.start_byte + if self.parent_of_type(ConditionalBlock) == conditional_parent: + # Use in the same block, should only depend on the inside of the block + return + for other_conditional in conditional_parent.other_possible_blocks: + if cond_name := next(other_conditional.resolve_name(name, start_byte=other_conditional.end_byte_for_condition_block), None): + if cond_name.start_byte >= other_conditional.start_byte: + yield cond_name + top_of_conditional = min(top_of_conditional, other_conditional.start_byte) + + yield from self.resolve_name(name, top_of_conditional, strict=False) diff --git a/src/codegen/sdk/core/file.py b/src/codegen/sdk/core/file.py index 8ad9e1385..12bcab303 100644 --- a/src/codegen/sdk/core/file.py +++ b/src/codegen/sdk/core/file.py @@ -3,7 +3,7 @@ import resource import sys from abc import abstractmethod -from collections.abc import Sequence +from collections.abc import Generator, Sequence from functools import cached_property from os import PathLike from pathlib import Path @@ -744,7 +744,7 @@ def get_symbol(self, name: str) -> Symbol | None: Returns: Symbol | None: The found symbol, or None if not found. """ - if symbol := self.resolve_name(name, self.end_byte): + if symbol := next(self.resolve_name(name, self.end_byte), None): if isinstance(symbol, Symbol): return symbol return next((x for x in self.symbols if x.name == name), None) @@ -819,7 +819,7 @@ def get_class(self, name: str) -> TClass | None: Returns: TClass | None: The matching Class object if found, None otherwise. """ - if symbol := self.resolve_name(name, self.end_byte): + if symbol := next(self.resolve_name(name, self.end_byte), None): if isinstance(symbol, Class): return symbol @@ -880,13 +880,41 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: + """Resolves a name to a symbol, import, or wildcard import within the file's scope. + + Performs name resolution by first checking the file's valid symbols and imports. When a start_byte + is provided, ensures proper scope handling by only resolving to symbols that are defined before + that position in the file. + + Args: + name (str): The name to resolve. + start_byte (int | None): If provided, only resolves to symbols defined before this byte position + in the file. Used for proper scope handling. Defaults to None. + strict (bool): When True and using start_byte, only yields symbols if found in the correct scope. + When False, allows falling back to global scope. Defaults to True. + + Yields: + Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import that matches + the name and scope requirements. Yields at most one result. + """ if resolved := self.valid_symbol_names.get(name): + # If we have a start_byte and the resolved symbol is after it, + # we need to look for earlier definitions of the symbol if start_byte is not None and resolved.end_byte > start_byte: - for symbol in self.symbols: + # Search backwards through symbols to find the most recent definition + # that comes before our start_byte position + for symbol in reversed(self.symbols): if symbol.start_byte <= start_byte and symbol.name == name: - return symbol - return resolved + yield symbol + return + # If strict mode and no valid symbol found, return nothing + if not strict: + return + # Either no start_byte constraint or symbol is before start_byte + yield resolved + return + return @property @reader diff --git a/src/codegen/sdk/core/function.py b/src/codegen/sdk/core/function.py index 21aa3c1df..408c15a84 100644 --- a/src/codegen/sdk/core/function.py +++ b/src/codegen/sdk/core/function.py @@ -141,13 +141,14 @@ def is_async(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: from codegen.sdk.core.class_definition import Class for symbol in self.valid_symbol_names: if symbol.name == name and (start_byte is None or (symbol.start_byte if isinstance(symbol, Class | Function) else symbol.end_byte) <= start_byte): - return symbol - return super().resolve_name(name, start_byte) + yield symbol + return + yield from super().resolve_name(name, start_byte) @cached_property @noapidoc diff --git a/src/codegen/sdk/core/interfaces/conditional_block.py b/src/codegen/sdk/core/interfaces/conditional_block.py new file mode 100644 index 000000000..a11990908 --- /dev/null +++ b/src/codegen/sdk/core/interfaces/conditional_block.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod +from collections.abc import Sequence + +from codegen.sdk.core.statements.statement import Statement + + +class ConditionalBlock(Statement, ABC): + """An interface for any code block that might not be executed in the code, e.g if block/else block/try block/catch block ect.""" + + @property + @abstractmethod + def other_possible_blocks(self) -> Sequence["ConditionalBlock"]: + """Should return all other "branches" that might be executed instead.""" + + @property + def end_byte_for_condition_block(self) -> int: + return self.end_byte diff --git a/src/codegen/sdk/core/interfaces/editable.py b/src/codegen/sdk/core/interfaces/editable.py index f59037144..22ae37f51 100644 --- a/src/codegen/sdk/core/interfaces/editable.py +++ b/src/codegen/sdk/core/interfaces/editable.py @@ -1003,10 +1003,11 @@ def viz(self) -> VizNode: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.parent is not None: - return self.parent.resolve_name(name, start_byte or self.start_byte) - return self.file.resolve_name(name, start_byte or self.start_byte) + yield from self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict) + else: + yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict) @cached_property @noapidoc diff --git a/src/codegen/sdk/core/statements/catch_statement.py b/src/codegen/sdk/core/statements/catch_statement.py index e9e96fa09..6d7b36071 100644 --- a/src/codegen/sdk/core/statements/catch_statement.py +++ b/src/codegen/sdk/core/statements/catch_statement.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Generic, Self, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc @@ -17,7 +18,7 @@ @apidoc -class CatchStatement(BlockStatement[Parent], Generic[Parent]): +class CatchStatement(ConditionalBlock, BlockStatement[Parent], Generic[Parent]): """Abstract representation catch clause. Attributes: diff --git a/src/codegen/sdk/core/statements/for_loop_statement.py b/src/codegen/sdk/core/statements/for_loop_statement.py index e6c6bc4b4..d884a52d0 100644 --- a/src/codegen/sdk/core/statements/for_loop_statement.py +++ b/src/codegen/sdk/core/statements/for_loop_statement.py @@ -12,6 +12,8 @@ from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from collections.abc import Generator + from codegen.sdk.core.detached_symbols.code_block import CodeBlock from codegen.sdk.core.expressions import Expression from codegen.sdk.core.import_resolution import Import, WildcardImport @@ -36,19 +38,23 @@ class ForLoopStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.item and isinstance(self.iterable, Chainable): if start_byte is None or start_byte > self.iterable.end_byte: if name == self.item: for frame in self.iterable.resolved_type_frames: if frame.generics: - return next(iter(frame.generics.values())) - return frame.top.node + yield next(iter(frame.generics.values())) + return + yield frame.top.node + return elif isinstance(self.item, Collection): for idx, item in enumerate(self.item): if item == name: for frame in self.iterable.resolved_type_frames: if frame.generics and len(frame.generics) > idx: - return list(frame.generics.values())[idx] - return frame.top.node - return super().resolve_name(name, start_byte) + yield list(frame.generics.values())[idx] + return + yield frame.top.node + return + yield from super().resolve_name(name, start_byte) diff --git a/src/codegen/sdk/core/statements/if_block_statement.py b/src/codegen/sdk/core/statements/if_block_statement.py index 31e2fcbe8..e3becc13c 100644 --- a/src/codegen/sdk/core/statements/if_block_statement.py +++ b/src/codegen/sdk/core/statements/if_block_statement.py @@ -8,11 +8,14 @@ from codegen.sdk.core.autocommit import reader, writer from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.function import Function +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.statement import Statement, StatementType from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc if TYPE_CHECKING: + from collections.abc import Sequence + from codegen.sdk.core.detached_symbols.code_block import CodeBlock from codegen.sdk.core.detached_symbols.function_call import FunctionCall from codegen.sdk.core.expressions import Expression @@ -26,7 +29,7 @@ @apidoc -class IfBlockStatement(Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): +class IfBlockStatement(ConditionalBlock, Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]): """Abstract representation of the if/elif/else if/else statement block. For example, if there is a code block like: @@ -271,3 +274,26 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - self.remove_byte_range(self.ts_node.start_byte, remove_end) else: self.remove() + + @property + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + if self.is_if_statement: + return self._main_if_block.alternative_blocks + elif self.is_elif_statement: + main = self._main_if_block + statements = [main] + if main.else_statement: + statements.append(main.else_statement) + for statement in main.elif_statements: + if statement != self: + statements.append(statement) + return statements + else: + main = self._main_if_block + return [main, *main.elif_statements] + + @property + def end_byte_for_condition_block(self) -> int: + if self.is_if_statement: + return self.consequence_block.end_byte + return self.end_byte diff --git a/src/codegen/sdk/core/statements/switch_case.py b/src/codegen/sdk/core/statements/switch_case.py index 3ebafb57e..d293ad034 100644 --- a/src/codegen/sdk/core/statements/switch_case.py +++ b/src/codegen/sdk/core/statements/switch_case.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Generic, Self, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.extensions.autocommit import commiter from codegen.shared.decorators.docs import apidoc, noapidoc @@ -18,7 +19,7 @@ @apidoc -class SwitchCase(BlockStatement[Parent], Generic[Parent]): +class SwitchCase(ConditionalBlock, BlockStatement[Parent], Generic[Parent]): """Abstract representation for a switch case. Attributes: @@ -34,3 +35,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa if self.condition: self.condition._compute_dependencies(usage_type, dest) super()._compute_dependencies(usage_type, dest) + + @property + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/core/statements/try_catch_statement.py b/src/codegen/sdk/core/statements/try_catch_statement.py index 1371a2d76..177ddde68 100644 --- a/src/codegen/sdk/core/statements/try_catch_statement.py +++ b/src/codegen/sdk/core/statements/try_catch_statement.py @@ -3,6 +3,7 @@ from abc import ABC from typing import TYPE_CHECKING, Generic, TypeVar +from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_block import HasBlock from codegen.sdk.core.statements.block_statement import BlockStatement from codegen.sdk.core.statements.statement import StatementType @@ -16,7 +17,7 @@ @apidoc -class TryCatchStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): +class TryCatchStatement(ConditionalBlock, BlockStatement[Parent], HasBlock, ABC, Generic[Parent]): """Abstract representation of the try catch statement block. Attributes: diff --git a/src/codegen/sdk/python/function.py b/src/codegen/sdk/python/function.py index 02a9dd55b..77d7e623d 100644 --- a/src/codegen/sdk/python/function.py +++ b/src/codegen/sdk/python/function.py @@ -19,6 +19,8 @@ from codegen.shared.logging.get_logger import get_logger if TYPE_CHECKING: + from collections.abc import Generator + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -119,15 +121,17 @@ def is_class_method(self) -> bool: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: if self.is_method: if not self.is_static_method: if len(self.parameters.symbols) > 0: if name == self.parameters[0].name: - return self.parent_class + yield self.parent_class + return if name == "super()": - return self.parent_class - return super().resolve_name(name, start_byte) + yield self.parent_class + return + yield from super().resolve_name(name, start_byte) @noapidoc @commiter diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index f8066c583..5c2a1f640 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -211,7 +211,11 @@ def _file_by_custom_resolve_paths(self, resolve_paths: list[str], filepath: str) """ for resolve_path in resolve_paths: filepath_new: str = os.path.join(resolve_path, filepath) - if file := self.ctx.get_file(filepath_new): + try: + file = self.ctx.get_file(filepath_new) + except AssertionError as e: + file = None + if file: return file return None diff --git a/src/codegen/sdk/python/statements/catch_statement.py b/src/codegen/sdk/python/statements/catch_statement.py index 3bbea1b46..9ebee3f3f 100644 --- a/src/codegen/sdk/python/statements/catch_statement.py +++ b/src/codegen/sdk/python/statements/catch_statement.py @@ -11,6 +11,7 @@ from tree_sitter import Node as PyNode from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.node_id_factory import NodeId @@ -26,3 +27,7 @@ class PyCatchStatement(CatchStatement[PyCodeBlock], PyBlockStatement): def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.children[0] + + @property + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [clause for clause in self.parent.except_clauses if clause != self] + [self.parent] diff --git a/src/codegen/sdk/python/statements/if_block_statement.py b/src/codegen/sdk/python/statements/if_block_statement.py index 54585b9e7..dc73b21dd 100644 --- a/src/codegen/sdk/python/statements/if_block_statement.py +++ b/src/codegen/sdk/python/statements/if_block_statement.py @@ -14,7 +14,6 @@ from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock - Parent = TypeVar("Parent", bound="PyCodeBlock") diff --git a/src/codegen/sdk/python/statements/match_case.py b/src/codegen/sdk/python/statements/match_case.py index 69528fbba..d5e1298fc 100644 --- a/src/codegen/sdk/python/statements/match_case.py +++ b/src/codegen/sdk/python/statements/match_case.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.python.statements.match_statement import PyMatchStatement @@ -20,3 +21,7 @@ class PyMatchCase(SwitchCase[PyCodeBlock["PyMatchStatement"]], PyBlockStatement) def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: PyCodeBlock, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("alternative") + + @property + def other_possible_blocks(self) -> list["ConditionalBlock"]: + return [case for case in self.parent.cases if case != self] diff --git a/src/codegen/sdk/python/statements/match_statement.py b/src/codegen/sdk/python/statements/match_statement.py index 804ff2029..59f01164c 100644 --- a/src/codegen/sdk/python/statements/match_statement.py +++ b/src/codegen/sdk/python/statements/match_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.children_by_field_name("alternative"): - self.cases.append(PyMatchCase(node, file_node_id, ctx, self.parent, self.index)) + self.cases.append(PyMatchCase(node, file_node_id, ctx, self, self.index)) diff --git a/src/codegen/sdk/python/statements/try_catch_statement.py b/src/codegen/sdk/python/statements/try_catch_statement.py index c4a4827b3..b54051f96 100644 --- a/src/codegen/sdk/python/statements/try_catch_statement.py +++ b/src/codegen/sdk/python/statements/try_catch_statement.py @@ -9,11 +9,14 @@ from codegen.shared.decorators.docs import noapidoc, py_apidoc if TYPE_CHECKING: + from collections.abc import Sequence + from tree_sitter import Node as PyNode from codegen.sdk.codebase.codebase_context import CodebaseContext from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.detached_symbols.function_call import FunctionCall + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.importable import Importable from codegen.sdk.core.node_id_factory import NodeId @@ -96,3 +99,14 @@ def nested_code_blocks(self) -> list[PyCodeBlock]: if self.finalizer: nested_blocks.append(self.finalizer.code_block) return nested_blocks + + @property + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + return self.except_clauses + + @property + def end_byte_for_condition_block(self) -> int: + if self.code_block: + return self.code_block.end_byte + else: + return self.end_byte diff --git a/src/codegen/sdk/typescript/export.py b/src/codegen/sdk/typescript/export.py index 44703749c..36c499358 100644 --- a/src/codegen/sdk/typescript/export.py +++ b/src/codegen/sdk/typescript/export.py @@ -204,7 +204,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa if frame.parent_frame: frame.parent_frame.add_usage(self._name_node or self, UsageKind.EXPORTED_SYMBOL, self, self.ctx) elif self._exported_symbol: - if not self.resolve_name(self._exported_symbol.source): + if not next(self.resolve_name(self._exported_symbol.source), None): self._exported_symbol._compute_dependencies(UsageKind.BODY, dest=dest or self) elif self.value: self.value._compute_dependencies(UsageKind.EXPORTED_SYMBOL, self) @@ -218,7 +218,7 @@ def compute_export_dependencies(self) -> None: self.ctx.add_edge(self.node_id, self.declared_symbol.node_id, type=EdgeType.EXPORT) elif self._exported_symbol is not None: symbol_name = self._exported_symbol.source - if (used_node := self.resolve_name(symbol_name)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): + if (used_node := next(self.resolve_name(symbol_name), None)) and isinstance(used_node, Importable) and self.ctx.has_node(used_node.node_id): self.ctx.add_edge(self.node_id, used_node.node_id, type=EdgeType.EXPORT) elif self.value is not None: if isinstance(self.value, Chainable): diff --git a/src/codegen/sdk/typescript/function.py b/src/codegen/sdk/typescript/function.py index a7be7b28f..5882bec74 100644 --- a/src/codegen/sdk/typescript/function.py +++ b/src/codegen/sdk/typescript/function.py @@ -19,6 +19,8 @@ from codegen.shared.logging.get_logger import get_logger if TYPE_CHECKING: + from collections.abc import Generator + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext @@ -358,7 +360,7 @@ def arrow_to_named(self, name: str | None = None) -> None: @noapidoc @reader - def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None: + def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]: """Resolves the name of a symbol in the function. This method resolves the name of a symbol in the function. If the name is "this", it returns the parent class. @@ -367,14 +369,16 @@ def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Imp Args: name (str): The name of the symbol to resolve. start_byte (int | None): The start byte of the symbol to resolve. + strict (bool): If True considers candidates that don't satisfy start byte if none do. Returns: - Symbol | Import | WildcardImport | None: The resolved symbol, import, or wildcard import, or None if not found. + Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import, or None if not found. """ if self.is_method: if name == "this": - return self.parent_class - return super().resolve_name(name, start_byte) + yield self.parent_class + return + yield from super().resolve_name(name, start_byte) @staticmethod def is_valid_node(node: TSNode) -> bool: diff --git a/src/codegen/sdk/typescript/statements/catch_statement.py b/src/codegen/sdk/typescript/statements/catch_statement.py index c6dc10bae..ed46d2efc 100644 --- a/src/codegen/sdk/typescript/statements/catch_statement.py +++ b/src/codegen/sdk/typescript/statements/catch_statement.py @@ -10,10 +10,10 @@ from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.typescript.detached_symbols.code_block import TSCodeBlock - Parent = TypeVar("Parent", bound="TSCodeBlock") @@ -29,3 +29,7 @@ class TSCatchStatement(CatchStatement[Parent], TSBlockStatement, Generic[Parent] def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, parent: Parent, pos: int | None = None) -> None: super().__init__(ts_node, file_node_id, ctx, parent, pos) self.condition = self.child_by_field_name("parameter") + + @property + def other_possible_blocks(self) -> list[ConditionalBlock]: + return [self.parent] diff --git a/src/codegen/sdk/typescript/statements/switch_statement.py b/src/codegen/sdk/typescript/statements/switch_statement.py index 914bde227..0dbec180f 100644 --- a/src/codegen/sdk/typescript/statements/switch_statement.py +++ b/src/codegen/sdk/typescript/statements/switch_statement.py @@ -24,4 +24,4 @@ def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: CodebaseContext, code_block = self.ts_node.child_by_field_name("body") self.cases = [] for node in code_block.named_children: - self.cases.append(TSSwitchCase(node, file_node_id, ctx, self.parent)) + self.cases.append(TSSwitchCase(node, file_node_id, ctx, self)) diff --git a/src/codegen/sdk/typescript/statements/try_catch_statement.py b/src/codegen/sdk/typescript/statements/try_catch_statement.py index aa24178d2..8f499da04 100644 --- a/src/codegen/sdk/typescript/statements/try_catch_statement.py +++ b/src/codegen/sdk/typescript/statements/try_catch_statement.py @@ -9,11 +9,14 @@ from codegen.shared.decorators.docs import noapidoc, ts_apidoc if TYPE_CHECKING: + from collections.abc import Sequence + from tree_sitter import Node as TSNode from codegen.sdk.codebase.codebase_context import CodebaseContext from codegen.sdk.core.dataclasses.usage import UsageKind from codegen.sdk.core.detached_symbols.function_call import FunctionCall + from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock from codegen.sdk.core.interfaces.has_name import HasName from codegen.sdk.core.interfaces.importable import Importable from codegen.sdk.core.node_id_factory import NodeId @@ -91,3 +94,17 @@ def nested_code_blocks(self) -> list[TSCodeBlock]: if self.finalizer: nested_blocks.append(self.finalizer.code_block) return nested_blocks + + @property + def other_possible_blocks(self) -> Sequence[ConditionalBlock]: + if self.catch: + return [self.catch] + else: + return [] + + @property + def end_byte_for_condition_block(self) -> int: + if self.code_block: + return self.code_block.end_byte + else: + return self.end_byte diff --git a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py index 23a3a7e6a..22f8af23f 100644 --- a/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py +++ b/tests/unit/codegen/sdk/python/statements/if_block_statement/test_if_block_statement_properties.py @@ -126,3 +126,143 @@ def foo(): assert len(alt_blocks[2].alternative_blocks) == 0 assert len(alt_blocks[2].elif_statements) == 0 assert alt_blocks[2].else_statement is None + + +def test_if_else_reassigment_handling(tmpdir) -> None: + content = """ + + if True: + PYSPARK = True + elif False: + PYSPARK = False + else: + PYSPARK = None + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassigment_handling_function(tmpdir) -> None: + content = """ + if True: + def foo(): + print('t') + elif False: + def foo(): + print('t') + else: + def foo(): + print('t') + foo() + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + foo = file.get_function("foo") + funct_call = file.function_calls[3] + for funct in file.functions: + usage = funct.usages[0] + assert usage.match == funct_call + + +def test_if_else_reassigment_handling_inside_func(tmpdir) -> None: + content = """ + def foo(a): + a = 1 + if xyz: + b = 1 + else: + b = 2 + f(a) # a resolves to 1 name + f(b) # b resolves to 2 possible names + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + foo = file.get_function("foo") + assert foo + assert len(foo.parameters[0].usages) == 0 + funct_call_a = foo.function_calls[0].args[0] + funct_call_b = foo.function_calls[1] + for symbol in file.symbols(True): + if symbol.name == "a": + assert len(symbol.usages) == 1 + symbol.usages[0].match == funct_call_a + elif symbol.name == "b": + assert len(symbol.usages) == 1 + symbol.usages[0].match == funct_call_b + + +def test_if_else_reassigment_handling_partial_if(tmpdir) -> None: + content = """ + PYSPARK = "TEST" + if True: + PYSPARK = True + elif None: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassigment_handling_double(tmpdir) -> None: + content = """ + if False: + PYSPARK = "TEST1" + elif True: + PYSPARK = "TEST2" + + if True: + PYSPARK = True + elif None: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg + + +def test_if_else_reassigment_handling_nested_usage(tmpdir) -> None: + content = """ + if True: + PYSPARK = True + elif None: + PYSPARK = False + print(PYSPARK) + + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + first = file.symbols[0] + second = file.symbols[1] + assert len(first.usages) == 0 + assert second.usages[0].match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py index 0a1928551..76bb5d0f4 100644 --- a/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py +++ b/tests/unit/codegen/sdk/python/statements/match_statement/test_try_catch_statement.py @@ -75,3 +75,23 @@ def risky(): assert not file.function_calls[0].is_wrapped_in(TryCatchStatement) assert file.function_calls[1].is_wrapped_in(TryCatchStatement) assert file.function_calls[2].is_wrapped_in(TryCatchStatement) + + +def test_try_except_reassigment_handling(tmpdir) -> None: + content = """ + try: + PYSPARK = True # This gets removed even though there is a later use + except ImportError: + PYSPARK = False + + print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols: + usage = symb.usages[0] + assert usage.match == pyspark_arg diff --git a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py index 3a07f1d81..be972cffd 100644 --- a/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py +++ b/tests/unit/codegen/sdk/python/statements/try_catch_statement/test_match_statement.py @@ -53,3 +53,27 @@ def risky(): assert len(dependencies) == 1 global_var = file.get_global_var("risky_var") assert dependencies[0] == global_var + + +def test_match_reassigment_handling(tmpdir) -> None: + content = """ +filter = 1 +match filter: + case 1: + PYSPARK=True + case 2: + PYSPARK=False + case _: + PYSPARK=None + +print(PYSPARK) + """ + + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file("test.py") + symbo = file.get_symbol("PYSPARK") + funct_call = file.function_calls[0] + pyspark_arg = funct_call.args.children[0] + for symb in file.symbols[1:]: + usage = symb.usages[0] + assert usage.match == pyspark_arg