Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,34 @@ def _analyze_imports_in_optimized_code(
return dict(imported_names_map)


def find_target_node(
root: ast.AST, function_to_optimize: FunctionToOptimize
) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]:
parents = function_to_optimize.parents
node = root
for parent in parents:
# Fast loop: directly look for the matching ClassDef in node.body
body = getattr(node, "body", None)
if not body:
return None
for child in body:
if isinstance(child, ast.ClassDef) and child.name == parent.name:
node = child
break
else:
return None

# Now node is either the root or the target parent class; look for function
body = getattr(node, "body", None)
if not body:
return None
target_name = function_to_optimize.function_name
for child in body:
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name:
return child
return None


def detect_unused_helper_functions(
function_to_optimize: FunctionToOptimize,
code_context: CodeOptimizationContext,
Expand Down Expand Up @@ -641,11 +669,7 @@ def detect_unused_helper_functions(
optimized_ast = ast.parse(optimized_code)

# Find the optimized entrypoint function
entrypoint_function_ast = None
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
entrypoint_function_ast = node
break
entrypoint_function_ast = find_target_node(optimized_ast, function_to_optimize)

if not entrypoint_function_ast:
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
Expand Down
151 changes: 150 additions & 1 deletion tests/test_unused_helper_revert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from codeflash.context.unused_definition_remover import detect_unused_helper_functions
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeStringsMarkdown
from codeflash.models.models import CodeStringsMarkdown, FunctionParent
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.verification.verification_utils import TestConfig

Expand Down Expand Up @@ -1460,3 +1460,152 @@ def calculate_class(cls, n):
import shutil

shutil.rmtree(temp_dir, ignore_errors=True)


def test_unused_helper_detection_with_duplicated_function_name_in_different_classes():
"""Test detection when helpers are called via module.function style."""
temp_dir = Path(tempfile.mkdtemp())

try:
# Main file
main_file = temp_dir / "main.py"
main_file.write_text("""from __future__ import annotations
import json
from helpers import replace_quotes_with_backticks, simplify_worktree_paths
from dataclasses import asdict, dataclass

@dataclass
class LspMessage:

def serialize(self) -> str:
data = self._loop_through(asdict(self))
# Important: keep type as the first key, for making it easy and fast for the client to know if this is a lsp message before parsing it
ordered = {"type": self.type(), **data}
return (
message_delimiter
+ json.dumps(ordered)
+ message_delimiter
)


@dataclass
class LspMarkdownMessage(LspMessage):

def serialize(self) -> str:
self.markdown = simplify_worktree_paths(self.markdown)
self.markdown = replace_quotes_with_backticks(self.markdown)
return super().serialize()
""")

# Helpers file
helpers_file = temp_dir / "helpers.py"
helpers_file.write_text("""def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
path_in_msg = worktree_path_regex.search(msg)
if path_in_msg:
last_part_of_path = path_in_msg.group(0).split("/")[-1]
if highlight:
last_part_of_path = f"`{last_part_of_path}`"
return msg.replace(path_in_msg.group(0), last_part_of_path)
return msg


def replace_quotes_with_backticks(text: str) -> str:
# double-quoted strings
text = _double_quote_pat.sub(r"`\1`", text)
# single-quoted strings
return _single_quote_pat.sub(r"`\1`", text)
""")

# Optimized version that only uses add_numbers
optimized_code = """
```python:main.py
from __future__ import annotations

import json
from dataclasses import asdict, dataclass

from codeflash.lsp.helpers import (replace_quotes_with_backticks,
simplify_worktree_paths)


@dataclass
class LspMessage:

def serialize(self) -> str:
# Use local variable to minimize lookup costs and avoid unnecessary dictionary unpacking
data = self._loop_through(asdict(self))
msg_type = self.type()
ordered = {'type': msg_type}
ordered.update(data)
return (
message_delimiter
+ json.dumps(ordered)
+ message_delimiter # \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message
)

@dataclass
class LspMarkdownMessage(LspMessage):

def serialize(self) -> str:
# Side effect required, must preserve for behavioral correctness
self.markdown = simplify_worktree_paths(self.markdown)
self.markdown = replace_quotes_with_backticks(self.markdown)
return super().serialize()
```
```python:helpers.py
def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002
m = worktree_path_regex.search(msg)
if m:
# More efficient way to get last path part
last_part_of_path = m.group(0).rpartition('/')[-1]
if highlight:
last_part_of_path = f"`{last_part_of_path}`"
return msg.replace(m.group(0), last_part_of_path)
return msg

def replace_quotes_with_backticks(text: str) -> str:
# Efficient string substitution, reduces intermediate string allocations
return _single_quote_pat.sub(
r"`\1`",
_double_quote_pat.sub(r"`\1`", text),
)
```
"""

# Create test config
test_cfg = TestConfig(
tests_root=temp_dir / "tests",
tests_project_rootdir=temp_dir,
project_root_path=temp_dir,
test_framework="pytest",
pytest_cmd="pytest",
)

# Create FunctionToOptimize instance
function_to_optimize = FunctionToOptimize(
file_path=main_file, function_name="serialize", qualified_name="serialize", parents=[
FunctionParent(name="LspMarkdownMessage", type="ClassDef"),
]
)

optimizer = FunctionOptimizer(
function_to_optimize=function_to_optimize,
test_cfg=test_cfg,
function_to_optimize_source_code=main_file.read_text(),
)

ctx_result = optimizer.get_code_optimization_context()
assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}"

code_context = ctx_result.unwrap()

unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code))

unused_names = {uh.qualified_name for uh in unused_helpers}
assert len(unused_names) == 0 # no unused helpers

finally:
# Cleanup
import shutil

shutil.rmtree(temp_dir, ignore_errors=True)
Loading