diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index b2818c695..42c9ead9d 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -1,9 +1,10 @@ from __future__ import annotations import ast +import re from collections import defaultdict from functools import lru_cache -from typing import TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar import libcst as cst diff --git a/codeflash/code_utils/concolic_utils.py b/codeflash/code_utils/concolic_utils.py new file mode 100644 index 000000000..bad02f49e --- /dev/null +++ b/codeflash/code_utils/concolic_utils.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import ast +import re +from typing import Optional + + +class AssertCleanup: + def transform_asserts(self, code: str) -> str: + lines = code.splitlines() + result_lines = [] + + for line in lines: + transformed = self._transform_assert_line(line) + result_lines.append(transformed if transformed is not None else line) + + return "\n".join(result_lines) + + def _transform_assert_line(self, line: str) -> Optional[str]: + indent = line[: len(line) - len(line.lstrip())] + + assert_match = self.assert_re.match(line) + if assert_match: + expression = assert_match.group(1).strip() + if expression.startswith("not "): + return f"{indent}{expression}" + + expression = expression.rstrip(",;") + return f"{indent}{expression}" + + unittest_match = self.unittest_re.match(line) + if unittest_match: + indent, assert_method, args = unittest_match.groups() + + if args: + arg_parts = self._split_top_level_args(args) + if arg_parts and arg_parts[0]: + return f"{indent}{arg_parts[0]}" + + return None + + def _split_top_level_args(self, args_str: str) -> list[str]: + result = [] + current = [] + depth = 0 + + for char in args_str: + if char in "([{": + depth += 1 + current.append(char) + elif char in ")]}": + depth -= 1 + current.append(char) + elif char == "," and depth == 0: + result.append("".join(current).strip()) + current = [] + else: + current.append(char) + + if current: + result.append("".join(current).strip()) + + return result + + def __init__(self): + # Pre-compiling regular expressions for faster execution + self.assert_re = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$") + self.unittest_re = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$") + + +def clean_concolic_tests(test_suite_code: str) -> str: + try: + can_parse = True + tree = ast.parse(test_suite_code) + except SyntaxError: + can_parse = False + + if not can_parse: + return AssertCleanup().transform_asserts(test_suite_code) + + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"): + new_body = [] + for stmt in node.body: + if isinstance(stmt, ast.Assert): + if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call): + new_body.append(ast.Expr(value=stmt.test.left)) + else: + new_body.append(stmt) + + else: + new_body.append(stmt) + node.body = new_body + + return ast.unparse(tree).strip() diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 11eebab16..6e4dd0571 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -1,6 +1,7 @@ from __future__ import annotations import ast +import json import os import random import warnings @@ -156,9 +157,9 @@ def get_functions_to_optimize( project_root: Path, module_root: Path, ) -> tuple[dict[Path, list[FunctionToOptimize]], int]: - assert ( - sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1 - ), "Only one of optimize_all, replay_test, or file should be provided" + assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, ( + "Only one of optimize_all, replay_test, or file should be provided" + ) functions: dict[str, list[FunctionToOptimize]] with warnings.catch_warnings(): warnings.simplefilter(action="ignore", category=SyntaxWarning) @@ -434,9 +435,7 @@ def filter_functions( test_functions_removed_count += len(functions) continue if file_path in ignore_paths or any( - # file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths if ignore_path - file_path.startswith(str(ignore_path) + os.sep) - for ignore_path in ignore_paths + file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths ): ignore_paths_removed_count += 1 continue @@ -457,15 +456,17 @@ def filter_functions( malformed_paths_count += 1 continue if blocklist_funcs: - for function in functions.copy(): - path = Path(function.file_path).name - if path in blocklist_funcs and function.function_name in blocklist_funcs[path]: - functions.remove(function) - logger.debug(f"Skipping {function.function_name} in {path} as it has already been optimized") - continue - + functions = [ + function + for function in functions + if not ( + function.file_path.name in blocklist_funcs + and function.qualified_name in blocklist_funcs[function.file_path.name] + ) + ] filtered_modified_functions[file_path] = functions functions_count += len(functions) + if not disable_logs: log_info = { f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count, @@ -475,10 +476,11 @@ def filter_functions( f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count, f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count, } - log_string: str - if log_string := "\n".join([k for k, v in log_info.items() if v > 0]): + log_string = "\n".join([k for k, v in log_info.items() if v > 0]) + if log_string: logger.info(f"Ignoring: {log_string}") console.rule() + return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count diff --git a/codeflash/verification/concolic_testing.py b/codeflash/verification/concolic_testing.py index 1746c296d..c8e032ede 100644 --- a/codeflash/verification/concolic_testing.py +++ b/codeflash/verification/concolic_testing.py @@ -7,6 +7,7 @@ from pathlib import Path from codeflash.cli_cmds.console import console, logger +from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE from codeflash.code_utils.static_analysis import has_typed_parameters from codeflash.discovery.discover_unit_tests import discover_unit_tests @@ -21,7 +22,11 @@ def generate_concolic_tests( ) -> tuple[dict[str, list[FunctionCalledInTest]], str]: function_to_concolic_tests = {} concolic_test_suite_code = "" - if test_cfg.concolic_test_root_dir and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents): + if ( + test_cfg.concolic_test_root_dir + and isinstance(function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef)) + and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents) + ): logger.info("Generating concolic opcode coverage tests for the original code…") console.rule() try: @@ -54,7 +59,8 @@ def generate_concolic_tests( return function_to_concolic_tests, concolic_test_suite_code if cover_result.returncode == 0: - concolic_test_suite_code: str = cover_result.stdout + generated_concolic_test: str = cover_result.stdout + concolic_test_suite_code: str = clean_concolic_tests(generated_concolic_test) concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir)) concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py" concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8") diff --git a/tests/test_code_utils.py b/tests/test_code_utils.py index 3a7c48dab..85719f4f9 100644 --- a/tests/test_code_utils.py +++ b/tests/test_code_utils.py @@ -18,6 +18,7 @@ module_name_from_file_path, path_belongs_to_site_packages, ) +from codeflash.code_utils.concolic_utils import clean_concolic_tests from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files @@ -378,3 +379,65 @@ def test_prepare_coverage_files(mock_get_run_tmp_file: MagicMock) -> None: assert coverage_database_file == mock_coverage_file assert coveragercfile == mock_coveragerc_file mock_coveragerc_file.write_text.assert_called_once_with(f"[run]\n branch = True\ndata_file={mock_coverage_file}\n") + + +def test_clean_concolic_tests() -> None: + original_code = """ +def test_add_numbers(x: int, y: int) -> None: + assert add_numbers(1, 2) == 3 + + +def test_concatenate_strings(s1: str, s2: str) -> None: + assert concatenate_strings("hello", "world") == "helloworld" + + +def test_append_to_list(my_list: list[int], element: int) -> None: + assert append_to_list([1, 2, 3], 4) == [1, 2, 3, 4] + + +def test_get_dict_value(my_dict: dict[str, int], key: str) -> None: + assert get_dict_value({"a": 1, "b": 2}, "a") == 1 + + +def test_union_sets(set1: set[int], set2: set[int]) -> None: + assert union_sets({1, 2, 3}, {3, 4, 5}) == {1, 2, 3, 4, 5} + +def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None: + assert calculate_tuple_sum((1, 2, 3)) == 6 +""" + + cleaned_code = clean_concolic_tests(original_code) + expected_cleaned_code = """ +def test_add_numbers(x: int, y: int) -> None: + add_numbers(1, 2) + +def test_concatenate_strings(s1: str, s2: str) -> None: + concatenate_strings('hello', 'world') + +def test_append_to_list(my_list: list[int], element: int) -> None: + append_to_list([1, 2, 3], 4) + +def test_get_dict_value(my_dict: dict[str, int], key: str) -> None: + get_dict_value({'a': 1, 'b': 2}, 'a') + +def test_union_sets(set1: set[int], set2: set[int]) -> None: + union_sets({1, 2, 3}, {3, 4, 5}) + +def test_calculate_tuple_sum(my_tuple: tuple[int, int, int]) -> None: + calculate_tuple_sum((1, 2, 3)) +""" + assert cleaned_code == expected_cleaned_code.strip() + + concolic_generated_repr_code = """from src.blib2to3.pgen2.grammar import Grammar + +def test_Grammar_copy(): + assert Grammar.copy(Grammar()) == +""" + cleaned_code = clean_concolic_tests(concolic_generated_repr_code) + expected_cleaned_code = """ +from src.blib2to3.pgen2.grammar import Grammar + +def test_Grammar_copy(): + Grammar.copy(Grammar()) +""" + assert cleaned_code == expected_cleaned_code.strip()