Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,33 @@ def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
return True


def find_insertion_index_after_imports(node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
)

is_conditional_import = isinstance(stmt, cst.If) and all(
isinstance(inner, cst.SimpleStatementLine)
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
for inner in stmt.body.body
)

if is_top_level_import or is_conditional_import:
insert_index = i + 1

# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break

return insert_index


class GlobalAssignmentTransformer(cst.CSTTransformer):
"""Transforms global assignments in the original file with those from the new file."""

Expand Down Expand Up @@ -122,32 +149,6 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c

return updated_node

def _find_insertion_index(self, updated_node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
insert_index = 0
for i, stmt in enumerate(updated_node.body):
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
)

is_conditional_import = isinstance(stmt, cst.If) and all(
isinstance(inner, cst.SimpleStatementLine)
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
for inner in stmt.body.body
)

if is_top_level_import or is_conditional_import:
insert_index = i + 1

# Stop scanning once we reach a class or function definition.
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
# Without this check, a stray import later in the file
# would incorrectly shift our insertion index below actual code definitions.
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
break

return insert_index

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)
Expand All @@ -161,7 +162,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c

if assignments_to_append:
# after last top-level imports
insert_index = self._find_insertion_index(updated_node)
insert_index = find_insertion_index_after_imports(updated_node)

assignment_lines = [
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
Expand Down
44 changes: 36 additions & 8 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import ast
from collections import defaultdict
from functools import lru_cache
from itertools import chain
from typing import TYPE_CHECKING, Optional, TypeVar

import libcst as cst
from libcst.metadata import PositionProvider

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
from codeflash.code_utils.code_extractor import (
add_global_assignments,
add_needed_imports_from_module,
find_insertion_index_after_imports,
)
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.formatter import sort_imports
from codeflash.code_utils.line_profile_utils import ImportAdder
Expand Down Expand Up @@ -249,6 +254,7 @@ def __init__(
] = {} # keys are (class_name, function_name)
self.new_functions: list[cst.FunctionDef] = []
self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list)
self.new_classes: list[cst.ClassDef] = []
self.current_class = None
self.modified_init_functions: dict[str, cst.FunctionDef] = {}

Expand All @@ -271,6 +277,10 @@ def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.current_class = node.name.value

parents = (FunctionParent(name=node.name.value, type="ClassDef"),)

if (node.name.value, ()) not in self.preexisting_objects:
self.new_classes.append(node)

for child_node in node.body.body:
if (
self.preexisting_objects
Expand All @@ -290,13 +300,15 @@ class OptimFunctionReplacer(cst.CSTTransformer):
def __init__(
self,
modified_functions: Optional[dict[tuple[str | None, str], cst.FunctionDef]] = None,
new_classes: Optional[list[cst.ClassDef]] = None,
new_functions: Optional[list[cst.FunctionDef]] = None,
new_class_functions: Optional[dict[str, list[cst.FunctionDef]]] = None,
modified_init_functions: Optional[dict[str, cst.FunctionDef]] = None,
) -> None:
super().__init__()
self.modified_functions = modified_functions if modified_functions is not None else {}
self.new_functions = new_functions if new_functions is not None else []
self.new_classes = new_classes if new_classes is not None else []
self.new_class_functions = new_class_functions if new_class_functions is not None else defaultdict(list)
self.modified_init_functions: dict[str, cst.FunctionDef] = (
modified_init_functions if modified_init_functions is not None else {}
Expand Down Expand Up @@ -335,19 +347,33 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
node = updated_node
max_function_index = None
class_index = None
max_class_index = None
for index, _node in enumerate(node.body):
if isinstance(_node, cst.FunctionDef):
max_function_index = index
if isinstance(_node, cst.ClassDef):
class_index = index
max_class_index = index

if self.new_classes:
existing_class_names = {_node.name.value for _node in node.body if isinstance(_node, cst.ClassDef)}

unique_classes = [
new_class for new_class in self.new_classes if new_class.name.value not in existing_class_names
]
if unique_classes:
new_classes_insertion_idx = max_class_index or find_insertion_index_after_imports(node)
new_body = list(
chain(node.body[:new_classes_insertion_idx], unique_classes, node.body[new_classes_insertion_idx:])
)
node = node.with_changes(body=new_body)

if max_function_index is not None:
node = node.with_changes(
body=(*node.body[: max_function_index + 1], *self.new_functions, *node.body[max_function_index + 1 :])
)
elif class_index is not None:
elif max_class_index is not None:
node = node.with_changes(
body=(*node.body[: class_index + 1], *self.new_functions, *node.body[class_index + 1 :])
body=(*node.body[: max_class_index + 1], *self.new_functions, *node.body[max_class_index + 1 :])
)
else:
node = node.with_changes(body=(*self.new_functions, *node.body))
Expand All @@ -373,18 +399,20 @@ def replace_functions_in_file(
parsed_function_names.append((class_name, function_name))

# Collect functions we want to modify from the optimized code
module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
optimized_module = cst.metadata.MetadataWrapper(cst.parse_module(optimized_code))
original_module = cst.parse_module(source_code)

visitor = OptimFunctionCollector(preexisting_objects, set(parsed_function_names))
module.visit(visitor)
optimized_module.visit(visitor)

# Replace these functions in the original code
transformer = OptimFunctionReplacer(
modified_functions=visitor.modified_functions,
new_classes=visitor.new_classes,
new_functions=visitor.new_functions,
new_class_functions=visitor.new_class_functions,
modified_init_functions=visitor.modified_init_functions,
)
original_module = cst.parse_module(source_code)
modified_tree = original_module.visit(transformer)
return modified_tree.code

Expand Down
143 changes: 133 additions & 10 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def new_function(self, value: cst.Name):
return other_function(self.name)
def new_function2(value):
return value
"""
"""

original_code = """import libcst as cst
from typing import Mandatory
Expand All @@ -230,19 +230,28 @@ def other_function(st):

print("Salut monde")
"""
expected = """from typing import Mandatory
expected = """import libcst as cst
from typing import Mandatory

class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value: cst.Name):
return other_function(self.name)
def new_function2(value):
return value

print("Au revoir")

def yet_another_function(values):
return len(values)

def other_function(st):
return(st * 2)

def totally_new_function(value):
return value

def other_function(st):
return(st * 2)

print("Salut monde")
"""

Expand Down Expand Up @@ -279,7 +288,7 @@ def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value
"""
"""

original_code = """import libcst as cst
from typing import Mandatory
Expand All @@ -296,17 +305,25 @@ def other_function(st):
"""
expected = """from typing import Mandatory

class NewClass:
def __init__(self, name):
self.name = name
def new_function(self, value):
return other_function(self.name)
def new_function2(value):
return value

print("Au revoir")

def yet_another_function(values):
return len(values) + 2

def other_function(st):
return(st * 2)

def totally_new_function(value):
return value

def other_function(st):
return(st * 2)

print("Salut monde")
"""

Expand Down Expand Up @@ -3619,4 +3636,110 @@ async def task():
await asyncio.sleep(1)
return "done"
'''
assert is_zero_diff(original_code, optimized_code)
assert is_zero_diff(original_code, optimized_code)



def test_code_replacement_with_new_helper_class() -> None:
optim_code = """from __future__ import annotations

import itertools
import re
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Sequence

from bokeh.models import HoverTool, Plot, Tool


# Move the Item dataclass to module-level to avoid redefining it on every function call
@dataclass(frozen=True)
class _RepeatedToolItem:
obj: Tool
properties: dict[str, Any]

def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
# Pre-collect properties for all objects by group to avoid repeated calls
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
grouped = list(group)
n = len(grouped)
if n > 1:
# Precompute all properties once for this group
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
i = 0
while i < len(props) - 1:
head = props[i]
for j in range(i+1, len(props)):
item = props[j]
if item.properties == head.properties:
yield item.obj
i += 1
"""

original_code = """from __future__ import annotations
import itertools
import re
from bokeh.models import HoverTool, Plot, Tool
from dataclasses import dataclass
from typing import Any, Callable, Iterator, Sequence

def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
@dataclass(frozen=True)
class Item:
obj: Tool
properties: dict[str, Any]

key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__

for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
rest = [ Item(obj, obj.properties_with_values()) for obj in group ]
while len(rest) > 1:
head, *rest = rest
for item in rest:
if item.properties == head.properties:
yield item.obj
"""

expected = """from __future__ import annotations
import itertools
from bokeh.models import Tool
from dataclasses import dataclass
from typing import Any, Callable, Iterator


# Move the Item dataclass to module-level to avoid redefining it on every function call
@dataclass(frozen=True)
class _RepeatedToolItem:
obj: Tool
properties: dict[str, Any]

def _collect_repeated_tools(tool_objs: list[Tool]) -> Iterator[Tool]:
key: Callable[[Tool], str] = lambda obj: obj.__class__.__name__
# Pre-collect properties for all objects by group to avoid repeated calls
for _, group in itertools.groupby(sorted(tool_objs, key=key), key=key):
grouped = list(group)
n = len(grouped)
if n > 1:
# Precompute all properties once for this group
props = [_RepeatedToolItem(obj, obj.properties_with_values()) for obj in grouped]
i = 0
while i < len(props) - 1:
head = props[i]
for j in range(i+1, len(props)):
item = props[j]
if item.properties == head.properties:
yield item.obj
i += 1
"""

function_names: list[str] = ["_collect_repeated_tools"]
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code)
new_code: str = replace_functions_and_add_imports(
source_code=original_code,
function_names=function_names,
optimized_code=optim_code,
module_abspath=Path(__file__).resolve(),
preexisting_objects=preexisting_objects,
project_root_path=Path(__file__).resolve().parent.resolve(),
)
assert new_code == expected
Loading