Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion codeflash/code_utils/coverage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def extract_dependent_function(main_function: str, code_context: CodeOptimizatio
"""Extract the single dependent function from the code context excluding the main function."""
ast_tree = ast.parse(code_context.testgen_context_code)

dependent_functions = {node.name for node in ast_tree.body if isinstance(node, ast.FunctionDef)}
dependent_functions = {
node.name for node in ast_tree.body if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
}

if main_function in dependent_functions:
dependent_functions.discard(main_function)
Expand Down
5 changes: 4 additions & 1 deletion codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,10 @@ def detect_unused_helper_functions(
# Find the optimized entrypoint function
entrypoint_function_ast = None
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
if (
isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
and node.name == function_to_optimize.function_name
):
entrypoint_function_ast = node
break

Expand Down
147 changes: 146 additions & 1 deletion tests/test_code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_extractor import add_global_assignments
from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector


class HelperClass:
Expand Down Expand Up @@ -2482,3 +2482,148 @@ def test_circular_deps():
assert "import ApiClient" not in new_code, "Error: Circular dependency found"

assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
def test_global_assignment_collector_with_async_function():
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
import libcst as cst

source_code = """
# Global assignment
GLOBAL_VAR = "global_value"
OTHER_GLOBAL = 42

async def async_function():
# This should not be collected (inside async function)
local_var = "local_value"
INNER_ASSIGNMENT = "should_not_be_global"
return local_var

# Another global assignment
ANOTHER_GLOBAL = "another_global"
"""

tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)

# Should collect global assignments but not the ones inside async function
assert len(collector.assignments) == 3
assert "GLOBAL_VAR" in collector.assignments
assert "OTHER_GLOBAL" in collector.assignments
assert "ANOTHER_GLOBAL" in collector.assignments

# Should not collect assignments from inside async function
assert "local_var" not in collector.assignments
assert "INNER_ASSIGNMENT" not in collector.assignments

# Verify assignment order
expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"]
assert collector.assignment_order == expected_order


def test_global_assignment_collector_nested_async_functions():
"""Test GlobalAssignmentCollector handles nested async functions correctly."""
import libcst as cst

source_code = """
# Global assignment
CONFIG = {"key": "value"}

def sync_function():
# Inside sync function - should not be collected
sync_local = "sync"

async def nested_async():
# Inside nested async function - should not be collected
nested_var = "nested"
return nested_var

return sync_local

async def async_function():
# Inside async function - should not be collected
async_local = "async"

def nested_sync():
# Inside nested function - should not be collected
deeply_nested = "deep"
return deeply_nested

return async_local

# Another global assignment
FINAL_GLOBAL = "final"
"""

tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)

# Should only collect global-level assignments
assert len(collector.assignments) == 2
assert "CONFIG" in collector.assignments
assert "FINAL_GLOBAL" in collector.assignments

# Should not collect any assignments from inside functions
assert "sync_local" not in collector.assignments
assert "nested_var" not in collector.assignments
assert "async_local" not in collector.assignments
assert "deeply_nested" not in collector.assignments


def test_global_assignment_collector_mixed_async_sync_with_classes():
"""Test GlobalAssignmentCollector with async functions, sync functions, and classes."""
import libcst as cst

source_code = """
# Global assignments
GLOBAL_CONSTANT = "constant"

class TestClass:
# Class-level assignment - should not be collected
class_var = "class_value"

def sync_method(self):
# Method assignment - should not be collected
method_var = "method"
return method_var

async def async_method(self):
# Async method assignment - should not be collected
async_method_var = "async_method"
return async_method_var

def sync_function():
# Function assignment - should not be collected
func_var = "function"
return func_var

async def async_function():
# Async function assignment - should not be collected
async_func_var = "async_function"
return async_func_var

# More global assignments
ANOTHER_CONSTANT = 100
FINAL_ASSIGNMENT = {"data": "value"}
"""

tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)

# Should only collect global-level assignments
assert len(collector.assignments) == 3
assert "GLOBAL_CONSTANT" in collector.assignments
assert "ANOTHER_CONSTANT" in collector.assignments
assert "FINAL_ASSIGNMENT" in collector.assignments

# Should not collect assignments from inside any scoped blocks
assert "class_var" not in collector.assignments
assert "method_var" not in collector.assignments
assert "async_method_var" not in collector.assignments
assert "func_var" not in collector.assignments
assert "async_func_var" not in collector.assignments

# Verify correct order
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
assert collector.assignment_order == expected_order
135 changes: 135 additions & 0 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
is_zero_diff,
replace_functions_and_add_imports,
replace_functions_in_file,
OptimFunctionCollector,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
Expand Down Expand Up @@ -3453,3 +3454,137 @@ def hydrate_input_text_actions_with_field_names(
main_file.unlink(missing_ok=True)

assert new_code == expected


# OptimFunctionCollector async function tests
def test_optim_function_collector_with_async_functions():
"""Test OptimFunctionCollector correctly collects async functions."""
import libcst as cst

source_code = """
def sync_function():
return "sync"

async def async_function():
return "async"

class TestClass:
def sync_method(self):
return "sync_method"

async def async_method(self):
return "async_method"
"""

tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")},
preexisting_objects=None
)
tree.visit(collector)

# Should collect both sync and async functions
assert len(collector.modified_functions) == 4
assert (None, "sync_function") in collector.modified_functions
assert (None, "async_function") in collector.modified_functions
assert ("TestClass", "sync_method") in collector.modified_functions
assert ("TestClass", "async_method") in collector.modified_functions


def test_optim_function_collector_new_async_functions():
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
import libcst as cst

source_code = """
def existing_function():
return "existing"

async def new_async_function():
return "new_async"

def new_sync_function():
return "new_sync"

class ExistingClass:
async def new_class_async_method(self):
return "new_class_async"
"""

# Only existing_function is in preexisting objects
preexisting_objects = {("existing_function", ())}

tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=set(), # Not looking for specific functions
preexisting_objects=preexisting_objects
)
tree.visit(collector)

# Should identify new functions (both sync and async)
assert len(collector.new_functions) == 2
function_names = [func.name.value for func in collector.new_functions]
assert "new_async_function" in function_names
assert "new_sync_function" in function_names

# Should identify new class methods
assert "ExistingClass" in collector.new_class_functions
assert len(collector.new_class_functions["ExistingClass"]) == 1
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"


def test_optim_function_collector_mixed_scenarios():
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
import libcst as cst

source_code = """
# Global functions
def global_sync():
pass

async def global_async():
pass

class ParentClass:
def __init__(self):
pass

def sync_method(self):
pass

async def async_method(self):
pass

class ChildClass:
async def child_async_method(self):
pass

def child_sync_method(self):
pass
"""

# Looking for specific functions
function_names = {
(None, "global_sync"),
(None, "global_async"),
("ParentClass", "sync_method"),
("ParentClass", "async_method"),
("ChildClass", "child_async_method")
}

tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=function_names,
preexisting_objects=None
)
tree.visit(collector)

# Should collect all specified functions (mix of sync and async)
assert len(collector.modified_functions) == 5
assert (None, "global_sync") in collector.modified_functions
assert (None, "global_async") in collector.modified_functions
assert ("ParentClass", "sync_method") in collector.modified_functions
assert ("ParentClass", "async_method") in collector.modified_functions
assert ("ChildClass", "child_async_method") in collector.modified_functions

# Should collect __init__ method
assert "ParentClass" in collector.modified_init_functions
Loading
Loading