diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index e90498936..624102b73 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -72,6 +72,33 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]: return True +def find_insertion_index_after_imports(node: cst.Module) -> int: + """Find the position of the last import statement in the top-level of the module.""" + insert_index = 0 + for i, stmt in enumerate(node.body): + is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any( + isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body + ) + + is_conditional_import = isinstance(stmt, cst.If) and all( + isinstance(inner, cst.SimpleStatementLine) + and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body) + for inner in stmt.body.body + ) + + if is_top_level_import or is_conditional_import: + insert_index = i + 1 + + # Stop scanning once we reach a class or function definition. + # Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file. + # Without this check, a stray import later in the file + # would incorrectly shift our insertion index below actual code definitions. + if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): + break + + return insert_index + + class GlobalAssignmentTransformer(cst.CSTTransformer): """Transforms global assignments in the original file with those from the new file.""" @@ -122,32 +149,6 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c return updated_node - def _find_insertion_index(self, updated_node: cst.Module) -> int: - """Find the position of the last import statement in the top-level of the module.""" - insert_index = 0 - for i, stmt in enumerate(updated_node.body): - is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any( - isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body - ) - - is_conditional_import = isinstance(stmt, cst.If) and all( - isinstance(inner, cst.SimpleStatementLine) - and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body) - for inner in stmt.body.body - ) - - if is_top_level_import or is_conditional_import: - insert_index = i + 1 - - # Stop scanning once we reach a class or function definition. - # Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file. - # Without this check, a stray import later in the file - # would incorrectly shift our insertion index below actual code definitions. - if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)): - break - - return insert_index - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 # Add any new assignments that weren't in the original file new_statements = list(updated_node.body) @@ -161,7 +162,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c if assignments_to_append: # after last top-level imports - insert_index = self._find_insertion_index(updated_node) + insert_index = find_insertion_index_after_imports(updated_node) assignment_lines = [ cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()]) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index b3ee6e34e..47aa2e75f 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -3,13 +3,18 @@ import ast from collections import defaultdict from functools import lru_cache +from itertools import chain from typing import TYPE_CHECKING, Optional, TypeVar import libcst as cst from libcst.metadata import PositionProvider from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module +from codeflash.code_utils.code_extractor import ( + add_global_assignments, + add_needed_imports_from_module, + find_insertion_index_after_imports, +) from codeflash.code_utils.config_parser import find_conftest_files from codeflash.code_utils.formatter import sort_imports from codeflash.code_utils.line_profile_utils import ImportAdder @@ -249,6 +254,7 @@ def __init__( ] = {} # keys are (class_name, function_name) self.new_functions: list[cst.FunctionDef] = [] self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list) + self.new_classes: list[cst.ClassDef] = [] self.current_class = None self.modified_init_functions: dict[str, cst.FunctionDef] = {} @@ -271,6 +277,10 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool: self.current_class = node.name.value parents = (FunctionParent(name=node.name.value, type="ClassDef"),) + + if (node.name.value, ()) not in self.preexisting_objects: + self.new_classes.append(node) + for child_node in node.body.body: if ( self.preexisting_objects @@ -290,6 +300,7 @@ class OptimFunctionReplacer(cst.CSTTransformer): def __init__( self, modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None, + new_classes: Optional[list[cst.ClassDef]] = None, new_functions: Optional[list[cst.FunctionDef]] = None, new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None, modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None, @@ -297,6 +308,7 @@ def __init__( super().__init__() self.modified_functions = modified_functions if modified_functions is not None else {} self.new_functions = new_functions if new_functions is not None else [] + self.new_classes = new_classes if new_classes is not None else [] self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list) self.modified_init_functions: dict[str, cst.FunctionDef] = ( modified_init_functions if modified_init_functions is not None else {} @@ -335,19 +347,33 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 node = updated_node max_function_index = None - class_index = None + max_class_index = None for index, _node in enumerate(node.body): if isinstance(_node, cst.FunctionDef): max_function_index = index if isinstance(_node, cst.ClassDef): - class_index = index + max_class_index = index + + if self.new_classes: + existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)} + + unique_classes = [ + new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names + ] + if unique_classes: + new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node) + new_body = list( + chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:]) + ) + node = node.with_changes(body=new_body) + if max_function_index is not None: node = node.with_changes( body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :]) ) - elif class_index is not None: + elif max_class_index is not None: node = node.with_changes( - body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :]) + body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :]) ) else: node = node.with_changes(body=(*self.new_functions, *node.body)) @@ -373,18 +399,20 @@ def replace_functions_in_file( parsed_function_names.append((class_name, function_name)) # Collect functions we want to modify from the optimized code - module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code)) + optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code)) + original_module = cst.parse_module(source_code) + visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names)) - module.visit(visitor) + optimized_module.visit(visitor) # Replace these functions in the original code transformer = OptimFunctionReplacer( modified_functions=visitor.modified_functions, + new_classes=visitor.new_classes, new_functions=visitor.new_functions, new_class_functions=visitor.new_class_functions, modified_init_functions=visitor.modified_init_functions, ) - original_module = cst.parse_module(source_code) modified_tree = original_module.visit(transformer) return modified_tree.code diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 86e5f989d..04d83f13f 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -215,7 +215,7 @@ def new_function(self, value: cst.Name): return other_function(self.name) def new_function2(value): return value - """ +""" original_code = """import libcst as cst from typing import Mandatory @@ -230,19 +230,28 @@ def other_function(st): print("Salut monde") """ - expected = """from typing import Mandatory + expected = """import libcst as cst +from typing import Mandatory + +class NewClass: + def __init__(self, name): + self.name = name + def new_function(self, value: cst.Name): + return other_function(self.name) + def new_function2(value): + return value print("Au revoir") def yet_another_function(values): return len(values) -def other_function(st): - return(st * 2) - def totally_new_function(value): return value +def other_function(st): + return(st * 2) + print("Salut monde") """ @@ -279,7 +288,7 @@ def new_function(self, value): return other_function(self.name) def new_function2(value): return value - """ +""" original_code = """import libcst as cst from typing import Mandatory @@ -296,17 +305,25 @@ def other_function(st): """ expected = """from typing import Mandatory +class NewClass: + def __init__(self, name): + self.name = name + def new_function(self, value): + return other_function(self.name) + def new_function2(value): + return value + print("Au revoir") def yet_another_function(values): return len(values) + 2 -def other_function(st): - return(st * 2) - def totally_new_function(value): return value +def other_function(st): + return(st * 2) + print("Salut monde") """ @@ -3619,4 +3636,110 @@ async def task(): await asyncio.sleep(1) return "done" ''' - assert is_zero_diff(original_code, optimized_code) \ No newline at end of file + assert is_zero_diff(original_code, optimized_code) + + + +def test_code_replacement_with_new_helper_class() -> None: + optim_code = """from __future__ import annotations + +import itertools +import re +from dataclasses import dataclass +from typing import Any, Callable, Iterator, Sequence + +from bokeh.models import HoverTool, Plot, Tool + + +# Move the Item dataclass to module-level to avoid redefining it on every function call +@dataclass(frozen=True) +class _RepeatedToolItem: + obj: Tool + properties: dict[str, Any] + +def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]: + key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__ + # Pre-collect properties for all objects by group to avoid repeated calls + for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key): + grouped = list(group) + n = len(grouped) + if n > 1: + # Precompute all properties once for this group + props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped] + i = 0 + while i < len(props) - 1: + head = props[i] + for j in range(i+1, len(props)): + item = props[j] + if item.properties == head.properties: + yield item.obj + i += 1 +""" + + original_code = """from __future__ import annotations +import itertools +import re +from bokeh.models import HoverTool, Plot, Tool +from dataclasses import dataclass +from typing import Any, Callable, Iterator, Sequence + +def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]: + @dataclass(frozen=True) + class Item: + obj: Tool + properties: dict[str, Any] + + key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__ + + for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key): + rest = [ Item(obj, obj.properties_with_values()) for obj in group ] + while len(rest) > 1: + head, *rest = rest + for item in rest: + if item.properties == head.properties: + yield item.obj +""" + + expected = """from __future__ import annotations +import itertools +from bokeh.models import Tool +from dataclasses import dataclass +from typing import Any, Callable, Iterator + + +# Move the Item dataclass to module-level to avoid redefining it on every function call +@dataclass(frozen=True) +class _RepeatedToolItem: + obj: Tool + properties: dict[str, Any] + +def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]: + key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__ + # Pre-collect properties for all objects by group to avoid repeated calls + for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key): + grouped = list(group) + n = len(grouped) + if n > 1: + # Precompute all properties once for this group + props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped] + i = 0 + while i < len(props) - 1: + head = props[i] + for j in range(i+1, len(props)): + item = props[j] + if item.properties == head.properties: + yield item.obj + i += 1 +""" + + function_names: list[str] = ["_collect_repeated_tools"] + preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) + new_code: str = replace_functions_and_add_imports( + source_code=original_code, + function_names=function_names, + optimized_code=optim_code, + module_abspath=Path(__file__).resolve(), + preexisting_objects=preexisting_objects, + project_root_path=Path(__file__).resolve().parent.resolve(), + ) + assert new_code == expected