Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 254 additions & 1 deletion codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import ast
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Dict, Optional, Set

import libcst as cst
import libcst.matchers as m
Expand All @@ -18,6 +18,227 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize

from typing import List, Union

class GlobalAssignmentCollector(cst.CSTVisitor):
"""Collects all global assignment statements."""

def __init__(self):
super().__init__()
self.assignments: Dict[str, cst.Assign] = {}
self.assignment_order: List[str] = []
# Track scope depth to identify global assignments
self.scope_depth = 0
self.if_else_depth = 0

def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
self.scope_depth += 1
return True

def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.scope_depth -= 1

def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
self.scope_depth += 1
return True

def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.scope_depth -= 1

def visit_If(self, node: cst.If) -> Optional[bool]:
self.if_else_depth += 1
return True

def leave_If(self, original_node: cst.If) -> None:
self.if_else_depth -= 1

def visit_Else(self, node: cst.Else) -> Optional[bool]:
# Else blocks are already counted as part of the if statement
return True

def visit_Assign(self, node: cst.Assign) -> Optional[bool]:
# Only process global assignments (not inside functions, classes, etc.)
if self.scope_depth == 0 and self.if_else_depth == 0: # We're at module level
for target in node.targets:
if isinstance(target.target, cst.Name):
name = target.target.value
self.assignments[name] = node
if name not in self.assignment_order:
self.assignment_order.append(name)
return True


class GlobalAssignmentTransformer(cst.CSTTransformer):
"""Transforms global assignments in the original file with those from the new file."""

def __init__(self, new_assignments: Dict[str, cst.Assign], new_assignment_order: List[str]):
super().__init__()
self.new_assignments = new_assignments
self.new_assignment_order = new_assignment_order
self.processed_assignments: Set[str] = set()
self.scope_depth = 0
self.if_else_depth = 0

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
self.scope_depth += 1

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self.scope_depth -= 1
return updated_node

def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.scope_depth += 1

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
self.scope_depth -= 1
return updated_node

def visit_If(self, node: cst.If) -> None:
self.if_else_depth += 1

def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If:
self.if_else_depth -= 1
return updated_node

def visit_Else(self, node: cst.Else) -> None:
# Else blocks are already counted as part of the if statement
pass

def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> cst.CSTNode:
if self.scope_depth > 0 or self.if_else_depth > 0:
return updated_node

# Check if this is a global assignment we need to replace
for target in original_node.targets:
if isinstance(target.target, cst.Name):
name = target.target.value
if name in self.new_assignments:
self.processed_assignments.add(name)
return self.new_assignments[name]

return updated_node

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)

# Find assignments to append
assignments_to_append = []
for name in self.new_assignment_order:
if name not in self.processed_assignments and name in self.new_assignments:
assignments_to_append.append(self.new_assignments[name])

if assignments_to_append:
# Add a blank line before appending new assignments if needed
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
new_statements.pop() # Remove the Pass statement but keep the empty line

# Add the new assignments
for assignment in assignments_to_append:
new_statements.append(
cst.SimpleStatementLine(
[assignment],
leading_lines=[cst.EmptyLine()]
)
)

return updated_node.with_changes(body=new_statements)

class GlobalStatementCollector(cst.CSTVisitor):
"""Visitor that collects all global statements (excluding imports and functions/classes)."""

def __init__(self):
super().__init__()
self.global_statements = []
self.in_function_or_class = False

def visit_ClassDef(self, node: cst.ClassDef) -> bool:
# Don't visit inside classes
self.in_function_or_class = True
return False

def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self.in_function_or_class = False

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
# Don't visit inside functions
self.in_function_or_class = True
return False

def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self.in_function_or_class = False

def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
if not self.in_function_or_class:
for statement in node.body:
# Skip imports
if not isinstance(statement, (cst.Import, cst.ImportFrom, cst.Assign)):
self.global_statements.append(node)
break


class LastImportFinder(cst.CSTVisitor):
"""Finds the position of the last import statement in the module."""

def __init__(self):
super().__init__()
self.last_import_line = 0
self.current_line = 0

def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
self.current_line += 1
for statement in node.body:
if isinstance(statement, (cst.Import, cst.ImportFrom)):
self.last_import_line = self.current_line


class ImportInserter(cst.CSTTransformer):
"""Transformer that inserts global statements after the last import."""

def __init__(self, global_statements: List[cst.SimpleStatementLine], last_import_line: int):
super().__init__()
self.global_statements = global_statements
self.last_import_line = last_import_line
self.current_line = 0
self.inserted = False

def leave_SimpleStatementLine(
self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine
) -> cst.Module:
self.current_line += 1

# If we're right after the last import and haven't inserted yet
if self.current_line == self.last_import_line and not self.inserted:
self.inserted = True
return cst.Module(body=[updated_node] + self.global_statements)

return cst.Module(body=[updated_node])

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# If there were no imports, add at the beginning of the module
if self.last_import_line == 0 and not self.inserted:
updated_body = list(updated_node.body)
for stmt in reversed(self.global_statements):
updated_body.insert(0, stmt)
return updated_node.with_changes(body=updated_body)
return updated_node


def extract_global_statements(source_code: str) -> List[cst.SimpleStatementLine]:
"""Extract global statements from source code."""
module = cst.parse_module(source_code)
collector = GlobalStatementCollector()
module.visit(collector)
return collector.global_statements


def find_last_import_line(target_code: str) -> int:
"""Find the line number of the last import statement."""
module = cst.parse_module(target_code)
finder = LastImportFinder()
module.visit(finder)
return finder.last_import_line

class FutureAliasedImportTransformer(cst.CSTTransformer):
def leave_ImportFrom(
Expand All @@ -38,6 +259,38 @@ def delete___future___aliased_imports(module_code: str) -> str:
return cst.parse_module(module_code).visit(FutureAliasedImportTransformer()).code


def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
non_assignment_global_statements = extract_global_statements(src_module_code)

# Find the last import line in target
last_import_line = find_last_import_line(dst_module_code)

# Parse the target code
target_module = cst.parse_module(dst_module_code)

# Create transformer to insert non_assignment_global_statements
transformer = ImportInserter(non_assignment_global_statements, last_import_line)
#
# # Apply transformation
modified_module = target_module.visit(transformer)
dst_module_code = modified_module.code

# Parse the code
original_module = cst.parse_module(dst_module_code)
new_module = cst.parse_module(src_module_code)

# Collect assignments from the new file
new_collector = GlobalAssignmentCollector()
new_module.visit(new_collector)

# Transform the original file
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
transformed_module = original_module.visit(transformer)

dst_module_code = transformed_module.code
return dst_module_code


def add_needed_imports_from_module(
src_module_code: str,
dst_module_code: str,
Expand Down
5 changes: 3 additions & 2 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import libcst as cst

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, add_global_assignments
from codeflash.models.models import FunctionParent

if TYPE_CHECKING:
Expand Down Expand Up @@ -220,7 +220,8 @@ def replace_function_definitions_in_module(
)
if is_zero_diff(source_code, new_code):
return False
module_abspath.write_text(new_code, encoding="utf8")
code_with_global_assignments = add_global_assignments(optimized_code, new_code)
module_abspath.write_text(code_with_global_assignments, encoding="utf8")
return True


Expand Down
37 changes: 20 additions & 17 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,23 +360,26 @@ def get_function_to_optimize_as_function_source(

# Find the name that matches our function
for name in names:
if (
name.type == "function"
and name.full_name
and name.name == function_to_optimize.function_name
and name.full_name.startswith(name.module_name)
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
):
function_source = FunctionSource(
file_path=function_to_optimize.file_path,
qualified_name=function_to_optimize.qualified_name,
fully_qualified_name=name.full_name,
only_function_name=name.name,
source_code=name.get_line_code(),
jedi_definition=name,
)
return function_source

try:
if (
name.type == "function"
and name.full_name
and name.name == function_to_optimize.function_name
and name.full_name.startswith(name.module_name)
and get_qualified_name(name.module_name, name.full_name) == function_to_optimize.qualified_name
):
function_source = FunctionSource(
file_path=function_to_optimize.file_path,
qualified_name=function_to_optimize.qualified_name,
fully_qualified_name=name.full_name,
only_function_name=name.name,
source_code=name.get_line_code(),
jedi_definition=name,
)
return function_source
except Exception as e:
logger.exception(f"Error while getting function source: {e}")
continue
raise ValueError(
f"Could not find function {function_to_optimize.function_name} in {function_to_optimize.file_path}"
)
Expand Down
Loading
Loading