Skip to content
29 changes: 26 additions & 3 deletions src/codegen/sdk/core/expressions/name.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from codegen.sdk.core.autocommit import reader, writer
from codegen.sdk.core.dataclasses.usage import UsageKind
from codegen.sdk.core.expressions.expression import Expression
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
from codegen.sdk.core.interfaces.resolvable import Resolvable
from codegen.sdk.extensions.autocommit import commiter
from codegen.shared.decorators.docs import apidoc, noapidoc

if TYPE_CHECKING:
from codegen.sdk.core.import_resolution import Import, WildcardImport
from codegen.sdk.core.interfaces.has_name import HasName

from codegen.sdk.core.symbol import Symbol

Parent = TypeVar("Parent", bound="Expression")

Expand All @@ -29,10 +31,9 @@ class Name(Expression[Parent], Resolvable, Generic[Parent]):
@override
def _resolved_types(self) -> Generator[ResolutionStack[Self], None, None]:
"""Resolve the types used by this symbol."""
if used := self.resolve_name(self.source, self.start_byte):
for used in self.resolve_name(self.source, self.start_byte):
yield from self.with_resolution_frame(used)

@noapidoc
@commiter
def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName | None "] = None) -> None:
"""Compute the dependencies of the export object."""
Expand All @@ -48,3 +49,25 @@ def _compute_dependencies(self, usage_type: UsageKind, dest: Optional["HasName |
def rename_if_matching(self, old: str, new: str):
if self.source == old:
self.edit(new)

@noapidoc
@reader
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator["Symbol | Import | WildcardImport"]:
resolved_name = next(super().resolve_name(name, start_byte or self.start_byte, strict=strict), None)
if resolved_name:
yield resolved_name
else:
return

if hasattr(resolved_name, "parent") and (conditional_parent := resolved_name.parent_of_type(ConditionalBlock)):
top_of_conditional = conditional_parent.start_byte
if self.parent_of_type(ConditionalBlock) == conditional_parent:
# Use in the same block, should only depend on the inside of the block
return
for other_conditional in conditional_parent.other_possible_blocks:
if cond_name := next(other_conditional.resolve_name(name, start_byte=other_conditional.end_byte_for_condition_block), None):
if cond_name.start_byte >= other_conditional.start_byte:
yield cond_name
top_of_conditional = min(top_of_conditional, other_conditional.start_byte)

yield from self.resolve_name(name, top_of_conditional, strict=False)
42 changes: 35 additions & 7 deletions src/codegen/sdk/core/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import resource
import sys
from abc import abstractmethod
from collections.abc import Sequence
from collections.abc import Generator, Sequence
from functools import cached_property
from os import PathLike
from pathlib import Path
Expand Down Expand Up @@ -744,7 +744,7 @@ def get_symbol(self, name: str) -> Symbol | None:
Returns:
Symbol | None: The found symbol, or None if not found.
"""
if symbol := self.resolve_name(name, self.end_byte):
if symbol := next(self.resolve_name(name, self.end_byte), None):
if isinstance(symbol, Symbol):
return symbol
return next((x for x in self.symbols if x.name == name), None)
Expand Down Expand Up @@ -819,7 +819,7 @@ def get_class(self, name: str) -> TClass | None:
Returns:
TClass | None: The matching Class object if found, None otherwise.
"""
if symbol := self.resolve_name(name, self.end_byte):
if symbol := next(self.resolve_name(name, self.end_byte), None):
if isinstance(symbol, Class):
return symbol

Expand Down Expand Up @@ -880,13 +880,41 @@ def valid_symbol_names(self) -> dict[str, Symbol | TImport | WildcardImport[TImp

@noapidoc
@reader
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
"""Resolves a name to a symbol, import, or wildcard import within the file's scope.

Performs name resolution by first checking the file's valid symbols and imports. When a start_byte
is provided, ensures proper scope handling by only resolving to symbols that are defined before
that position in the file.

Args:
name (str): The name to resolve.
start_byte (int | None): If provided, only resolves to symbols defined before this byte position
in the file. Used for proper scope handling. Defaults to None.
strict (bool): When True and using start_byte, only yields symbols if found in the correct scope.
When False, allows falling back to global scope. Defaults to True.

Yields:
Symbol | Import | WildcardImport: The resolved symbol, import, or wildcard import that matches
the name and scope requirements. Yields at most one result.
"""
if resolved := self.valid_symbol_names.get(name):
# If we have a start_byte and the resolved symbol is after it,
# we need to look for earlier definitions of the symbol
if start_byte is not None and resolved.end_byte > start_byte:
for symbol in self.symbols:
# Search backwards through symbols to find the most recent definition
# that comes before our start_byte position
for symbol in reversed(self.symbols):
if symbol.start_byte <= start_byte and symbol.name == name:
return symbol
return resolved
yield symbol
return
# If strict mode and no valid symbol found, return nothing
if not strict:
return
# Either no start_byte constraint or symbol is before start_byte
yield resolved
return
return

@property
@reader
Expand Down
7 changes: 4 additions & 3 deletions src/codegen/sdk/core/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ def is_async(self) -> bool:

@noapidoc
@reader
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
from codegen.sdk.core.class_definition import Class

for symbol in self.valid_symbol_names:
if symbol.name == name and (start_byte is None or (symbol.start_byte if isinstance(symbol, Class | Function) else symbol.end_byte) <= start_byte):
return symbol
return super().resolve_name(name, start_byte)
yield symbol
return
yield from super().resolve_name(name, start_byte)

@cached_property
@noapidoc
Expand Down
17 changes: 17 additions & 0 deletions src/codegen/sdk/core/interfaces/conditional_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence

from codegen.sdk.core.statements.statement import Statement


class ConditionalBlock(Statement, ABC):
"""An interface for any code block that might not be executed in the code, e.g if block/else block/try block/catch block ect."""

@property
@abstractmethod
def other_possible_blocks(self) -> Sequence["ConditionalBlock"]:
"""Should return all other "branches" that might be executed instead."""

@property
def end_byte_for_condition_block(self) -> int:
return self.end_byte
7 changes: 4 additions & 3 deletions src/codegen/sdk/core/interfaces/editable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,10 +1003,11 @@ def viz(self) -> VizNode:

@noapidoc
@reader
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
if self.parent is not None:
return self.parent.resolve_name(name, start_byte or self.start_byte)
return self.file.resolve_name(name, start_byte or self.start_byte)
yield from self.parent.resolve_name(name, start_byte or self.start_byte, strict=strict)
else:
yield from self.file.resolve_name(name, start_byte or self.start_byte, strict=strict)

@cached_property
@noapidoc
Expand Down
3 changes: 2 additions & 1 deletion src/codegen/sdk/core/statements/catch_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING, Generic, Self, TypeVar

from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
from codegen.sdk.core.statements.block_statement import BlockStatement
from codegen.sdk.extensions.autocommit import commiter
from codegen.shared.decorators.docs import apidoc, noapidoc
Expand All @@ -17,7 +18,7 @@


@apidoc
class CatchStatement(BlockStatement[Parent], Generic[Parent]):
class CatchStatement(ConditionalBlock, BlockStatement[Parent], Generic[Parent]):
"""Abstract representation catch clause.

Attributes:
Expand Down
18 changes: 12 additions & 6 deletions src/codegen/sdk/core/statements/for_loop_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from codegen.shared.decorators.docs import apidoc, noapidoc

if TYPE_CHECKING:
from collections.abc import Generator

from codegen.sdk.core.detached_symbols.code_block import CodeBlock
from codegen.sdk.core.expressions import Expression
from codegen.sdk.core.import_resolution import Import, WildcardImport
Expand All @@ -36,19 +38,23 @@ class ForLoopStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]):

@noapidoc
@reader
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
if self.item and isinstance(self.iterable, Chainable):
if start_byte is None or start_byte > self.iterable.end_byte:
if name == self.item:
for frame in self.iterable.resolved_type_frames:
if frame.generics:
return next(iter(frame.generics.values()))
return frame.top.node
yield next(iter(frame.generics.values()))
return
yield frame.top.node
return
elif isinstance(self.item, Collection):
for idx, item in enumerate(self.item):
if item == name:
for frame in self.iterable.resolved_type_frames:
if frame.generics and len(frame.generics) > idx:
return list(frame.generics.values())[idx]
return frame.top.node
return super().resolve_name(name, start_byte)
yield list(frame.generics.values())[idx]
return
yield frame.top.node
return
yield from super().resolve_name(name, start_byte)
28 changes: 27 additions & 1 deletion src/codegen/sdk/core/statements/if_block_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,14 @@
from codegen.sdk.core.autocommit import reader, writer
from codegen.sdk.core.dataclasses.usage import UsageKind
from codegen.sdk.core.function import Function
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
from codegen.sdk.core.statements.statement import Statement, StatementType
from codegen.sdk.extensions.autocommit import commiter
from codegen.shared.decorators.docs import apidoc, noapidoc

if TYPE_CHECKING:
from collections.abc import Sequence

from codegen.sdk.core.detached_symbols.code_block import CodeBlock
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
from codegen.sdk.core.expressions import Expression
Expand All @@ -26,7 +29,7 @@


@apidoc
class IfBlockStatement(Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]):
class IfBlockStatement(ConditionalBlock, Statement[TCodeBlock], Generic[TCodeBlock, TIfBlockStatement]):
"""Abstract representation of the if/elif/else if/else statement block.

For example, if there is a code block like:
Expand Down Expand Up @@ -271,3 +274,26 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) -
self.remove_byte_range(self.ts_node.start_byte, remove_end)
else:
self.remove()

@property
def other_possible_blocks(self) -> Sequence[ConditionalBlock]:
if self.is_if_statement:
return self._main_if_block.alternative_blocks
elif self.is_elif_statement:
main = self._main_if_block
statements = [main]
if main.else_statement:
statements.append(main.else_statement)
for statement in main.elif_statements:
if statement != self:
statements.append(statement)
return statements
else:
main = self._main_if_block
return [main, *main.elif_statements]

@property
def end_byte_for_condition_block(self) -> int:
if self.is_if_statement:
return self.consequence_block.end_byte
return self.end_byte
7 changes: 6 additions & 1 deletion src/codegen/sdk/core/statements/switch_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING, Generic, Self, TypeVar

from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
from codegen.sdk.core.statements.block_statement import BlockStatement
from codegen.sdk.extensions.autocommit import commiter
from codegen.shared.decorators.docs import apidoc, noapidoc
Expand All @@ -18,7 +19,7 @@


@apidoc
class SwitchCase(BlockStatement[Parent], Generic[Parent]):
class SwitchCase(ConditionalBlock, BlockStatement[Parent], Generic[Parent]):
"""Abstract representation for a switch case.

Attributes:
Expand All @@ -34,3 +35,7 @@ def _compute_dependencies(self, usage_type: UsageKind | None = None, dest: HasNa
if self.condition:
self.condition._compute_dependencies(usage_type, dest)
super()._compute_dependencies(usage_type, dest)

@property
def other_possible_blocks(self) -> list[ConditionalBlock]:
return [case for case in self.parent.cases if case != self]
3 changes: 2 additions & 1 deletion src/codegen/sdk/core/statements/try_catch_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from abc import ABC
from typing import TYPE_CHECKING, Generic, TypeVar

from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
from codegen.sdk.core.interfaces.has_block import HasBlock
from codegen.sdk.core.statements.block_statement import BlockStatement
from codegen.sdk.core.statements.statement import StatementType
Expand All @@ -16,7 +17,7 @@


@apidoc
class TryCatchStatement(BlockStatement[Parent], HasBlock, ABC, Generic[Parent]):
class TryCatchStatement(ConditionalBlock, BlockStatement[Parent], HasBlock, ABC, Generic[Parent]):
"""Abstract representation of the try catch statement block.

Attributes:
Expand Down
12 changes: 8 additions & 4 deletions src/codegen/sdk/python/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from codegen.shared.logging.get_logger import get_logger

if TYPE_CHECKING:
from collections.abc import Generator

from tree_sitter import Node as TSNode

from codegen.sdk.codebase.codebase_context import CodebaseContext
Expand Down Expand Up @@ -119,15 +121,17 @@ def is_class_method(self) -> bool:

@noapidoc
@reader
def resolve_name(self, name: str, start_byte: int | None = None) -> Symbol | Import | WildcardImport | None:
def resolve_name(self, name: str, start_byte: int | None = None, strict: bool = True) -> Generator[Symbol | Import | WildcardImport]:
if self.is_method:
if not self.is_static_method:
if len(self.parameters.symbols) > 0:
if name == self.parameters[0].name:
return self.parent_class
yield self.parent_class
return
if name == "super()":
return self.parent_class
return super().resolve_name(name, start_byte)
yield self.parent_class
return
yield from super().resolve_name(name, start_byte)

@noapidoc
@commiter
Expand Down
6 changes: 5 additions & 1 deletion src/codegen/sdk/python/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@ def _file_by_custom_resolve_paths(self, resolve_paths: list[str], filepath: str)
"""
for resolve_path in resolve_paths:
filepath_new: str = os.path.join(resolve_path, filepath)
if file := self.ctx.get_file(filepath_new):
try:
file = self.ctx.get_file(filepath_new)
except AssertionError as e:
file = None
if file:
return file

return None
Expand Down
5 changes: 5 additions & 0 deletions src/codegen/sdk/python/statements/catch_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from tree_sitter import Node as PyNode

from codegen.sdk.codebase.codebase_context import CodebaseContext
from codegen.sdk.core.interfaces.conditional_block import ConditionalBlock
from codegen.sdk.core.node_id_factory import NodeId


Expand All @@ -26,3 +27,7 @@ class PyCatchStatement(CatchStatement[PyCodeBlock], PyBlockStatement):
def __init__(self, ts_node: PyNode, file_node_id: NodeId, ctx: CodebaseContext, parent: PyCodeBlock, pos: int | None = None) -> None:
super().__init__(ts_node, file_node_id, ctx, parent, pos)
self.condition = self.children[0]

@property
def other_possible_blocks(self) -> list[ConditionalBlock]:
return [clause for clause in self.parent.except_clauses if clause != self] + [self.parent]
1 change: 0 additions & 1 deletion src/codegen/sdk/python/statements/if_block_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from codegen.sdk.core.node_id_factory import NodeId
from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock


Parent = TypeVar("Parent", bound="PyCodeBlock")


Expand Down
Loading
Loading