Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def file_name_from_test_module_name(test_module_name: str, base_dir: Path) -> Pa
def get_imports_from_file(
file_path: Path | None = None, file_string: str | None = None, file_ast: ast.AST | None = None
) -> list[ast.Import | ast.ImportFrom]:
assert (
sum([file_path is not None, file_string is not None, file_ast is not None]) == 1
), "Must provide exactly one of file_path, file_string, or file_ast"
assert sum([file_path is not None, file_string is not None, file_ast is not None]) == 1, (
"Must provide exactly one of file_path, file_string, or file_ast"
)
if file_path:
with file_path.open(encoding="utf8") as file:
file_string = file.read()
Expand Down Expand Up @@ -107,6 +107,14 @@ def validate_python_code(code: str) -> str:
return code


def has_any_async_functions(code: str) -> bool:
try:
module = ast.parse(code)
except SyntaxError:
return False
return any(isinstance(node, ast.AsyncFunctionDef) for node in ast.walk(module))


def cleanup_paths(paths: list[Path]) -> None:
for path in paths:
path.unlink(missing_ok=True)
5 changes: 3 additions & 2 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
file_name_from_test_module_name,
get_run_tmp_file,
module_name_from_file_path,
has_any_async_functions,
)
from codeflash.code_utils.config_consts import (
INDIVIDUAL_TESTCASE_TIMEOUT,
Expand Down Expand Up @@ -136,8 +137,8 @@ def optimize_function(self) -> Result[BestOptimization, str]:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code

logger.info("Code to be optimized:")
if has_any_async_functions(code_context.code_to_optimize_with_helpers):
return Failure("Codeflash does not support async functions in the code to optimize.")
code_print(code_context.read_writable_code)

for module_abspath, helper_code_source in original_helper_code.items():
Expand Down
11 changes: 6 additions & 5 deletions codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def create_function_optimizer(
function_to_optimize_ast=function_to_optimize_ast,
aiservice_client=self.aiservice_client,
args=self.args,

)

def run(self) -> None:
Expand Down Expand Up @@ -140,6 +139,7 @@ def run(self) -> None:
validated_original_code[analysis.file_path] = ValidCode(
source_code=callee_original_code, normalized_code=normalized_callee_original_code
)

if has_syntax_error:
continue

Expand All @@ -149,7 +149,7 @@ def run(self) -> None:
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
f"{function_to_optimize.qualified_name}"
)

console.rule()
if not (
function_to_optimize_ast := get_first_top_level_function_or_method_ast(
function_to_optimize.function_name, function_to_optimize.parents, original_module_ast
Expand All @@ -160,9 +160,11 @@ def run(self) -> None:
f"Skipping optimization."
)
continue

function_optimizer = self.create_function_optimizer(
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code
function_to_optimize,
function_to_optimize_ast,
function_to_tests,
validated_original_code[original_module_path].source_code,
)
best_optimization = function_optimizer.optimize_function()
if is_successful(best_optimization):
Expand Down Expand Up @@ -192,7 +194,6 @@ def run(self) -> None:
get_run_tmp_file.tmpdir.cleanup()



def run_with_args(args: Namespace) -> None:
optimizer = Optimizer(args)
optimizer.run()
25 changes: 25 additions & 0 deletions tests/test_code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_class_defined_in_file,
module_name_from_file_path,
path_belongs_to_site_packages,
has_any_async_functions,
)
from codeflash.code_utils.concolic_utils import clean_concolic_tests
from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files
Expand Down Expand Up @@ -441,3 +442,27 @@ def test_Grammar_copy():
Grammar.copy(Grammar())
"""
assert cleaned_code == expected_cleaned_code.strip()


def test_has_any_async_functions_with_async_code() -> None:
code = """
def normal_function():
pass

async def async_function():
pass
"""
result = has_any_async_functions(code)
assert result is True


def test_has_any_async_functions_without_async_code() -> None:
code = """
def normal_function():
pass

def another_function():
pass
"""
result = has_any_async_functions(code)
assert result is False
Loading