Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
from codeflash.context.unused_definition_remover import (
collect_top_level_defs_with_usages,
extract_names_from_targets,
remove_unused_definitions_by_function_names,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
from codeflash.models.models import (
CodeContextType,
Expand All @@ -29,6 +33,8 @@
from jedi.api.classes import Name
from libcst import CSTNode

from codeflash.context.unused_definition_remover import UsageInfo


def get_code_optimization_context(
function_to_optimize: FunctionToOptimize,
Expand Down Expand Up @@ -498,8 +504,10 @@ def parse_code_and_prune_cst(
) -> str:
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
module = cst.parse_module(code)
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)

if code_context_type == CodeContextType.READ_WRITABLE:
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages)
elif code_context_type == CodeContextType.READ_ONLY:
filtered_node, found_target = prune_cst_for_read_only_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
Expand All @@ -524,7 +532,7 @@ def parse_code_and_prune_cst(


def prune_cst_for_read_writable_code( # noqa: PLR0911
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.

Expand Down Expand Up @@ -569,6 +577,21 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911

return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target

if isinstance(node, cst.Assign):
for target in node.targets:
names = extract_names_from_targets(target.target)
for name in names:
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
return node, True
return None, False

if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
names = extract_names_from_targets(node.target)
for name in names:
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
return node, True
return None, False

# For other nodes, we preserve them only if they contain target functions in their children.
section_names = get_section_names(node)
if not section_names:
Expand All @@ -583,7 +606,9 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_read_writable_code(child, target_functions, prefix)
filtered, found_target = prune_cst_for_read_writable_code(
child, target_functions, defs_with_usages, prefix
)
if filtered:
new_children.append(filtered)
section_found_target |= found_target
Expand All @@ -592,15 +617,16 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
found_any_target = True
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_read_writable_code(original_content, target_functions, prefix)
filtered, found_target = prune_cst_for_read_writable_code(
original_content, target_functions, defs_with_usages, prefix
)
if found_target:
found_any_target = True
if filtered:
updates[section] = filtered

if not found_any_target:
return None, False

return (node.with_changes(**updates) if updates else node), True


Expand Down
118 changes: 81 additions & 37 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

import libcst as cst

Expand Down Expand Up @@ -52,46 +52,64 @@ def collect_top_level_definitions(
node: cst.CSTNode, definitions: Optional[dict[str, UsageInfo]] = None
) -> dict[str, UsageInfo]:
"""Recursively collect all top-level variable, function, and class definitions."""
# Locally bind types and helpers for faster lookup
FunctionDef = cst.FunctionDef # noqa: N806
ClassDef = cst.ClassDef # noqa: N806
Assign = cst.Assign # noqa: N806
AnnAssign = cst.AnnAssign # noqa: N806
AugAssign = cst.AugAssign # noqa: N806
IndentedBlock = cst.IndentedBlock # noqa: N806

if definitions is None:
definitions = {}

# Handle top-level function definitions
if isinstance(node, cst.FunctionDef):
# Speed: Single isinstance+local var instead of several type calls
node_type = type(node)
# Fast path: function def
if node_type is FunctionDef:
name = node.name.value
definitions[name] = UsageInfo(
name=name,
used_by_qualified_function=False, # Will be marked later if in qualified functions
)
return definitions

# Handle top-level class definitions
if isinstance(node, cst.ClassDef):
# Fast path: class def
if node_type is ClassDef:
name = node.name.value
definitions[name] = UsageInfo(name=name)

# Also collect method definitions within the class
if hasattr(node, "body") and isinstance(node.body, cst.IndentedBlock):
for statement in node.body.body:
if isinstance(statement, cst.FunctionDef):
method_name = f"{name}.{statement.name.value}"
# Collect class methods
body = getattr(node, "body", None)
if body is not None and type(body) is IndentedBlock:
statements = body.body
# Precompute f-string template for efficiency
prefix = name + "."
for statement in statements:
if type(statement) is FunctionDef:
method_name = prefix + statement.name.value
definitions[method_name] = UsageInfo(name=method_name)

return definitions

# Handle top-level variable assignments
if isinstance(node, cst.Assign):
for target in node.targets:
# Fast path: assignment
if node_type is Assign:
# Inline extract_names_from_targets for single-target speed
targets = node.targets
append_def = definitions.__setitem__
for target in targets:
names = extract_names_from_targets(target.target)
for name in names:
definitions[name] = UsageInfo(name=name)
append_def(name, UsageInfo(name=name))
return definitions

if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
if isinstance(node.target, cst.Name):
name = node.target.value
if node_type is AnnAssign or node_type is AugAssign:
tgt = node.target
if type(tgt) is cst.Name:
name = tgt.value
definitions[name] = UsageInfo(name=name)
else:
names = extract_names_from_targets(node.target)
names = extract_names_from_targets(tgt)
for name in names:
definitions[name] = UsageInfo(name=name)
return definitions
Expand All @@ -100,12 +118,15 @@ def collect_top_level_definitions(
section_names = get_section_names(node)

if section_names:
getattr_ = getattr
for section in section_names:
original_content = getattr(node, section, None)
original_content = getattr_(node, section, None)
# Instead of isinstance check for list/tuple, rely on duck-type via iter
# If section contains a list of nodes
if isinstance(original_content, (list, tuple)):
defs = definitions # Move out for minor speed
for child in original_content:
collect_top_level_definitions(child, definitions)
collect_top_level_definitions(child, defs)
# If section contains a single node
elif original_content is not None:
collect_top_level_definitions(original_content, definitions)
Expand All @@ -122,6 +143,8 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
class DependencyCollector(cst.CSTVisitor):
"""Collects dependencies between definitions using the visitor pattern with depth tracking."""

METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)

def __init__(self, definitions: dict[str, UsageInfo]) -> None:
super().__init__()
self.definitions = definitions
Expand Down Expand Up @@ -259,8 +282,12 @@ def visit_Name(self, node: cst.Name) -> None:
if self.processing_variable and name in self.current_variable_names:
return

# Check if name is a top-level definition we're tracking
if name in self.definitions and name != self.current_top_level_name:
# skip if we are refrencing a class attribute and not a top-level definition
if self.class_depth > 0:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if parent is not None and isinstance(parent, cst.Attribute):
return
self.definitions[self.current_top_level_name].dependencies.add(name)


Expand Down Expand Up @@ -293,13 +320,20 @@ def _expand_qualified_functions(self) -> set[str]:

def mark_used_definitions(self) -> None:
"""Find all qualified functions and mark them and their dependencies as used."""
# First identify all specified functions (including expanded ones)
functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions]
# Avoid list comprehension for set intersection
expanded_names = self.expanded_qualified_functions
defs = self.definitions
# Use set intersection but only if defs.keys is a set (Python 3.12 dict_keys supports it efficiently)
fnames = (
expanded_names & defs.keys()
if isinstance(expanded_names, set)
else [name for name in expanded_names if name in defs]
)

# For each specified function, mark it and all its dependencies as used
for func_name in functions_to_mark:
self.definitions[func_name].used_by_qualified_function = True
for dep in self.definitions[func_name].dependencies:
for func_name in fnames:
defs[func_name].used_by_qualified_function = True
for dep in defs[func_name].dependencies:
self.mark_as_used_recursively(dep)

def mark_as_used_recursively(self, name: str) -> None:
Expand Down Expand Up @@ -457,6 +491,25 @@ def remove_unused_definitions_recursively( # noqa: PLR0911
return node, False


def collect_top_level_defs_with_usages(
code: Union[str, cst.Module], qualified_function_names: set[str]
) -> dict[str, UsageInfo]:
"""Collect all top level definitions (classes, variables or functions) and their usages."""
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
# Collect all definitions (top level classes, variables or function)
definitions = collect_top_level_definitions(module)

# Collect dependencies between definitions using the visitor pattern
wrapper = cst.MetadataWrapper(module)
dependency_collector = DependencyCollector(definitions)
wrapper.visit(dependency_collector)

# Mark definitions used by specified functions, and their dependencies recursively
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
usage_marker.mark_used_definitions()
return definitions


def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
"""Analyze a file and remove top level definitions not used by specified functions.

Expand All @@ -476,19 +529,10 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
return code

try:
# Collect all definitions (top level classes, variables or function)
definitions = collect_top_level_definitions(module)

# Collect dependencies between definitions using the visitor pattern
dependency_collector = DependencyCollector(definitions)
module.visit(dependency_collector)

# Mark definitions used by specified functions, and their dependencies recursively
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
usage_marker.mark_used_definitions()
defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names)

# Apply the recursive removal transformation
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)

return modified_module.code if modified_module else "" # noqa: TRY300
except Exception as e:
Expand Down
Loading
Loading