Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions codeflash/languages/python/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,14 +703,39 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:
def _get_expr_name(node: ast.AST | None) -> str | None:
if node is None:
return None
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
parent_name = _get_expr_name(node.value)
return node.attr if parent_name is None else f"{parent_name}.{node.attr}"
if isinstance(node, ast.Call):
return _get_expr_name(node.func)
return None

# Iteratively collect attribute parts and skip Call nodes to avoid recursion.
parts: list[str] = []
current = node
# Walk down attribute/call chain collecting attribute names.
while True:
if isinstance(current, ast.Attribute):
# collect attrs in reverse (will join later)
parts.append(current.attr)
current = current.value
continue
if isinstance(current, ast.Call):
current = current.func
continue
if isinstance(current, ast.Name):
# If we reached a base name, include it at the front.
base_name = current.id
else:
base_name = None
break

if not parts:
# No attribute parts collected: return base name or None (matches original).
return base_name

# parts were collected from outermost to innermost attr (append order),
# but we want base-first order. Reverse to get innermost-first, then prepend base if present.
parts.reverse()
if base_name is not None:
parts.insert(0, base_name)
# Join parts with dots. If base_name is None, this still returns the joined attrs,
# which matches the original behavior where an Attribute with non-name base returns attr(s).
return ".".join(parts)


def _collect_import_aliases(module_tree: ast.Module) -> dict[str, str]:
Expand All @@ -735,10 +760,13 @@ def _expr_matches_name(node: ast.AST | None, import_aliases: dict[str, str], suf
expr_name = _get_expr_name(node)
if expr_name is None:
return False
if expr_name == suffix or expr_name.endswith(f".{suffix}"):

# Precompute ".suffix" to avoid repeated f-string allocations.
suffix_dot = "." + suffix
if expr_name == suffix or expr_name.endswith(suffix_dot):
return True
resolved_name = import_aliases.get(expr_name)
return resolved_name is not None and (resolved_name == suffix or resolved_name.endswith(f".{suffix}"))
return resolved_name is not None and (resolved_name == suffix or resolved_name.endswith(suffix_dot))


def _get_node_source(node: ast.AST | None, module_source: str, fallback: str = "...") -> str:
Expand Down
Loading