From 5830a70c3c29d4e57eea55a4f8ae536c3c1b4f12 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 27 Nov 2025 16:18:55 +0200 Subject: [PATCH 01/35] quick and dirty --- codeflash/discovery/functions_to_optimize.py | 42 +++- codeflash/models/models.py | 29 +++ codeflash/optimization/function_optimizer.py | 25 +- codeflash/verification/equivalence.py | 85 +++++-- codeflash/verification/parse_test_output.py | 42 ++++ tests/test_codeflash_capture.py | 233 +++++++++++++++++++ 6 files changed, 416 insertions(+), 40 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index b8cf895e1..3bee5fbf9 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -306,25 +306,43 @@ def levenshtein_distance(s1: str, s2: str) -> int: len1 = len(s1) len2 = len(s2) # Use a preallocated list instead of creating a new list every iteration + + # Early exit for empty string cases + if len1 == 0: + return len2 + if len2 == 0: + return len1 + + # Convert strings to lists for fast indexed access + s1_list = list(s1) + s2_list = list(s2) + + # Preallocate and reuse arrays; avoid creating new ones every iteration previous = list(range(len1 + 1)) current = [0] * (len1 + 1) for index2 in range(len2): - char2 = s2[index2] + char2 = s2_list[index2] current[0] = index2 + 1 + + # Remove redundant intermediate assignments for better cache locality + prev = previous + curr = current + s1_chars = s1_list + # Use local variables for frequently accessed values for index1 in range(len1): - char1 = s1[index1] - if char1 == char2: - current[index1 + 1] = previous[index1] + # Unrolling char1 assignment and equality check + if s1_chars[index1] == char2: + curr[index1 + 1] = prev[index1] else: - # Fast min calculation without tuple construct - a = previous[index1] - b = previous[index1 + 1] - c = current[index1] - min_val = min(b, a) - min_val = min(c, min_val) - current[index1 + 1] = 1 + min_val - # Swap references instead of copying + x = prev[index1] + y = prev[index1 + 1] + z = curr[index1] + min_xy = min(x, y) + min_xyz = min(z, min_xy) + curr[index1 + 1] = 1 + min_xyz + + # Swap references rather than copying data previous, current = current, previous return previous[len1] diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 744f76087..647aa2a3d 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -3,6 +3,7 @@ from collections import Counter, defaultdict from typing import TYPE_CHECKING +import libcst as cst from rich.tree import Tree from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log @@ -505,6 +506,31 @@ def id(self) -> str: f"{self.function_getting_tested}:{self.iteration_id}" ) + def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: + for stmt in class_node.body.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name: + return stmt + return None + + def get_src_code(self, test_path: Path) -> Optional[str]: + test_src = test_path.read_text(encoding="utf-8") + module_node = cst.parse_module(test_src) + + if self.test_class_name: + for stmt in module_node.body: + if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name: + func_node = self.find_func_in_class(stmt, self.test_function_name) + if func_node: + return module_node.code_for_node(func_node).strip() + # class not found + return None + + # Otherwise, look for a top level function + for stmt in module_node.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name: + return module_node.code_for_node(stmt).strip() + return None + @staticmethod def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: components = string_id.split(":") @@ -549,7 +575,10 @@ class TestResults(BaseModel): # noqa: PLW1641 # also we don't support deletion of test results elements - caution is advised test_results: list[FunctionTestInvocation] = [] test_result_idx: dict[str, int] = {} + perf_stdout: Optional[str] = None + # mapping between test function name and stdout failure message + test_failures: Optional[dict[str, str]] = None def add(self, function_test_invocation: FunctionTestInvocation) -> None: unique_id = function_test_invocation.unique_invocation_loop_id diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2eef51f0f..d948da052 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1752,6 +1752,11 @@ def establish_original_code_baseline( ) ) + def get_results_not_matched_error(self) -> Failure: + logger.info("h4|Test results did not match the test results of the original code ❌") + console.rule() + return Failure("Test results did not match the test results of the original code.") + def run_optimized_candidate( self, *, @@ -1808,13 +1813,25 @@ def run_optimized_candidate( ) ) console.rule() - if compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results): + match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) + if match: logger.info("h3|Test results matched ✅") console.rule() else: - logger.info("h4|Test results did not match the test results of the original code ❌") - console.rule() - return Failure("Test results did not match the test results of the original code.") + result_unmatched_perc = len(diffs) / len(candidate_behavior_results) + if result_unmatched_perc > 0.5: + # if the test unmatched percentage is greater than 50%, we can't fix it + return self.get_results_not_matched_error() + + # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again + # self.run_optimized_candidate( + # optimization_candidate_index=optimization_candidate_index, + # baseline_results=baseline_results, + # original_helper_code=original_helper_code, + # file_path_to_helper_classes=file_path_to_helper_classes, + # ) + print(f"should try to fix it, diffs: {diffs}") + return self.get_results_not_matched_error() logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 9d7f5ba2c..614fc66b4 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,4 +1,6 @@ import sys +from dataclasses import dataclass +from enum import Enum from codeflash.cli_cmds.console import logger from codeflash.models.models import TestResults, TestType, VerificationType @@ -7,21 +9,38 @@ INCREASED_RECURSION_LIMIT = 5000 -def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> bool: +class TestDiffScope(Enum): + RETURN_VALUE = "return_value" + STDOUT = "stdout" + TIMED_OUT = "timed_out" + DID_PASS = "did_pass" # noqa: S105 + + +@dataclass +class TestDiff: + scope: TestDiffScope + test_src_code: str + pytest_error: str + original_value: any + candidate_value: any + + +def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: # This is meant to be only called with test results for the first loop index if len(original_results) == 0 or len(candidate_results) == 0: - return False # empty test results are not equal + return False, [] # empty test results are not equal original_recursion_limit = sys.getrecursionlimit() if original_recursion_limit < INCREASED_RECURSION_LIMIT: sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union( set(candidate_results.get_all_unique_invocation_loop_ids()) ) - are_equal: bool = True + test_diffs: list[TestDiff] = [] did_all_timeout: bool = True for test_id in test_ids_superset: original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) + candidate_pytest_error = candidate_results.test_failures.get(original_test_result.id.test_function_name) if cdd_test_result is not None and original_test_result is None: continue # If helper function instance_state verification is not present, that's ok. continue @@ -32,8 +51,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ): continue if original_test_result is None or cdd_test_result is None: - are_equal = False - break + return False, [] did_all_timeout = did_all_timeout and original_test_result.timed_out if original_test_result.timed_out: continue @@ -43,23 +61,26 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} ): superset_obj = True + test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): - are_equal = False + test_diffs.append( + TestDiff( + scope=TestDiffScope.RETURN_VALUE, + test_src_code=test_src_code, + original_value=original_test_result.return_value, + candidate_value=cdd_test_result.return_value, + pytest_error=candidate_pytest_error, + ) + ) + try: - logger.debug( - "File Name: %s\n" - "Test Type: %s\n" - "Verification Type: %s\n" - "Invocation ID: %s\n" - "Original return value: %s\n" - "Candidate return value: %s\n" - "-------------------", - original_test_result.file_name, - original_test_result.test_type, - original_test_result.verification_type, - original_test_result.id, - original_test_result.return_value, - cdd_test_result.return_value, + print( + f"File Name: {original_test_result.file_name}\n" + f"Test Type: {original_test_result.test_type}\n" + f"Verification Type: {original_test_result.verification_type}\n" + f"Invocation ID: {original_test_result.id}\n" + f"Original return value: {original_test_result.return_value}\n" + f"Candidate return value: {cdd_test_result.return_value}\n" ) except Exception as e: logger.error(e) @@ -67,7 +88,15 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( original_test_result.stdout, cdd_test_result.stdout ): - are_equal = False + test_diffs.append( + TestDiff( + scope=TestDiffScope.STDOUT, + test_src_code=test_src_code, + original_value=original_test_result.stdout, + candidate_value=cdd_test_result.stdout, + pytest_error=candidate_pytest_error, + ) + ) break if original_test_result.test_type in { @@ -76,9 +105,17 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST, } and (cdd_test_result.did_pass != original_test_result.did_pass): - are_equal = False + test_diffs.append( + TestDiff( + scope=TestDiffScope.DID_PASS, + test_src_code=test_src_code, + original_value=original_test_result.did_pass, + candidate_value=cdd_test_result.did_pass, + pytest_error=candidate_pytest_error, + ) + ) break sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: - return False - return are_equal + return False, test_diffs + return len(test_diffs) == 0, test_diffs diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index ef513a0a3..bbcf21adc 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -512,6 +512,43 @@ def merge_test_results( return merged_test_results +def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: + stdout_lines = stdout.splitlines() + start_line = -1 + end_line = -1 + for i, line in enumerate(stdout_lines): + if start_line != -1 and end_line != -1: + break + if "FAILURES" in line: + start_line = i + elif "short test summary info" in line: + end_line = i + if start_line == -1 or end_line == -1: + return test_results + + complete_failure_output_lines = stdout_lines[start_line:end_line] # exclude last summary line + + test_case_to_failure: dict[str, str] = {} + + current_test_case: str | None = None + current_failure_lines: list[str] = [] + + for line in complete_failure_output_lines: + if line.startswith("_______"): + if current_test_case: + test_case_to_failure[current_test_case] = "".join(current_failure_lines) + current_test_case = line.strip("_ ").strip() + current_failure_lines = [] + elif current_test_case: + current_failure_lines.append(line + "\n") + + if current_test_case: + test_case_to_failure[current_test_case] = "".join(current_failure_lines) + + test_results.test_failures = test_case_to_failure + return test_results + + def parse_test_results( test_xml_path: Path, test_files: TestFiles, @@ -572,4 +609,9 @@ def parse_test_results( function_name=function_name, ) coverage.log_coverage() + try: + parse_test_failures_from_stdout(results, run_result.stdout) + except Exception as e: + logger.exception(e) + return results, coverage if all_args else None diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index c326cecc4..5f1bcb858 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -1177,3 +1177,236 @@ def target_function(self): fto_file_path.unlink(missing_ok=True) helper_path_1.unlink(missing_ok=True) helper_path_2.unlink(missing_ok=True) + +def test_instrument_codeflash_capture_and_run_tests_2() -> None: + # End to end run that instruments code and runs tests. Made to be similar to code used in the optimizer.py + test_code = """import math +import pytest +from typing import List, Tuple, Optional +from code_to_optimize.tests.pytest.fto_file import calculate_portfolio_metrics + +def test_calculate_portfolio_metrics(): + # Test case 1: Basic portfolio + investments = [ + ('Stocks', 0.6, 0.12), + ('Bonds', 0.3, 0.04), + ('Cash', 0.1, 0.01) + ] + + result = calculate_portfolio_metrics(investments) + + # Check weighted return calculation + expected_return = 0.6*0.12 + 0.3*0.04 + 0.1*0.01 + assert abs(result['weighted_return'] - expected_return) < 1e-10 + + # Check volatility calculation + expected_vol = math.sqrt((0.6*0.12)**2 + (0.3*0.04)**2 + (0.1*0.01)**2) + assert abs(result['volatility'] - expected_vol) < 1e-10 + + # Check Sharpe ratio + expected_sharpe = (expected_return - 0.02) / expected_vol + assert abs(result['sharpe_ratio'] - expected_sharpe) < 1e-10 + + # Check best/worst performers + assert result['best_performing'][0] == 'Stocks' + assert result['worst_performing'][0] == 'Cash' + assert result['total_assets'] == 3 + +def test_empty_investments(): + with pytest.raises(ValueError, match="Investments list cannot be empty"): + calculate_portfolio_metrics([]) + +def test_weights_not_sum_to_one(): + investments = [('Stock', 0.5, 0.1), ('Bond', 0.4, 0.05)] + with pytest.raises(ValueError, match="Portfolio weights must sum to 1.0"): + calculate_portfolio_metrics(investments) + +def test_zero_volatility(): + investments = [('Cash', 1.0, 0.0)] + result = calculate_portfolio_metrics(investments, risk_free_rate=0.0) + assert result['sharpe_ratio'] == 0.0 + assert result['volatility'] == 0.0 +""" + + original_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Calculate weighted return + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Calculate portfolio volatility (simplified) + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Calculate Sharpe ratio + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Find best and worst performing assets + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + 'weighted_return': round(weighted_return, 6), + 'volatility': round(volatility, 6), + 'sharpe_ratio': round(sharpe_ratio, 6), + 'best_performing': (best_asset[0], round(best_asset[2], 6)), + 'worst_performing': (worst_asset[0], round(worst_asset[2], 6)), + 'total_assets': len(investments) + } +""" + test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + test_file_name = "test_multiple_helpers.py" + + fto_file_name = "fto_file.py" + + test_path = test_dir / test_file_name + test_path_perf = test_dir / "test_multiple_helpers_perf.py" + fto_file_path = test_dir / fto_file_name + + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + + try: + with fto_file_path.open("w") as f: + f.write(original_code) + with test_path.open("w") as f: + f.write(test_code) + + fto = FunctionToOptimize("calculate_portfolio_metrics", fto_file_path, parents=[]) + file_path_to_helper_class = { + } + instrument_codeflash_capture(fto, file_path_to_helper_class, tests_root) + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + + test_type = TestType.EXISTING_UNIT_TEST + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + # Code in optimizer.py + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = { + } + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + + # Now, let's say we optimize the code and make changes. + new_fto_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + total_weight = sum(w for _, w, _ in investments) + if total_weight != 1.0: # Should use tolerance check + raise ValueError("Portfolio weights must sum to 1.0") + + weighted_return = 1.0 + for _, weight, ret in investments: + weighted_return *= (1 + ret) ** weight + weighted_return = weighted_return - 1.0 # Convert back from geometric + + returns = [r for _, _, r in investments] + mean_return = sum(returns) / len(returns) + volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns)) + + # BUG 4: Sharpe ratio calculation is correct but uses wrong inputs + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + def risk_adjusted_return(return_val, weight): + return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val + + best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": 2, + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fto_code) + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = { + # Path(helper_path_1): {"HelperClass1"}, + # Path(helper_path_2): {"HelperClass2", "AnotherHelperClass"}, + } + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results) + print(diffs) + + assert not matched + + finally: + test_path.unlink(missing_ok=True) + fto_file_path.unlink(missing_ok=True) \ No newline at end of file From 3e0440bf471a20b60b1c48d9862fb78f9902d113 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 27 Nov 2025 16:25:33 +0200 Subject: [PATCH 02/35] safter --- codeflash/verification/equivalence.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 614fc66b4..1e46f2ebc 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -40,7 +40,17 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR for test_id in test_ids_superset: original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) - candidate_pytest_error = candidate_results.test_failures.get(original_test_result.id.test_function_name) + candidate_test_failures = candidate_results.test_failures + # original_test_failures = original_results.test_failures + cdd_pytest_error = ( + candidate_test_failures.get(original_test_result.id.test_function_name, "") + if candidate_test_failures + else "" + ) + # original_pytest_error = ( + # original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else "" + # ) + if cdd_test_result is not None and original_test_result is None: continue # If helper function instance_state verification is not present, that's ok. continue @@ -69,7 +79,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_src_code=test_src_code, original_value=original_test_result.return_value, candidate_value=cdd_test_result.return_value, - pytest_error=candidate_pytest_error, + pytest_error=cdd_pytest_error, ) ) @@ -94,7 +104,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_src_code=test_src_code, original_value=original_test_result.stdout, candidate_value=cdd_test_result.stdout, - pytest_error=candidate_pytest_error, + pytest_error=cdd_pytest_error, ) ) break @@ -111,7 +121,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_src_code=test_src_code, original_value=original_test_result.did_pass, candidate_value=cdd_test_result.did_pass, - pytest_error=candidate_pytest_error, + pytest_error=cdd_pytest_error, ) ) break From eb16cb2a0b1aa3be640b6ce42cb28c5930c3fbeb Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:49:05 +0000 Subject: [PATCH 03/35] Optimize parse_test_failures_from_stdout The optimized code achieves a **15% speedup** through several targeted micro-optimizations that reduce computational overhead in the parsing loop: **Key Optimizations:** 1. **Single-pass boundary search**: Instead of checking both conditions (`start_line != -1 and end_line != -1`) on every iteration, the optimized version uses `None` values and breaks immediately when both markers are found, eliminating redundant condition checks. 2. **Fast-path string matching**: Before calling the expensive `.startswith("_______")` method, it first checks if `line[0] == "_"`, avoiding the method call for most lines that don't start with underscores. 3. **Method lookup optimization**: Pulls `current_failure_lines.append` into a local variable to avoid repeated attribute lookups in the hot loop where failure lines are processed. 4. **Memory-efficient list management**: Uses `current_failure_lines.clear()` instead of creating new list objects (`current_failure_lines = []`), reducing object allocation pressure. **Performance Impact:** The optimizations show the most significant gains in large-scale scenarios: - **Large failure sets**: 14.2% faster with 500 failures, 14.0% faster with 999 failures - **Large output**: 29.2% faster for single failures with 1000 lines of output - **Complex scenarios**: 22.3% faster with 50 cases having 10 lines each **Hot Path Context:** Based on the function reference, `parse_test_failures_from_stdout` is called from `parse_test_results`, which appears to be part of a test optimization pipeline. The function processes pytest stdout to extract failure information, making it performance-critical when dealing with large test suites or verbose test outputs. The 15% improvement becomes meaningful when processing hundreds of test failures in CI/CD environments or during iterative code optimization workflows. --- codeflash/verification/parse_test_output.py | 31 ++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index bbcf21adc..47ad5738a 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -514,16 +514,17 @@ def merge_test_results( def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: stdout_lines = stdout.splitlines() - start_line = -1 - end_line = -1 + start_line = end_line = None + + # optimize search for start/end by scanning once for i, line in enumerate(stdout_lines): - if start_line != -1 and end_line != -1: - break - if "FAILURES" in line: + if start_line is None and "FAILURES" in line: start_line = i - elif "short test summary info" in line: + elif start_line is not None and end_line is None and "short test summary info" in line: end_line = i - if start_line == -1 or end_line == -1: + break + + if start_line is None or end_line is None: return test_results complete_failure_output_lines = stdout_lines[start_line:end_line] # exclude last summary line @@ -533,14 +534,24 @@ def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> T current_test_case: str | None = None current_failure_lines: list[str] = [] + # Avoid per-line string concatenation by tracking indices and performing join once per section + # Precompute the boundary check value + underline_prefix = "_______" + + # Minor: Pull into local variable to avoid attribute lookup inside loop + join_nl = "\n".join + append = current_failure_lines.append + for line in complete_failure_output_lines: - if line.startswith("_______"): + # Fast-path: avoid .startswith() unless it can possibly match + if line and line[0] == "_" and line.startswith(underline_prefix): if current_test_case: test_case_to_failure[current_test_case] = "".join(current_failure_lines) current_test_case = line.strip("_ ").strip() - current_failure_lines = [] + # Start new collection + current_failure_lines.clear() elif current_test_case: - current_failure_lines.append(line + "\n") + append(line + "\n") if current_test_case: test_case_to_failure[current_test_case] = "".join(current_failure_lines) From a7f8816f5ed8f4035b8c427b99737c46c06fa8f4 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 27 Nov 2025 19:51:55 +0200 Subject: [PATCH 04/35] fix tests --- codeflash/models/models.py | 2 ++ codeflash/verification/equivalence.py | 8 ++++++- tests/test_codeflash_capture.py | 18 +++++++++----- tests/test_comparator.py | 21 ++++++++++------ tests/test_instrument_all_and_run.py | 24 ++++++++++++------- ...t_instrumentation_run_results_aiservice.py | 10 ++++---- tests/test_pickle_patcher.py | 8 +++---- 7 files changed, 61 insertions(+), 30 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 647aa2a3d..48ecf396a 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -513,6 +513,8 @@ def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Option return None def get_src_code(self, test_path: Path) -> Optional[str]: + if not test_path.exists(): + return None test_src = test_path.read_text(encoding="utf-8") module_node = cst.parse_module(test_src) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 1e46f2ebc..77798d88f 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,11 +1,17 @@ +from __future__ import annotations + import sys from dataclasses import dataclass from enum import Enum +from typing import TYPE_CHECKING, Optional from codeflash.cli_cmds.console import logger from codeflash.models.models import TestResults, TestType, VerificationType from codeflash.verification.comparator import comparator +if TYPE_CHECKING: + from codeflash.models.models import TestResults + INCREASED_RECURSION_LIMIT = 5000 @@ -19,10 +25,10 @@ class TestDiffScope(Enum): @dataclass class TestDiff: scope: TestDiffScope - test_src_code: str pytest_error: str original_value: any candidate_value: any + test_src_code: Optional[str] = None def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 5f1bcb858..4ab52dd80 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -502,7 +502,8 @@ def __init__(self, x=2): pytest_max_loops=1, testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -626,7 +627,8 @@ def __init__(self, *args, **kwargs): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -754,7 +756,8 @@ def __init__(self, x=2): testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) sample_code_path.unlink(missing_ok=True) @@ -902,7 +905,8 @@ def another_helper(self): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -1132,7 +1136,8 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert not compare_test_results(test_results, mutated_test_results) + match, _ = compare_test_results(test_results, mutated_test_results) + assert not match # This fto code stopped using a helper class. it should still pass no_helper1_fto_code = """ @@ -1170,7 +1175,8 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert compare_test_results(test_results, no_helper1_test_results) + match, _ = compare_test_results(test_results, no_helper1_test_results) + assert match finally: test_path.unlink(missing_ok=True) diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 06d178f95..6c2781229 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1176,7 +1176,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_1) + match, _ = compare_test_results(original_results, new_results_1) + assert match new_results_2 = TestResults() new_results_2.add( @@ -1199,7 +1200,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_2) + match, _ = compare_test_results(original_results, new_results_2) + assert not match new_results_3 = TestResults() new_results_3.add( @@ -1241,7 +1243,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_3) + match, _ = compare_test_results(original_results, new_results_3) + assert match new_results_4 = TestResults() new_results_4.add( @@ -1264,7 +1267,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_4) + match, _ = compare_test_results(original_results, new_results_4) + assert not match new_results_5_baseline = TestResults() new_results_5_baseline.add( @@ -1308,7 +1312,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_5_baseline, new_results_5_opt) + match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt) + assert not match new_results_6_baseline = TestResults() new_results_6_baseline.add( @@ -1352,9 +1357,11 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_6_baseline, new_results_6_opt) + match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt) + assert not match - assert not compare_test_results(TestResults(), TestResults()) + match, _ = compare_test_results(TestResults(), TestResults()) + assert not match def test_exceptions(): diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index ece7d38b0..7bdfa364b 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -223,7 +223,8 @@ def test_sort(): result: [0, 1, 2, 3, 4, 5] """ assert out_str == results2[0].stdout - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) @@ -368,7 +369,8 @@ def test_sort(): assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],) out_str = """codeflash stdout : BubbleSorter.sorter() called\n""" assert test_results[1].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[2].id.function_getting_tested == "BubbleSorter.__init__" assert test_results[2].id.test_function_name == "test_sort" assert test_results[2].did_pass @@ -396,7 +398,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match # Replace with optimized code that mutated instance attribute optimized_code = """ @@ -491,7 +494,8 @@ def sorter(self, arr): ) assert new_test_results[3].runtime > 0 assert new_test_results[3].did_pass - assert not compare_test_results(test_results, new_test_results) + match, _ = compare_test_results(test_results, new_test_results) + assert not match finally: fto_path.write_text(original_code, "utf-8") @@ -630,7 +634,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_classmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_classmethod" assert test_results[1].id.iteration_id == "4_0" @@ -655,7 +660,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") @@ -794,7 +800,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_staticmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_staticmethod" assert test_results[1].id.iteration_id == "4_0" @@ -819,7 +826,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index cae2c76f1..03556718d 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -221,10 +221,10 @@ def sorter(self, arr): testing_time=0.1, ) # assert test_results_mutated_attr[0].return_value[1]["self"].x == 1 TODO: add self as input to function - assert compare_test_results( + match, _ = compare_test_results( test_results, test_results_mutated_attr ) # Without codeflash capture, the init state was not verified, and the results are verified as correct even with the attribute mutated - + assert match assert test_results_mutated_attr[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n" finally: fto_path.write_text(original_code, "utf-8") @@ -403,9 +403,10 @@ def sorter(self, arr): assert test_results_mutated_attr[0].return_value[0] == {"x": 1} assert test_results_mutated_attr[0].verification_type == VerificationType.INIT_STATE_FTO assert test_results_mutated_attr[0].stdout == "" - assert not compare_test_results( + match,_ = compare_test_results( test_results, test_results_mutated_attr ) # The test should fail because the instance attribute was mutated + assert not match # Replace with optimized code that did not mutate existing instance attribute, but added a new one optimized_code_new_attr = """ import sys @@ -457,9 +458,10 @@ def sorter(self, arr): assert test_results_new_attr[0].stdout == "" # assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input # assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input - assert compare_test_results( + match,_ = compare_test_results( test_results, test_results_new_attr ) # The test should pass because the instance attribute was not mutated, only a new one was added + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index c67883c12..c05384d03 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -427,8 +427,8 @@ def bubble_sort_with_unused_socket(data_container): testing_time=1.0, ) assert len(optimized_test_results_unused_socket) == 1 - verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) - assert verification_result is True + match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) + assert match # Remove the previous instrumentation replay_test_path.write_text(original_replay_test_code) @@ -517,8 +517,8 @@ def bubble_sort_with_used_socket(data_container): assert test_results_used_socket.test_results[0].did_pass is False # Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined. - assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False - + match, _ = compare_test_results(test_results_used_socket, optimized_test_results_used_socket) + assert not match finally: # cleanup output_file.unlink(missing_ok=True) From 4e9f8941d4b9b91a4575271c17ad8070e8f36549 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 27 Nov 2025 19:56:02 +0200 Subject: [PATCH 05/35] linting --- codeflash/verification/parse_test_output.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 47ad5738a..873f9f341 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -516,7 +516,6 @@ def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> T stdout_lines = stdout.splitlines() start_line = end_line = None - # optimize search for start/end by scanning once for i, line in enumerate(stdout_lines): if start_line is None and "FAILURES" in line: start_line = i @@ -534,24 +533,16 @@ def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> T current_test_case: str | None = None current_failure_lines: list[str] = [] - # Avoid per-line string concatenation by tracking indices and performing join once per section - # Precompute the boundary check value underline_prefix = "_______" - # Minor: Pull into local variable to avoid attribute lookup inside loop - join_nl = "\n".join - append = current_failure_lines.append - for line in complete_failure_output_lines: - # Fast-path: avoid .startswith() unless it can possibly match if line and line[0] == "_" and line.startswith(underline_prefix): if current_test_case: test_case_to_failure[current_test_case] = "".join(current_failure_lines) current_test_case = line.strip("_ ").strip() - # Start new collection current_failure_lines.clear() elif current_test_case: - append(line + "\n") + current_failure_lines.append(line + "\n") if current_test_case: test_case_to_failure[current_test_case] = "".join(current_failure_lines) From 1c9abaf0ff298f0a772c93c3d9facdd299cdfadb Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 28 Nov 2025 10:49:07 +0200 Subject: [PATCH 06/35] did it pass ? --- codeflash/optimization/function_optimizer.py | 2 +- codeflash/verification/equivalence.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d948da052..9f3d1cfd0 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1823,6 +1823,7 @@ def run_optimized_candidate( # if the test unmatched percentage is greater than 50%, we can't fix it return self.get_results_not_matched_error() + print(f"should try to fix it, diffs: {diffs}") # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again # self.run_optimized_candidate( # optimization_candidate_index=optimization_candidate_index, @@ -1830,7 +1831,6 @@ def run_optimized_candidate( # original_helper_code=original_helper_code, # file_path_to_helper_classes=file_path_to_helper_classes, # ) - print(f"should try to fix it, diffs: {diffs}") return self.get_results_not_matched_error() logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 77798d88f..d39433416 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -18,8 +18,8 @@ class TestDiffScope(Enum): RETURN_VALUE = "return_value" STDOUT = "stdout" - TIMED_OUT = "timed_out" DID_PASS = "did_pass" # noqa: S105 + TIMED_OUT = "timed_out" @dataclass @@ -28,6 +28,8 @@ class TestDiff: pytest_error: str original_value: any candidate_value: any + original_pass: bool + candidate_pass: bool test_src_code: Optional[str] = None @@ -86,6 +88,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_value=original_test_result.return_value, candidate_value=cdd_test_result.return_value, pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, ) ) @@ -111,6 +115,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_value=original_test_result.stdout, candidate_value=cdd_test_result.stdout, pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, ) ) break @@ -128,6 +134,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_value=original_test_result.did_pass, candidate_value=cdd_test_result.did_pass, pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, ) ) break From 0b2d894ce2ed7242ee7fc76ff676f6eef88ecbee Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 28 Nov 2025 11:03:59 +0200 Subject: [PATCH 07/35] revert test optimization --- codeflash/discovery/functions_to_optimize.py | 42 ++++++-------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 3bee5fbf9..b8cf895e1 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -306,43 +306,25 @@ def levenshtein_distance(s1: str, s2: str) -> int: len1 = len(s1) len2 = len(s2) # Use a preallocated list instead of creating a new list every iteration - - # Early exit for empty string cases - if len1 == 0: - return len2 - if len2 == 0: - return len1 - - # Convert strings to lists for fast indexed access - s1_list = list(s1) - s2_list = list(s2) - - # Preallocate and reuse arrays; avoid creating new ones every iteration previous = list(range(len1 + 1)) current = [0] * (len1 + 1) for index2 in range(len2): - char2 = s2_list[index2] + char2 = s2[index2] current[0] = index2 + 1 - - # Remove redundant intermediate assignments for better cache locality - prev = previous - curr = current - s1_chars = s1_list - # Use local variables for frequently accessed values for index1 in range(len1): - # Unrolling char1 assignment and equality check - if s1_chars[index1] == char2: - curr[index1 + 1] = prev[index1] + char1 = s1[index1] + if char1 == char2: + current[index1 + 1] = previous[index1] else: - x = prev[index1] - y = prev[index1 + 1] - z = curr[index1] - min_xy = min(x, y) - min_xyz = min(z, min_xy) - curr[index1 + 1] = 1 + min_xyz - - # Swap references rather than copying data + # Fast min calculation without tuple construct + a = previous[index1] + b = previous[index1 + 1] + c = current[index1] + min_val = min(b, a) + min_val = min(c, min_val) + current[index1 + 1] = 1 + min_val + # Swap references instead of copying previous, current = current, previous return previous[len1] From ecfa89fb344b3654df7b75b3d83bcf7929f7726c Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 28 Nov 2025 11:22:16 +0200 Subject: [PATCH 08/35] cleaner --- codeflash/verification/equivalence.py | 70 ++++++++++++--------------- 1 file changed, 31 insertions(+), 39 deletions(-) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index d39433416..27f020a9f 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -25,12 +25,14 @@ class TestDiffScope(Enum): @dataclass class TestDiff: scope: TestDiffScope - pytest_error: str original_value: any candidate_value: any original_pass: bool candidate_pass: bool + test_src_code: Optional[str] = None + candidate_pytest_error: Optional[str] = None + original_pytest_error: Optional[str] = None def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: @@ -49,15 +51,15 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) candidate_test_failures = candidate_results.test_failures - # original_test_failures = original_results.test_failures + original_test_failures = original_results.test_failures cdd_pytest_error = ( candidate_test_failures.get(original_test_result.id.test_function_name, "") if candidate_test_failures else "" ) - # original_pytest_error = ( - # original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else "" - # ) + original_pytest_error = ( + original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else "" + ) if cdd_test_result is not None and original_test_result is None: continue @@ -79,22 +81,26 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} ): superset_obj = True + test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) + test_diff = TestDiff( + scope=TestDiffScope.RETURN_VALUE, + original_value=original_test_result.return_value, + candidate_value=cdd_test_result.return_value, + test_src_code=test_src_code, + candidate_pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, + original_pytest_error=original_pytest_error, + ) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): - test_diffs.append( - TestDiff( - scope=TestDiffScope.RETURN_VALUE, - test_src_code=test_src_code, - original_value=original_test_result.return_value, - candidate_value=cdd_test_result.return_value, - pytest_error=cdd_pytest_error, - original_pass=original_test_result.did_pass, - candidate_pass=cdd_test_result.did_pass, - ) - ) + test_diff.scope = TestDiffScope.RETURN_VALUE + test_diff.original_value = original_test_result.return_value + test_diff.candidate_value = cdd_test_result.return_value + test_diffs.append(test_diff) try: - print( + logger.debug( f"File Name: {original_test_result.file_name}\n" f"Test Type: {original_test_result.test_type}\n" f"Verification Type: {original_test_result.verification_type}\n" @@ -108,17 +114,10 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( original_test_result.stdout, cdd_test_result.stdout ): - test_diffs.append( - TestDiff( - scope=TestDiffScope.STDOUT, - test_src_code=test_src_code, - original_value=original_test_result.stdout, - candidate_value=cdd_test_result.stdout, - pytest_error=cdd_pytest_error, - original_pass=original_test_result.did_pass, - candidate_pass=cdd_test_result.did_pass, - ) - ) + test_diff.scope = TestDiffScope.STDOUT + test_diff.original_value = original_test_result.stdout + test_diff.candidate_value = cdd_test_result.stdout + test_diffs.append(test_diff) break if original_test_result.test_type in { @@ -127,17 +126,10 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST, } and (cdd_test_result.did_pass != original_test_result.did_pass): - test_diffs.append( - TestDiff( - scope=TestDiffScope.DID_PASS, - test_src_code=test_src_code, - original_value=original_test_result.did_pass, - candidate_value=cdd_test_result.did_pass, - pytest_error=cdd_pytest_error, - original_pass=original_test_result.did_pass, - candidate_pass=cdd_test_result.did_pass, - ) - ) + test_diff.scope = TestDiffScope.DID_PASS + test_diff.original_value = original_test_result.did_pass + test_diff.candidate_value = cdd_test_result.did_pass + test_diffs.append(test_diff) break sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: From 6ea2545aeed0a077f1669375444dfdfffd4f7e25 Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 28 Nov 2025 13:39:00 +0200 Subject: [PATCH 09/35] test: try to fix the candidate and see if the diff is empty --- tests/test_codeflash_capture.py | 69 +++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index 4ab52dd80..bd6518b24 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -1392,10 +1392,7 @@ def risk_adjusted_return(return_val, weight): candidate_helper_code = {} for file_path in file_path_to_helper_class: candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") - file_path_to_helper_classes = { - # Path(helper_path_1): {"HelperClass1"}, - # Path(helper_path_2): {"HelperClass2", "AnotherHelperClass"}, - } + file_path_to_helper_classes = {} instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) modified_test_results, coverage_data = func_optimizer.run_and_parse_tests( testing_type=TestingMode.BEHAVIOR, @@ -1413,6 +1410,70 @@ def risk_adjusted_return(return_val, weight): assert not matched + new_fixed_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + # Tolerant weight check (matches original) + total_weight = sum(weight for _, weight, _ in investments) + if abs(total_weight - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Same weighted return as original + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Same volatility formula as original + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Same Sharpe ratio logic + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Same best/worst logic (based on return only) + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": round(volatility, 6), + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fixed_code) + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = {} + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results_2, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results_2) + # now the test should match and no diffs should be found + assert len(diffs) == 0 + assert matched + finally: test_path.unlink(missing_ok=True) fto_file_path.unlink(missing_ok=True) \ No newline at end of file From fe68772bd421659f27ae3b2589394ed594128c33 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 17:11:36 -0500 Subject: [PATCH 10/35] capture all test discrepancies --- codeflash/verification/equivalence.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index d39433416..dc0dda952 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -119,7 +119,6 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR candidate_pass=cdd_test_result.did_pass, ) ) - break if original_test_result.test_type in { TestType.EXISTING_UNIT_TEST, @@ -138,7 +137,6 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR candidate_pass=cdd_test_result.did_pass, ) ) - break sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: return False, test_diffs From ed39ec8bb2978fe4ae1200342a7081dc2087b9c2 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 18:28:10 -0500 Subject: [PATCH 11/35] do the repair in main loop --- codeflash/optimization/function_optimizer.py | 53 ++++++++++++++------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 9f3d1cfd0..a13e3058e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING import libcst as cst +import sentry_sdk from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -583,6 +584,7 @@ def determine_best_candidate( baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, + code_context=code_context, ) console.rule() if not is_successful(run_results): @@ -672,21 +674,21 @@ def determine_best_candidate( async_throughput=candidate_result.async_throughput, ) valid_optimizations.append(best_optimization) - # queue corresponding refined optimization for best optimization - if not candidate.optimization_id.endswith("refi"): - future_all_refinements.append( - self.refine_optimizations( - valid_optimizations=[best_optimization], - original_code_baseline=original_code_baseline, - code_context=code_context, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - executor=self.executor, - function_references=function_references, - ) - ) + # # queue corresponding refined optimization for best optimization + # if not candidate.optimization_id.endswith("refi"): + # future_all_refinements.append( + # self.refine_optimizations( + # valid_optimizations=[best_optimization], + # original_code_baseline=original_code_baseline, + # code_context=code_context, + # trace_id=self.function_trace_id[:-4] + exp_type + # if self.experiment_id + # else self.function_trace_id, + # ai_service_client=ai_service_client, + # executor=self.executor, + # function_references=function_references, + # ) + # ) else: # For async functions, prioritize throughput metrics over runtime even for slow candidates is_async = ( @@ -1813,6 +1815,7 @@ def run_optimized_candidate( ) ) console.rule() + # print(type(code_context), type(candidate)) match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) if match: logger.info("h3|Test results matched ✅") @@ -1823,7 +1826,25 @@ def run_optimized_candidate( # if the test unmatched percentage is greater than 50%, we can't fix it return self.get_results_not_matched_error() - print(f"should try to fix it, diffs: {diffs}") + logger.info("running code repair...") + # not sure if all return types will be convertible to string + diff_per_test_fn = {} + for diff in diffs: + try: + diff_per_test_fn.setdefault(diff.test_src_code, []).append( + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}" + ) + except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + try: + test_issues = "\n".join( + f"{test_fn_def}\n{value}" for test_fn_def, value in diff_per_test_fn.items() + ) + except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + print(type(diff_per_test_fn), type(test_issues)) # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again # self.run_optimized_candidate( # optimization_candidate_index=optimization_candidate_index, From 142da4c748bf0734a5792d5abe5c1ff22bd5f86d Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 19:45:49 -0500 Subject: [PATCH 12/35] todo write backend endpoint --- codeflash/optimization/function_optimizer.py | 41 ++++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index a13e3058e..2371917b7 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -584,13 +584,34 @@ def determine_best_candidate( baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, - code_context=code_context, ) console.rule() if not is_successful(run_results): optimized_runtimes[candidate.optimization_id] = None is_correct[candidate.optimization_id] = False speedup_ratios[candidate.optimization_id] = None + fail_value = run_results.value + if ( + fail_value != "Test results did not match the test results of the original code." + and len(future_all_refinements) <= 3 + and not candidate.optimization_id.endswith("cdrp") + ): + # # queue corresponding code repair optimization for best optimization + future_all_refinements.append( + self.code_repair_optimizations( + original_source_code=candidate, + modified_source_code=code_context, + original_code_baseline=original_code_baseline, + test_details="test_details", + code_context=code_context, + trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + ai_service_client=ai_service_client, + executor=self.executor, + function_references=function_references, + ) + ) else: candidate_result: OptimizedCandidateResult = run_results.unwrap() best_test_runtime = candidate_result.best_test_runtime @@ -1831,12 +1852,15 @@ def run_optimized_candidate( diff_per_test_fn = {} for diff in diffs: try: - diff_per_test_fn.setdefault(diff.test_src_code, []).append( - f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}" + diff_per_test_fn[diff.test_src_code] = ( + diff_per_test_fn.setdefault(diff.test_src_code, "") + + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}\n" ) + except Exception as e: sentry_sdk.capture_exception(e) logger.exception(e) + return self.get_results_not_matched_error() try: test_issues = "\n".join( f"{test_fn_def}\n{value}" for test_fn_def, value in diff_per_test_fn.items() @@ -1844,15 +1868,8 @@ def run_optimized_candidate( except Exception as e: sentry_sdk.capture_exception(e) logger.exception(e) - print(type(diff_per_test_fn), type(test_issues)) - # with the parsed test results diff ask the llm to fix the candidate to match the test results of the original code, and run again - # self.run_optimized_candidate( - # optimization_candidate_index=optimization_candidate_index, - # baseline_results=baseline_results, - # original_helper_code=original_helper_code, - # file_path_to_helper_classes=file_path_to_helper_classes, - # ) - return self.get_results_not_matched_error() + return self.get_results_not_matched_error() + return Failure(test_issues) logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") From 5a7c3563fba0b93fd74e05916b9bff9a62293259 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 19:56:33 -0500 Subject: [PATCH 13/35] need to test now --- codeflash/api/aiservice.py | 55 +++++++++++++++++++- codeflash/models/models.py | 9 ++++ codeflash/optimization/function_optimizer.py | 21 ++++++++ 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 20c478eb4..eac035a13 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -27,7 +27,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata - from codeflash.models.models import AIServiceRefinerRequest + from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest from codeflash.result.explanation import Explanation @@ -294,6 +294,59 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [] + def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) -> list[OptimizedCandidate]: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Args: + request: A list of optimization candidate details for refinement + + Returns: + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + payload = [ + { + "optimization_id": opt.optimization_id, + "original_source_code": opt.original_source_code, + "modified_source_code": opt.modified_source_code, + "trace_id": opt.trace_id, + } + for opt in request + ] + # logger.debug(f"Repair {len(request)} optimizations…") + console.rule() + try: + response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating optimization repair: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + refined_optimizations = response.json()["code_repairs"] + logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") + console.rule() + + refinements = self._get_valid_candidates(refined_optimizations) + return [ + OptimizedCandidate( + source_code=c.source_code, + explanation=c.explanation, + optimization_id=c.optimization_id[:-4] + "cdrp", + ) + for c in refinements + ] + + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def get_new_explanation( # noqa: D417 self, source_code: str, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 48ecf396a..66d952d01 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -48,6 +48,15 @@ class AIServiceRefinerRequest: function_references: str | None = None +@dataclass(frozen=True) +class AIServiceCodeRepairRequest: + optimization_id: str + original_source_code: str + modified_source_code: str + test_details: str + trace_id: str + + # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully # qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name # of the module is foo.eggs. diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2371917b7..78b6305c9 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -70,6 +70,7 @@ from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( + AIServiceCodeRepairRequest, BestOptimization, CodeOptimizationContext, GeneratedTests, @@ -862,6 +863,26 @@ def refine_optimizations( ] return executor.submit(ai_service_client.optimize_python_code_refinement, request=request) + def code_repair_optimizations( + self, + original_source_code: str, + modified_source_code: str, + test_details: str, + trace_id: str, + ai_service_client: AiServiceClient, + executor: concurrent.futures.ThreadPoolExecutor, + ) -> concurrent.futures.Future: + request = [ + AIServiceCodeRepairRequest( + optimization_id="", + original_source_code=original_source_code, + modified_source_code=modified_source_code, + test_details=test_details, + trace_id=trace_id, + ) + ] + return executor.submit(ai_service_client.optimize_python_code_repair, request=request) + def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str ) -> None: From 5ed5dfcfb542e8e82331f0794e5ea9caef39ab13 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 22:18:09 -0500 Subject: [PATCH 14/35] works, figure out logging --- codeflash/api/aiservice.py | 3 ++- codeflash/optimization/function_optimizer.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index eac035a13..afb529534 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -310,6 +310,7 @@ def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) "optimization_id": opt.optimization_id, "original_source_code": opt.original_source_code, "modified_source_code": opt.modified_source_code, + "test_details": opt.test_details, "trace_id": opt.trace_id, } for opt in request @@ -325,7 +326,7 @@ def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) if response.status_code == 200: refined_optimizations = response.json()["code_repairs"] - logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") + # logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") console.rule() refinements = self._get_valid_candidates(refined_optimizations) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 78b6305c9..51e788f96 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -593,24 +593,22 @@ def determine_best_candidate( speedup_ratios[candidate.optimization_id] = None fail_value = run_results.value if ( - fail_value != "Test results did not match the test results of the original code." + fail_value.strip() != "Test results did not match the test results of the original code." and len(future_all_refinements) <= 3 and not candidate.optimization_id.endswith("cdrp") ): # # queue corresponding code repair optimization for best optimization future_all_refinements.append( self.code_repair_optimizations( - original_source_code=candidate, - modified_source_code=code_context, - original_code_baseline=original_code_baseline, - test_details="test_details", - code_context=code_context, + original_source_code=code_context.read_writable_code.markdown, + modified_source_code=candidate.source_code.markdown, + test_details=fail_value, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, ai_service_client=ai_service_client, executor=self.executor, - function_references=function_references, + optimization_id=candidate.optimization_id, ) ) else: @@ -869,12 +867,13 @@ def code_repair_optimizations( modified_source_code: str, test_details: str, trace_id: str, + optimization_id: str, ai_service_client: AiServiceClient, executor: concurrent.futures.ThreadPoolExecutor, ) -> concurrent.futures.Future: request = [ AIServiceCodeRepairRequest( - optimization_id="", + optimization_id=optimization_id, original_source_code=original_source_code, modified_source_code=modified_source_code, test_details=test_details, From fe33c8244036c5b4c119a31330977c19f3e1ec71 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Sun, 30 Nov 2025 23:53:36 -0500 Subject: [PATCH 15/35] local db logging --- codeflash/optimization/function_optimizer.py | 108 ++++++++++++++++++- 1 file changed, 107 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 51e788f96..b100dd144 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -5,6 +5,7 @@ import os import queue import random +import sqlite3 import subprocess import time import uuid @@ -119,6 +120,61 @@ from codeflash.verification.verification_utils import TestConfig +def log_code_repair_to_db( + code_repair_log_db: Path, + optimization_id: str, + trace_id: str | None = None, + passed: str | None = None, + faster: str | None = None, +) -> None: + """Log code repair data to SQLite database. + + Uses upsert pattern to allow incremental logging with different columns at different places. + Only non-None values will be updated; existing values are preserved. + """ + try: + conn = sqlite3.connect(code_repair_log_db) + cursor = conn.cursor() + + # Build dynamic upsert query based on provided columns + columns = ["optimization_id"] + values = [optimization_id] + update_parts = ["updated_at = CURRENT_TIMESTAMP"] + + if trace_id is not None: + columns.append("trace_id") + values.append(trace_id) + update_parts.append("trace_id = excluded.trace_id") + + if passed is not None: + columns.append("passed") + values.append(passed) + update_parts.append("passed = excluded.passed") + + if faster is not None: + columns.append("faster") + values.append(faster) + update_parts.append("faster = excluded.faster") + + placeholders = ", ".join(["?"] * len(values)) + columns_str = ", ".join(columns) + update_str = ", ".join(update_parts) + + cursor.execute( + f""" + INSERT INTO code_repair_logs_cf ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT(optimization_id) DO UPDATE SET {update_str} + """, # noqa: S608 + values, + ) + conn.commit() + conn.close() + except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + + class CandidateProcessor: """Handles candidate processing using a queue-based approach.""" @@ -249,6 +305,8 @@ def __init__( max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 ) self.optimization_review = "" + # SQLite database setup for logging + self.code_repair_log_db = Path(__file__).parent / "code_repair_log_cf.db" def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -389,7 +447,19 @@ def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) - + conn = sqlite3.connect(self.code_repair_log_db) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS code_repair_logs ( + optimization_id TEXT PRIMARY KEY, + trace_id TEXT, + passed TEXT, + faster TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + """) + conn.commit() + conn.close() should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() code_print( @@ -540,6 +610,14 @@ def determine_best_candidate( logger.warning( "force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate." ) + if candidate.optimization_id.endswith("cdrp"): + log_code_repair_to_db( + code_repair_log_db=self.code_repair_log_db, + trace_id=self.function_trace_id[:-4] + exp_type, + optimization_id=candidate.optimization_id, + passed="no", + faster="no", # this also may or may not pass + ) console.rule() continue except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: @@ -547,6 +625,14 @@ def determine_best_candidate( self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) + if candidate.optimization_id.endswith("cdrp"): + log_code_repair_to_db( + code_repair_log_db=self.code_repair_log_db, + trace_id=self.function_trace_id[:-4] + exp_type, + optimization_id=candidate.optimization_id, + passed="no", + faster="no", # this also may or may not pass + ) continue # check if this code has been evaluated before by checking the ast normalized code string normalized_code = normalize_code(candidate.source_code.flat.strip()) @@ -574,6 +660,16 @@ def determine_best_candidate( ): # new candidate has a shorter diff than the previously encountered one ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + if candidate.optimization_id.endswith("cdrp"): + log_code_repair_to_db( + code_repair_log_db=self.code_repair_log_db, + trace_id=self.function_trace_id[:-4] + exp_type, + optimization_id=candidate.optimization_id, + passed="yes" if is_correct[candidate.optimization_id] else "no", + faster="yes" + if speedup_ratios[candidate.optimization_id] > 0 + else "no", # this also may or may not pass + ) continue ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, @@ -743,6 +839,16 @@ def determine_best_candidate( if self.args.benchmark and benchmark_tree: console.print(benchmark_tree) console.rule() + if candidate.optimization_id.endswith("cdrp"): + log_code_repair_to_db( + code_repair_log_db=self.code_repair_log_db, + trace_id=self.function_trace_id[:-4] + exp_type, + optimization_id=candidate.optimization_id, + passed="yes" if is_correct[candidate.optimization_id] else "no", + faster="yes" + if speedup_ratios[candidate.optimization_id] > 0 + else "no", # this also may or may not pass + ) except KeyboardInterrupt as e: logger.exception(f"Optimization interrupted: {e}") raise From 83814bee599af0696b20790b0617b2bb0acab0d0 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Mon, 1 Dec 2025 00:09:49 -0500 Subject: [PATCH 16/35] ready to run experiments --- codeflash/optimization/function_optimizer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index b100dd144..4c6322fdf 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -306,7 +306,7 @@ def __init__( ) self.optimization_review = "" # SQLite database setup for logging - self.code_repair_log_db = Path(__file__).parent / "code_repair_log_cf.db" + self.code_repair_log_db = Path(__file__).parent / "code_repair_logs_cf.db" def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -450,13 +450,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: conn = sqlite3.connect(self.code_repair_log_db) cursor = conn.cursor() cursor.execute(""" - CREATE TABLE IF NOT EXISTS code_repair_logs ( + CREATE TABLE IF NOT EXISTS code_repair_logs_cf ( optimization_id TEXT PRIMARY KEY, trace_id TEXT, passed TEXT, faster TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) """) conn.commit() conn.close() @@ -1980,7 +1981,7 @@ def run_optimized_candidate( try: diff_per_test_fn[diff.test_src_code] = ( diff_per_test_fn.setdefault(diff.test_src_code, "") - + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.pytest_error}\n" + + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.candidate_pytest_error}\n" ) except Exception as e: From 0325444ce4626aea31ccaafe4650ddd1c9c3772e Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Mon, 1 Dec 2025 00:20:09 -0500 Subject: [PATCH 17/35] logging fix --- codeflash/optimization/function_optimizer.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 4c6322fdf..2129da430 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -617,7 +617,7 @@ def determine_best_candidate( trace_id=self.function_trace_id[:-4] + exp_type, optimization_id=candidate.optimization_id, passed="no", - faster="no", # this also may or may not pass + faster="no", ) console.rule() continue @@ -632,7 +632,7 @@ def determine_best_candidate( trace_id=self.function_trace_id[:-4] + exp_type, optimization_id=candidate.optimization_id, passed="no", - faster="no", # this also may or may not pass + faster="no", ) continue # check if this code has been evaluated before by checking the ast normalized code string @@ -668,8 +668,11 @@ def determine_best_candidate( optimization_id=candidate.optimization_id, passed="yes" if is_correct[candidate.optimization_id] else "no", faster="yes" - if speedup_ratios[candidate.optimization_id] > 0 - else "no", # this also may or may not pass + if ( + speedup_ratios[candidate.optimization_id] is not None + and speedup_ratios[candidate.optimization_id] > 0 + ) + else "no", ) continue ast_code_to_id[normalized_code] = { @@ -847,8 +850,11 @@ def determine_best_candidate( optimization_id=candidate.optimization_id, passed="yes" if is_correct[candidate.optimization_id] else "no", faster="yes" - if speedup_ratios[candidate.optimization_id] > 0 - else "no", # this also may or may not pass + if ( + speedup_ratios[candidate.optimization_id] is not None + and speedup_ratios[candidate.optimization_id] > 0 + ) + else "no", ) except KeyboardInterrupt as e: logger.exception(f"Optimization interrupted: {e}") From 9f7ed9030c16f3c4e3bfcfb57e946d6b82419634 Mon Sep 17 00:00:00 2001 From: ali Date: Mon, 1 Dec 2025 16:48:23 +0200 Subject: [PATCH 18/35] handle test class methods for the test diff --- codeflash/models/models.py | 5 +++++ codeflash/verification/equivalence.py | 6 ++++-- tests/test_codeflash_capture.py | 1 - 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 66d952d01..39388630a 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -515,6 +515,11 @@ def id(self) -> str: f"{self.function_getting_tested}:{self.iteration_id}" ) + # TestSuiteClass.test_function_name + def test_fn_qualified_name(self) -> str: + class_prefix = f"{self.test_class_name}." if self.test_class_name else "" + return f"{class_prefix}{self.test_function_name}" + def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: for stmt in class_node.body.body: if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name: diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 642da9194..6eff438e4 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -53,12 +53,14 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR candidate_test_failures = candidate_results.test_failures original_test_failures = original_results.test_failures cdd_pytest_error = ( - candidate_test_failures.get(original_test_result.id.test_function_name, "") + candidate_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") if candidate_test_failures else "" ) original_pytest_error = ( - original_test_failures.get(original_test_result.id.test_function_name, "") if original_test_failures else "" + original_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if original_test_failures + else "" ) if cdd_test_result is not None and original_test_result is None: diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index bd6518b24..79133bc15 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -1406,7 +1406,6 @@ def risk_adjusted_return(return_val, weight): # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) matched, diffs = compare_test_results(test_results, modified_test_results) - print(diffs) assert not matched From 6060ffbfe82d46864d2d9753fb62dfcd98e291ad Mon Sep 17 00:00:00 2001 From: mohammed ahmed <64513301+mohammedahmed18@users.noreply.github.com> Date: Tue, 2 Dec 2025 13:48:06 +0200 Subject: [PATCH 19/35] codeflash suggestion Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> --- codeflash/models/models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 39388630a..e4aa623d8 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -517,8 +517,12 @@ def id(self) -> str: # TestSuiteClass.test_function_name def test_fn_qualified_name(self) -> str: - class_prefix = f"{self.test_class_name}." if self.test_class_name else "" - return f"{class_prefix}{self.test_function_name}" + # Use f-string with inline conditional to reduce string concatenation operations + return ( + f"{self.test_class_name}.{self.test_function_name}" + if self.test_class_name + else str(self.test_function_name) + ) def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: for stmt in class_node.body.body: From 1120d6416c378060aaa2077e28bc112925f17df0 Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 2 Dec 2025 13:59:08 +0200 Subject: [PATCH 20/35] safer parsing --- codeflash/verification/parse_test_output.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 873f9f341..25fcf4e63 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -517,16 +517,18 @@ def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> T start_line = end_line = None for i, line in enumerate(stdout_lines): - if start_line is None and "FAILURES" in line: + stripped_line = line.strip() + if start_line is None and stripped_line[0] == "=" and "FAILURES" in stripped_line: start_line = i - elif start_line is not None and end_line is None and "short test summary info" in line: + # exclude last summary line + elif start_line is not None and end_line is None and "short test summary info" in stripped_line: end_line = i break if start_line is None or end_line is None: return test_results - complete_failure_output_lines = stdout_lines[start_line:end_line] # exclude last summary line + complete_failure_output_lines = stdout_lines[start_line:end_line] test_case_to_failure: dict[str, str] = {} From c2e037aa4321f9b6d808a1880acb74be435767ee Mon Sep 17 00:00:00 2001 From: ali Date: Tue, 2 Dec 2025 14:09:00 +0200 Subject: [PATCH 21/35] better parsing for pytest stdout --- codeflash/verification/parse_test_output.py | 68 +++++++++++++-------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 25fcf4e63..f5cdad9d1 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -512,44 +512,58 @@ def merge_test_results( return merged_test_results +FAILURES_HEADER_RE = re.compile(r"=+ FAILURES =+") +TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$") + + def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: - stdout_lines = stdout.splitlines() - start_line = end_line = None - - for i, line in enumerate(stdout_lines): - stripped_line = line.strip() - if start_line is None and stripped_line[0] == "=" and "FAILURES" in stripped_line: - start_line = i - # exclude last summary line - elif start_line is not None and end_line is None and "short test summary info" in stripped_line: - end_line = i + """Extract individual pytest test failures from stdout grouped by test case qualified name, and add them to the test results.""" + lines = stdout.splitlines() + start = end = None + + for i, line in enumerate(lines): + if FAILURES_HEADER_RE.search(line.strip()): + start = i break - if start_line is None or end_line is None: + if start is None: return test_results - complete_failure_output_lines = stdout_lines[start_line:end_line] + for j in range(start + 1, len(lines)): + stripped = lines[j].strip() + if "short test summary info" in stripped: + end = j + break + # any new === section === block + if stripped.startswith("=") and stripped.count("=") > 3: + end = j + break + + # If no clear "end", just grap the rest of the string + if end is None: + end = len(lines) - test_case_to_failure: dict[str, str] = {} + failure_block = lines[start:end] - current_test_case: str | None = None - current_failure_lines: list[str] = [] + failures: dict[str, str] = {} + current_name = None + current_lines: list[str] = [] - underline_prefix = "_______" + for line in failure_block: + m = TEST_HEADER_RE.match(line.strip()) + if m: + if current_name is not None: + failures[current_name] = "".join(current_lines) - for line in complete_failure_output_lines: - if line and line[0] == "_" and line.startswith(underline_prefix): - if current_test_case: - test_case_to_failure[current_test_case] = "".join(current_failure_lines) - current_test_case = line.strip("_ ").strip() - current_failure_lines.clear() - elif current_test_case: - current_failure_lines.append(line + "\n") + current_name = m.group(1) + current_lines = [] + elif current_name: + current_lines.append(line + "\n") - if current_test_case: - test_case_to_failure[current_test_case] = "".join(current_failure_lines) + if current_name: + failures[current_name] = "".join(current_lines) - test_results.test_failures = test_case_to_failure + test_results.test_failures = failures return test_results From bd1ebf4b76dc149919b2e35877f4adb1274c7287 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 2 Dec 2025 22:52:34 -0500 Subject: [PATCH 22/35] temp logging --- codeflash/optimization/function_optimizer.py | 84 +++++--------------- 1 file changed, 22 insertions(+), 62 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2129da430..560c54fd1 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -121,58 +121,32 @@ def log_code_repair_to_db( - code_repair_log_db: Path, - optimization_id: str, - trace_id: str | None = None, - passed: str | None = None, - faster: str | None = None, + code_repair_log_db: Path, optimization_id: str, trace_id: str, passed: str, faster: str ) -> None: - """Log code repair data to SQLite database. - - Uses upsert pattern to allow incremental logging with different columns at different places. - Only non-None values will be updated; existing values are preserved. - """ + """Log code repair data to SQLite database.""" try: - conn = sqlite3.connect(code_repair_log_db) - cursor = conn.cursor() - - # Build dynamic upsert query based on provided columns - columns = ["optimization_id"] - values = [optimization_id] - update_parts = ["updated_at = CURRENT_TIMESTAMP"] - - if trace_id is not None: - columns.append("trace_id") - values.append(trace_id) - update_parts.append("trace_id = excluded.trace_id") - - if passed is not None: - columns.append("passed") - values.append(passed) - update_parts.append("passed = excluded.passed") - - if faster is not None: - columns.append("faster") - values.append(faster) - update_parts.append("faster = excluded.faster") - - placeholders = ", ".join(["?"] * len(values)) - columns_str = ", ".join(columns) - update_str = ", ".join(update_parts) - - cursor.execute( - f""" - INSERT INTO code_repair_logs_cf ({columns_str}) - VALUES ({placeholders}) - ON CONFLICT(optimization_id) DO UPDATE SET {update_str} - """, # noqa: S608 - values, - ) - conn.commit() - conn.close() + with sqlite3.connect(code_repair_log_db) as conn: + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS code_repair_logs_cf ( + optimization_id TEXT PRIMARY KEY, + trace_id TEXT, + passed TEXT, + faster TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + cursor.execute( + """ + INSERT INTO code_repair_logs_cf (optimization_id, trace_id, passed, faster) + VALUES (?, ?, ?, ?) + """, + (optimization_id, trace_id, passed, faster), + ) + conn.commit() except Exception as e: sentry_sdk.capture_exception(e) - logger.exception(e) + logger.exception("Error logging code repair to db") class CandidateProcessor: @@ -447,20 +421,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) - conn = sqlite3.connect(self.code_repair_log_db) - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS code_repair_logs_cf ( - optimization_id TEXT PRIMARY KEY, - trace_id TEXT, - passed TEXT, - faster TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - conn.commit() - conn.close() should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() code_print( From c1ae81eec9ca585917d2ab99fec8776e6380cd40 Mon Sep 17 00:00:00 2001 From: ali Date: Wed, 3 Dec 2025 19:59:58 +0200 Subject: [PATCH 23/35] working version --- codeflash/api/aiservice.py | 43 +++---- codeflash/models/models.py | 22 +++- codeflash/optimization/function_optimizer.py | 111 ++++++++++--------- codeflash/verification/equivalence.py | 40 ++----- 4 files changed, 103 insertions(+), 113 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index afb529534..88f0bd887 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -294,50 +294,39 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [] - def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) -> list[OptimizedCandidate]: + def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None: """Optimize the given python code for performance by making a request to the Django endpoint. Args: - request: A list of optimization candidate details for refinement + request: optimization candidate details for refinement Returns: ------- - - List[OptimizationCandidate]: A list of Optimization Candidates. + - OptimizationCandidate: new fixed candidate. """ - payload = [ - { - "optimization_id": opt.optimization_id, - "original_source_code": opt.original_source_code, - "modified_source_code": opt.modified_source_code, - "test_details": opt.test_details, - "trace_id": opt.trace_id, - } - for opt in request - ] - # logger.debug(f"Repair {len(request)} optimizations…") console.rule() try: - response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) + response = self.make_ai_service_request("/code_repair", payload=request, timeout=120) except requests.exceptions.RequestException as e: logger.exception(f"Error generating optimization repair: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) return [] if response.status_code == 200: - refined_optimizations = response.json()["code_repairs"] - # logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") + refined_optimization = response.json() console.rule() - refinements = self._get_valid_candidates(refined_optimizations) - return [ - OptimizedCandidate( - source_code=c.source_code, - explanation=c.explanation, - optimization_id=c.optimization_id[:-4] + "cdrp", - ) - for c in refinements - ] + refinements = self._get_valid_candidates([refined_optimization]) + if not refinements: + logger.error("Code repair failed to generate a valid candidate.") + return None + + return OptimizedCandidate( + source_code=refinements[0].source_code, + explanation=refinements[0].explanation, + optimization_id=refinements[0].optimization_id[:-4] + "cdrp", + ) try: error = response.json()["error"] @@ -346,7 +335,7 @@ def optimize_python_code_repair(self, request: list[AIServiceCodeRepairRequest]) logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) console.rule() - return [] + return None def get_new_explanation( # noqa: D417 self, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index e4aa623d8..bdee8a3e9 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -48,13 +48,33 @@ class AIServiceRefinerRequest: function_references: str | None = None +class TestDiffScope(str, Enum): + RETURN_VALUE = "return_value" + STDOUT = "stdout" + DID_PASS = "did_pass" # noqa: S105 + TIMED_OUT = "timed_out" + + +@dataclass +class TestDiff: + scope: TestDiffScope + original_pass: bool + candidate_pass: bool + + original_value: str | None = None + candidate_value: str | None = None + test_src_code: Optional[str] = None + candidate_pytest_error: Optional[str] = None + original_pytest_error: Optional[str] = None + + @dataclass(frozen=True) class AIServiceCodeRepairRequest: optimization_id: str original_source_code: str modified_source_code: str - test_details: str trace_id: str + test_diffs: list[TestDiff] # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2129da430..fcab65c72 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -116,6 +116,7 @@ CoverageData, FunctionCalledInTest, FunctionSource, + TestDiff, ) from codeflash.verification.verification_utils import TestConfig @@ -685,32 +686,15 @@ def determine_best_candidate( baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, + code_context=code_context, + candidate=candidate, + exp_type=exp_type, ) console.rule() if not is_successful(run_results): optimized_runtimes[candidate.optimization_id] = None is_correct[candidate.optimization_id] = False speedup_ratios[candidate.optimization_id] = None - fail_value = run_results.value - if ( - fail_value.strip() != "Test results did not match the test results of the original code." - and len(future_all_refinements) <= 3 - and not candidate.optimization_id.endswith("cdrp") - ): - # # queue corresponding code repair optimization for best optimization - future_all_refinements.append( - self.code_repair_optimizations( - original_source_code=code_context.read_writable_code.markdown, - modified_source_code=candidate.source_code.markdown, - test_details=fail_value, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - executor=self.executor, - optimization_id=candidate.optimization_id, - ) - ) else: candidate_result: OptimizedCandidateResult = run_results.unwrap() best_test_runtime = candidate_result.best_test_runtime @@ -978,22 +962,19 @@ def code_repair_optimizations( self, original_source_code: str, modified_source_code: str, - test_details: str, + test_diffs: list[TestDiff], trace_id: str, optimization_id: str, ai_service_client: AiServiceClient, - executor: concurrent.futures.ThreadPoolExecutor, - ) -> concurrent.futures.Future: - request = [ - AIServiceCodeRepairRequest( - optimization_id=optimization_id, - original_source_code=original_source_code, - modified_source_code=modified_source_code, - test_details=test_details, - trace_id=trace_id, - ) - ] - return executor.submit(ai_service_client.optimize_python_code_repair, request=request) + ) -> OptimizedCandidate | None: + request = AIServiceCodeRepairRequest( + optimization_id=optimization_id, + original_source_code=original_source_code, + modified_source_code=modified_source_code, + test_diffs=test_diffs, + trace_id=trace_id, + ) + return ai_service_client.optimize_python_code_repair(request=request) def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str @@ -1920,6 +1901,9 @@ def run_optimized_candidate( baseline_results: OriginalCodeBaseline, original_helper_code: dict[Path, str], file_path_to_helper_classes: dict[Path, set[str]], + code_context: CodeOptimizationContext, + candidate: OptimizedCandidate, + exp_type: str, ) -> Result[OptimizedCandidateResult, str]: assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 @@ -1980,29 +1964,50 @@ def run_optimized_candidate( # if the test unmatched percentage is greater than 50%, we can't fix it return self.get_results_not_matched_error() - logger.info("running code repair...") - # not sure if all return types will be convertible to string - diff_per_test_fn = {} - for diff in diffs: - try: - diff_per_test_fn[diff.test_src_code] = ( - diff_per_test_fn.setdefault(diff.test_src_code, "") - + f"Expected Value: {diff.original_value!s}\nActual Value: {diff.candidate_value!s}\nError String:{diff.candidate_pytest_error}\n" - ) + if candidate.optimization_id.endswith("cdrp"): + # prevent looping for now + return self.get_results_not_matched_error() + + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + + with progress_bar("The test results are not matching, let me see if I can fix this"): + new_candidate = self.code_repair_optimizations( + original_source_code=code_context.read_writable_code.markdown, + modified_source_code=candidate.source_code.markdown, + test_diffs=diffs, + trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + ai_service_client=ai_service_client, + optimization_id=candidate.optimization_id, + ) + if not new_candidate: + return Failure("Code repair failed to generate a valid candidate.") + + code_print(new_candidate.source_code.flat) - except Exception as e: - sentry_sdk.capture_exception(e) - logger.exception(e) - return self.get_results_not_matched_error() try: - test_issues = "\n".join( - f"{test_fn_def}\n{value}" for test_fn_def, value in diff_per_test_fn.items() + did_update = self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=new_candidate.source_code, + original_helper_code=original_helper_code, ) - except Exception as e: - sentry_sdk.capture_exception(e) - logger.exception(e) - return self.get_results_not_matched_error() - return Failure(test_issues) + if did_update: + return self.run_optimized_candidate( + optimization_candidate_index=optimization_candidate_index, + baseline_results=baseline_results, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, + code_context=code_context, + candidate=new_candidate, + exp_type=exp_type, + ) + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + logger.error(e) + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + return Failure("Code repair failed to generate a valid candidate.") logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 6eff438e4..2e7e6e4fe 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,12 +1,10 @@ from __future__ import annotations import sys -from dataclasses import dataclass -from enum import Enum -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from codeflash.cli_cmds.console import logger -from codeflash.models.models import TestResults, TestType, VerificationType +from codeflash.models.models import TestDiff, TestDiffScope, TestResults, TestType, VerificationType from codeflash.verification.comparator import comparator if TYPE_CHECKING: @@ -15,26 +13,6 @@ INCREASED_RECURSION_LIMIT = 5000 -class TestDiffScope(Enum): - RETURN_VALUE = "return_value" - STDOUT = "stdout" - DID_PASS = "did_pass" # noqa: S105 - TIMED_OUT = "timed_out" - - -@dataclass -class TestDiff: - scope: TestDiffScope - original_value: any - candidate_value: any - original_pass: bool - candidate_pass: bool - - test_src_code: Optional[str] = None - candidate_pytest_error: Optional[str] = None - original_pytest_error: Optional[str] = None - - def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: # This is meant to be only called with test results for the first loop index if len(original_results) == 0 or len(candidate_results) == 0: @@ -87,8 +65,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) test_diff = TestDiff( scope=TestDiffScope.RETURN_VALUE, - original_value=original_test_result.return_value, - candidate_value=cdd_test_result.return_value, + original_value=f"{original_test_result.return_value!r}", + candidate_value=f"{cdd_test_result.return_value!r}", test_src_code=test_src_code, candidate_pytest_error=cdd_pytest_error, original_pass=original_test_result.did_pass, @@ -97,8 +75,6 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): test_diff.scope = TestDiffScope.RETURN_VALUE - test_diff.original_value = original_test_result.return_value - test_diff.candidate_value = cdd_test_result.return_value test_diffs.append(test_diff) try: @@ -117,8 +93,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR original_test_result.stdout, cdd_test_result.stdout ): test_diff.scope = TestDiffScope.STDOUT - test_diff.original_value = original_test_result.stdout - test_diff.candidate_value = cdd_test_result.stdout + test_diff.original_value = str(original_test_result.stdout) + test_diff.candidate_value = str(cdd_test_result.stdout) test_diffs.append(test_diff) if original_test_result.test_type in { @@ -128,8 +104,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR TestType.REPLAY_TEST, } and (cdd_test_result.did_pass != original_test_result.did_pass): test_diff.scope = TestDiffScope.DID_PASS - test_diff.original_value = original_test_result.did_pass - test_diff.candidate_value = cdd_test_result.did_pass + test_diff.original_value = str(original_test_result.did_pass) + test_diff.candidate_value = str(cdd_test_result.did_pass) test_diffs.append(test_diff) sys.setrecursionlimit(original_recursion_limit) From 97f2426694847f5a81219f0c0fca6108393c2490 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 4 Dec 2025 10:16:20 +0200 Subject: [PATCH 24/35] fix override candidate after the code repair --- codeflash/models/models.py | 1 + codeflash/optimization/function_optimizer.py | 38 ++++++++++++++------ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index bdee8a3e9..8bec44cd1 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -302,6 +302,7 @@ class CodeContextType(str, Enum): class OptimizedCandidateResult(BaseModel): + optimized_candidate: OptimizedCandidate max_loop_count: int best_test_runtime: int behavior_test_results: TestResults diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 49a935b95..2955a45c0 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -280,6 +280,7 @@ def __init__( max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 ) self.optimization_review = "" + self.ast_code_to_id = {} # SQLite database setup for logging self.code_repair_log_db = Path(__file__).parent / "code_repair_logs_cf.db" @@ -519,7 +520,7 @@ def determine_best_candidate( console.rule() future_all_refinements: list[concurrent.futures.Future] = [] - ast_code_to_id = {} + self.ast_code_to_id.clear() valid_optimizations = [] optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated @@ -598,11 +599,11 @@ def determine_best_candidate( continue # check if this code has been evaluated before by checking the ast normalized code string normalized_code = normalize_code(candidate.source_code.flat.strip()) - if normalized_code in ast_code_to_id: + if normalized_code in self.ast_code_to_id: logger.info( "Current candidate has been encountered before in testing, Skipping optimization candidate." ) - past_opt_id = ast_code_to_id[normalized_code]["optimization_id"] + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] is_correct[candidate.optimization_id] = is_correct[past_opt_id] @@ -612,16 +613,18 @@ def determine_best_candidate( optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[ past_opt_id ] - optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][ + optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ + "shorter_source_code" + ].markdown + optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code][ "shorter_source_code" ].markdown - optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) if ( - new_diff_len < ast_code_to_id[normalized_code]["diff_len"] + new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"] ): # new candidate has a shorter diff than the previously encountered one - ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code - ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len if candidate.optimization_id.endswith("cdrp"): log_code_repair_to_db( code_repair_log_db=self.code_repair_log_db, @@ -636,7 +639,7 @@ def determine_best_candidate( else "no", ) continue - ast_code_to_id[normalized_code] = { + self.ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, "shorter_source_code": candidate.source_code, "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), @@ -657,6 +660,9 @@ def determine_best_candidate( speedup_ratios[candidate.optimization_id] = None else: candidate_result: OptimizedCandidateResult = run_results.unwrap() + # override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair + if candidate.optimization_id != candidate_result.optimized_candidate.optimization_id: + candidate = candidate_result.optimized_candidate best_test_runtime = candidate_result.best_test_runtime optimized_runtimes[candidate.optimization_id] = best_test_runtime is_correct[candidate.optimization_id] = True @@ -821,7 +827,7 @@ def determine_best_candidate( for valid_opt in valid_optimizations: valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) new_candidate_with_shorter_code = OptimizedCandidate( - source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], + source_code=self.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], optimization_id=valid_opt.candidate.optimization_id, explanation=valid_opt.candidate.explanation, ) @@ -1946,7 +1952,18 @@ def run_optimized_candidate( code_print(new_candidate.source_code.flat) + normalized_code = normalize_code(candidate.source_code.flat.strip()) + self.ast_code_to_id[normalized_code] = { + "optimization_id": candidate.optimization_id, + "shorter_source_code": candidate.source_code, + "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), + } + try: + # revert first to original code then replace with new repaired code, so we don't get any weird behavior + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) did_update = self.replace_function_and_helpers_with_optimized_code( code_context=code_context, optimized_code=new_candidate.source_code, @@ -2048,6 +2065,7 @@ def run_optimized_candidate( ) return Success( OptimizedCandidateResult( + optimized_candidate=candidate, max_loop_count=loop_count, best_test_runtime=total_candidate_timing, behavior_test_results=candidate_behavior_results, From 6a9390c28abd451bae6c9ade3b8c4cfbd58f7df8 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 4 Dec 2025 10:30:10 +0200 Subject: [PATCH 25/35] typo --- codeflash/optimization/function_optimizer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2955a45c0..7678221c9 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1951,12 +1951,11 @@ def run_optimized_candidate( return Failure("Code repair failed to generate a valid candidate.") code_print(new_candidate.source_code.flat) - - normalized_code = normalize_code(candidate.source_code.flat.strip()) + normalized_code = normalize_code(new_candidate.source_code.flat.strip()) self.ast_code_to_id[normalized_code] = { - "optimization_id": candidate.optimization_id, - "shorter_source_code": candidate.source_code, - "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), + "optimization_id": new_candidate.optimization_id, + "shorter_source_code": new_candidate.source_code, + "diff_len": diff_length(new_candidate.source_code.flat, code_context.read_writable_code.flat), } try: From b93fd34c067205ea740ad43d6641beb3105c9eca Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 4 Dec 2025 11:38:05 +0200 Subject: [PATCH 26/35] enhancements and cleanups --- codeflash/api/aiservice.py | 4 +- codeflash/models/models.py | 1 - codeflash/optimization/function_optimizer.py | 195 +++++++------------ 3 files changed, 72 insertions(+), 128 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 88f0bd887..29b3142e8 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -308,10 +308,10 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op console.rule() try: response = self.make_ai_service_request("/code_repair", payload=request, timeout=120) - except requests.exceptions.RequestException as e: + except (requests.exceptions.RequestException, TypeError) as e: logger.exception(f"Error generating optimization repair: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) - return [] + return None if response.status_code == 200: refined_optimization = response.json() diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 8bec44cd1..bdee8a3e9 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -302,7 +302,6 @@ class CodeContextType(str, Enum): class OptimizedCandidateResult(BaseModel): - optimized_candidate: OptimizedCandidate max_loop_count: int best_test_runtime: int behavior_test_results: TestResults diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7678221c9..043914cd1 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -5,7 +5,6 @@ import os import queue import random -import sqlite3 import subprocess import time import uuid @@ -14,7 +13,6 @@ from typing import TYPE_CHECKING import libcst as cst -import sentry_sdk from rich.console import Group from rich.panel import Panel from rich.syntax import Syntax @@ -121,35 +119,6 @@ from codeflash.verification.verification_utils import TestConfig -def log_code_repair_to_db( - code_repair_log_db: Path, optimization_id: str, trace_id: str, passed: str, faster: str -) -> None: - """Log code repair data to SQLite database.""" - try: - with sqlite3.connect(code_repair_log_db) as conn: - cursor = conn.cursor() - cursor.execute(""" - CREATE TABLE IF NOT EXISTS code_repair_logs_cf ( - optimization_id TEXT PRIMARY KEY, - trace_id TEXT, - passed TEXT, - faster TEXT, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - """) - cursor.execute( - """ - INSERT INTO code_repair_logs_cf (optimization_id, trace_id, passed, faster) - VALUES (?, ?, ?, ?) - """, - (optimization_id, trace_id, passed, faster), - ) - conn.commit() - except Exception as e: - sentry_sdk.capture_exception(e) - logger.exception("Error logging code repair to db") - - class CandidateProcessor: """Handles candidate processing using a queue-based approach.""" @@ -281,8 +250,6 @@ def __init__( ) self.optimization_review = "" self.ast_code_to_id = {} - # SQLite database setup for logging - self.code_repair_log_db = Path(__file__).parent / "code_repair_logs_cf.db" def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -494,6 +461,41 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) + def was_candidate_tested_before(self, normalized_code: str) -> bool: + # check if this code has been evaluated before by checking the ast normalized code string + return normalized_code in self.ast_code_to_id + + def update_results_for_duplicate_candidate( + self, + candidate: OptimizedCandidate, + code_context: CodeOptimizationContext, + normalized_code: str, + speedup_ratios: dict, + is_correct: dict, + optimized_runtimes: dict, + optimized_line_profiler_results: dict, + optimizations_post: dict, + ) -> None: + logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.") + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] + # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes + speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] + is_correct[candidate.optimization_id] = is_correct[past_opt_id] + optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] + # line profiler results only available for successful runs + if past_opt_id in optimized_line_profiler_results: + optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[past_opt_id] + optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ + "shorter_source_code" + ].markdown + optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) + if ( + new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"] + ): # new candidate has a shorter diff than the previously encountered one + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + def determine_best_candidate( self, *, @@ -573,14 +575,6 @@ def determine_best_candidate( logger.warning( "force_lsp|No functions were replaced in the optimized code. Skipping optimization candidate." ) - if candidate.optimization_id.endswith("cdrp"): - log_code_repair_to_db( - code_repair_log_db=self.code_repair_log_db, - trace_id=self.function_trace_id[:-4] + exp_type, - optimization_id=candidate.optimization_id, - passed="no", - faster="no", - ) console.rule() continue except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: @@ -588,63 +582,27 @@ def determine_best_candidate( self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) - if candidate.optimization_id.endswith("cdrp"): - log_code_repair_to_db( - code_repair_log_db=self.code_repair_log_db, - trace_id=self.function_trace_id[:-4] + exp_type, - optimization_id=candidate.optimization_id, - passed="no", - faster="no", - ) continue # check if this code has been evaluated before by checking the ast normalized code string normalized_code = normalize_code(candidate.source_code.flat.strip()) - if normalized_code in self.ast_code_to_id: - logger.info( - "Current candidate has been encountered before in testing, Skipping optimization candidate." + if self.was_candidate_tested_before(normalized_code): + self.update_results_for_duplicate_candidate( + candidate=candidate, + code_context=code_context, + normalized_code=normalized_code, + speedup_ratios=speedup_ratios, + is_correct=is_correct, + optimized_runtimes=optimized_runtimes, + optimized_line_profiler_results=optimized_line_profiler_results, + optimizations_post=optimizations_post, ) - past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] - # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes - speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] - is_correct[candidate.optimization_id] = is_correct[past_opt_id] - optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] - # line profiler results only available for successful runs - if past_opt_id in optimized_line_profiler_results: - optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[ - past_opt_id - ] - optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ - "shorter_source_code" - ].markdown - optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code][ - "shorter_source_code" - ].markdown - new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) - if ( - new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"] - ): # new candidate has a shorter diff than the previously encountered one - self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code - self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len - if candidate.optimization_id.endswith("cdrp"): - log_code_repair_to_db( - code_repair_log_db=self.code_repair_log_db, - trace_id=self.function_trace_id[:-4] + exp_type, - optimization_id=candidate.optimization_id, - passed="yes" if is_correct[candidate.optimization_id] else "no", - faster="yes" - if ( - speedup_ratios[candidate.optimization_id] is not None - and speedup_ratios[candidate.optimization_id] > 0 - ) - else "no", - ) continue self.ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, "shorter_source_code": candidate.source_code, "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), } - run_results = self.run_optimized_candidate( + run_results, new_candidate = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, original_helper_code=original_helper_code, @@ -653,6 +611,10 @@ def determine_best_candidate( candidate=candidate, exp_type=exp_type, ) + if candidate.optimization_id != new_candidate.optimization_id: + # override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair + candidate = new_candidate + console.rule() if not is_successful(run_results): optimized_runtimes[candidate.optimization_id] = None @@ -660,9 +622,6 @@ def determine_best_candidate( speedup_ratios[candidate.optimization_id] = None else: candidate_result: OptimizedCandidateResult = run_results.unwrap() - # override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair - if candidate.optimization_id != candidate_result.optimized_candidate.optimization_id: - candidate = candidate_result.optimized_candidate best_test_runtime = candidate_result.best_test_runtime optimized_runtimes[candidate.optimization_id] = best_test_runtime is_correct[candidate.optimization_id] = True @@ -745,20 +704,20 @@ def determine_best_candidate( ) valid_optimizations.append(best_optimization) # # queue corresponding refined optimization for best optimization - # if not candidate.optimization_id.endswith("refi"): - # future_all_refinements.append( - # self.refine_optimizations( - # valid_optimizations=[best_optimization], - # original_code_baseline=original_code_baseline, - # code_context=code_context, - # trace_id=self.function_trace_id[:-4] + exp_type - # if self.experiment_id - # else self.function_trace_id, - # ai_service_client=ai_service_client, - # executor=self.executor, - # function_references=function_references, - # ) - # ) + if not candidate.optimization_id.endswith("refi"): + future_all_refinements.append( + self.refine_optimizations( + valid_optimizations=[best_optimization], + original_code_baseline=original_code_baseline, + code_context=code_context, + trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + ai_service_client=ai_service_client, + executor=self.executor, + function_references=function_references, + ) + ) else: # For async functions, prioritize throughput metrics over runtime even for slow candidates is_async = ( @@ -793,19 +752,6 @@ def determine_best_candidate( if self.args.benchmark and benchmark_tree: console.print(benchmark_tree) console.rule() - if candidate.optimization_id.endswith("cdrp"): - log_code_repair_to_db( - code_repair_log_db=self.code_repair_log_db, - trace_id=self.function_trace_id[:-4] + exp_type, - optimization_id=candidate.optimization_id, - passed="yes" if is_correct[candidate.optimization_id] else "no", - faster="yes" - if ( - speedup_ratios[candidate.optimization_id] is not None - and speedup_ratios[candidate.optimization_id] > 0 - ) - else "no", - ) except KeyboardInterrupt as e: logger.exception(f"Optimization interrupted: {e}") raise @@ -1870,7 +1816,7 @@ def run_optimized_candidate( code_context: CodeOptimizationContext, candidate: OptimizedCandidate, exp_type: str, - ) -> Result[OptimizedCandidateResult, str]: + ) -> tuple[Result[OptimizedCandidateResult, str], OptimizedCandidate]: assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 with progress_bar("Testing optimization candidate"): @@ -1928,11 +1874,11 @@ def run_optimized_candidate( result_unmatched_perc = len(diffs) / len(candidate_behavior_results) if result_unmatched_perc > 0.5: # if the test unmatched percentage is greater than 50%, we can't fix it - return self.get_results_not_matched_error() + return self.get_results_not_matched_error(), candidate if candidate.optimization_id.endswith("cdrp"): # prevent looping for now - return self.get_results_not_matched_error() + return self.get_results_not_matched_error(), candidate ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client @@ -1948,7 +1894,7 @@ def run_optimized_candidate( optimization_id=candidate.optimization_id, ) if not new_candidate: - return Failure("Code repair failed to generate a valid candidate.") + return Failure("Code repair failed to generate a valid candidate."), candidate code_print(new_candidate.source_code.flat) normalized_code = normalize_code(new_candidate.source_code.flat.strip()) @@ -1983,7 +1929,7 @@ def run_optimized_candidate( self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) - return Failure("Code repair failed to generate a valid candidate.") + return Failure("Code repair failed to generate a valid candidate."), candidate logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") @@ -2064,7 +2010,6 @@ def run_optimized_candidate( ) return Success( OptimizedCandidateResult( - optimized_candidate=candidate, max_loop_count=loop_count, best_test_runtime=total_candidate_timing, behavior_test_results=candidate_behavior_results, @@ -2076,7 +2021,7 @@ def run_optimized_candidate( total_candidate_timing=total_candidate_timing, async_throughput=candidate_async_throughput, ) - ) + ), candidate def run_and_parse_tests( self, From bcc19f76009d28d112d417c648521a038b0c081b Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 4 Dec 2025 12:32:17 +0200 Subject: [PATCH 27/35] handle repaired code is exact same as the original code --- codeflash/optimization/function_optimizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 043914cd1..0f84a671e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1806,7 +1806,7 @@ def get_results_not_matched_error(self) -> Failure: console.rule() return Failure("Test results did not match the test results of the original code.") - def run_optimized_candidate( + def run_optimized_candidate( # noqa: PLR0911 self, *, optimization_candidate_index: int, @@ -1924,6 +1924,9 @@ def run_optimized_candidate( candidate=new_candidate, exp_type=exp_type, ) + msg = "No functions were replaced in the optimized code. Skipping optimization candidate." + logger.warning(f"force_lsp|{msg}") + return Failure(msg), candidate except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: logger.error(e) self.write_code_and_helpers( From 79387c37cc73066b36ad8259a25bd199c55e657c Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 4 Dec 2025 13:09:27 +0200 Subject: [PATCH 28/35] linting issue and handle file name in code_print for repaired candidate --- codeflash/api/aiservice.py | 9 ++++++++- codeflash/optimization/function_optimizer.py | 8 ++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 29b3142e8..44c286ddb 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -307,7 +307,14 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op """ console.rule() try: - response = self.make_ai_service_request("/code_repair", payload=request, timeout=120) + payload = { + "optimization_id": request.optimization_id, + "original_source_code": request.original_source_code, + "modified_source_code": request.modified_source_code, + "trace_id": request.trace_id, + "test_diffs": request.test_diffs, + } + response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) except (requests.exceptions.RequestException, TypeError) as e: logger.exception(f"Error generating optimization repair: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 0f84a671e..fce94f7ee 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -703,7 +703,7 @@ def determine_best_candidate( async_throughput=candidate_result.async_throughput, ) valid_optimizations.append(best_optimization) - # # queue corresponding refined optimization for best optimization + # queue corresponding refined optimization for best optimization if not candidate.optimization_id.endswith("refi"): future_all_refinements.append( self.refine_optimizations( @@ -1896,7 +1896,11 @@ def run_optimized_candidate( # noqa: PLR0911 if not new_candidate: return Failure("Code repair failed to generate a valid candidate."), candidate - code_print(new_candidate.source_code.flat) + code_print( + new_candidate.source_code.flat, + file_name=f"candidate_{optimization_candidate_index}.py", + function_name=self.function_to_optimize.function_name, + ) normalized_code = normalize_code(new_candidate.source_code.flat.strip()) self.ast_code_to_id[normalized_code] = { "optimization_id": new_candidate.optimization_id, From 4c13bb9626fa0c448bdb5c8786365bdace79fcd3 Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 4 Dec 2025 15:48:42 +0200 Subject: [PATCH 29/35] fixes --- codeflash/models/models.py | 1 - codeflash/optimization/function_optimizer.py | 16 +++++++++++++--- codeflash/verification/equivalence.py | 9 ++++----- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index bdee8a3e9..52f8ba4c8 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -52,7 +52,6 @@ class TestDiffScope(str, Enum): RETURN_VALUE = "return_value" STDOUT = "stdout" DID_PASS = "did_pass" # noqa: S105 - TIMED_OUT = "timed_out" @dataclass diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index fce94f7ee..c054a51ff 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -461,6 +461,13 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) + def reset_optimization_metrics_for_candidate( + self, opt_id: str, speedup_ratios: dict, is_correct: dict, optimized_runtimes: dict + ) -> None: + speedup_ratios[opt_id] = None + is_correct[opt_id] = False + optimized_runtimes[opt_id] = None + def was_candidate_tested_before(self, normalized_code: str) -> bool: # check if this code has been evaluated before by checking the ast normalized code string return normalized_code in self.ast_code_to_id @@ -602,6 +609,9 @@ def determine_best_candidate( "shorter_source_code": candidate.source_code, "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), } + self.reset_optimization_metrics_for_candidate( + candidate.optimization_id, speedup_ratios, is_correct, optimized_runtimes + ) run_results, new_candidate = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, @@ -617,9 +627,9 @@ def determine_best_candidate( console.rule() if not is_successful(run_results): - optimized_runtimes[candidate.optimization_id] = None - is_correct[candidate.optimization_id] = False - speedup_ratios[candidate.optimization_id] = None + self.reset_optimization_metrics_for_candidate( + candidate.optimization_id, speedup_ratios, is_correct, optimized_runtimes + ) else: candidate_result: OptimizedCandidateResult = run_results.unwrap() best_test_runtime = candidate_result.best_test_runtime diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 2e7e6e4fe..545616bae 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -65,8 +65,8 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) test_diff = TestDiff( scope=TestDiffScope.RETURN_VALUE, - original_value=f"{original_test_result.return_value!r}", - candidate_value=f"{cdd_test_result.return_value!r}", + original_value=repr(original_test_result.return_value), + candidate_value=repr(cdd_test_result.return_value), test_src_code=test_src_code, candidate_pytest_error=cdd_pytest_error, original_pass=original_test_result.did_pass, @@ -88,8 +88,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ) except Exception as e: logger.error(e) - break - if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( + elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator( original_test_result.stdout, cdd_test_result.stdout ): test_diff.scope = TestDiffScope.STDOUT @@ -97,7 +96,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR test_diff.candidate_value = str(cdd_test_result.stdout) test_diffs.append(test_diff) - if original_test_result.test_type in { + elif original_test_result.test_type in { TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST, TestType.GENERATED_REGRESSION, From d66d2ced93651924b73514c2da7495b23133c2ac Mon Sep 17 00:00:00 2001 From: ali Date: Thu, 4 Dec 2025 23:21:26 +0200 Subject: [PATCH 30/35] small changes --- codeflash/api/aiservice.py | 11 +++++------ codeflash/optimization/function_optimizer.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 44c286ddb..8b59fab9a 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -321,18 +321,17 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op return None if response.status_code == 200: - refined_optimization = response.json() + fixed_optimization = response.json() console.rule() - refinements = self._get_valid_candidates([refined_optimization]) - if not refinements: + if not self._get_valid_candidates([fixed_optimization]): logger.error("Code repair failed to generate a valid candidate.") return None return OptimizedCandidate( - source_code=refinements[0].source_code, - explanation=refinements[0].explanation, - optimization_id=refinements[0].optimization_id[:-4] + "cdrp", + source_code=fixed_optimization["source_code"], + explanation=fixed_optimization["explanation"], + optimization_id=fixed_optimization["optimization_id"], ) try: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index c054a51ff..11139bc53 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1892,7 +1892,7 @@ def run_optimized_candidate( # noqa: PLR0911 ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client - with progress_bar("The test results are not matching, let me see if I can fix this"): + with progress_bar("Some of the test results are not matching, let me see if I can fix this"): new_candidate = self.code_repair_optimizations( original_source_code=code_context.read_writable_code.markdown, modified_source_code=candidate.source_code.markdown, From b4474f377cfc64abf9a34835c8db11fc4af09e7a Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 5 Dec 2025 15:34:26 +0200 Subject: [PATCH 31/35] add code repairs to the queue --- codeflash/api/aiservice.py | 9 +- codeflash/optimization/function_optimizer.py | 114 ++++++++----------- 2 files changed, 53 insertions(+), 70 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 8b59fab9a..4fce2006d 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -324,15 +324,12 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op fixed_optimization = response.json() console.rule() - if not self._get_valid_candidates([fixed_optimization]): + valid_candidates = self._get_valid_candidates([fixed_optimization]) + if not valid_candidates: logger.error("Code repair failed to generate a valid candidate.") return None - return OptimizedCandidate( - source_code=fixed_optimization["source_code"], - explanation=fixed_optimization["explanation"], - optimization_id=fixed_optimization["optimization_id"], - ) + return valid_candidates[0] try: error = response.json()["error"] diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 11139bc53..bbebc2b27 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -126,7 +126,8 @@ def __init__( self, initial_candidates: list, future_line_profile_results: concurrent.futures.Future, - future_all_refinements: list, + future_all_refinements: list[concurrent.futures.Future], + future_all_code_repair: list[concurrent.futures.Future], ) -> None: self.candidate_queue = queue.Queue() self.line_profiler_done = False @@ -139,6 +140,7 @@ def __init__( self.future_line_profile_results = future_line_profile_results self.future_all_refinements = future_all_refinements + self.future_all_code_repair = future_all_code_repair def get_next_candidate(self) -> OptimizedCandidate | None: """Get the next candidate from the queue, handling async results as needed.""" @@ -151,6 +153,8 @@ def _handle_empty_queue(self) -> OptimizedCandidate | None: """Handle empty queue by checking for pending async results.""" if not self.line_profiler_done: return self._process_line_profiler_results() + if len(self.future_all_code_repair) > 0: + return self._process_code_repair() if self.line_profiler_done and not self.refinement_done: return self._process_refinement_results() return None # All done @@ -190,10 +194,30 @@ def _process_refinement_results(self) -> OptimizedCandidate | None: logger.info( f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}" ) + self.future_all_refinements = [] self.refinement_done = True return self.get_next_candidate() + def _process_code_repair(self) -> OptimizedCandidate | None: + logger.info(f"loading|Repairing {len(self.future_all_code_repair)} candidates") + concurrent.futures.wait(self.future_all_code_repair) + candidates_added = 0 + for future_code_repair in self.future_all_code_repair: + possible_code_repair = future_code_repair.result() + if possible_code_repair: + self.candidate_queue.put(possible_code_repair) + self.candidate_len += 1 + candidates_added += 1 + + if candidates_added > 0: + logger.info( + f"Added {candidates_added} candidates from code repair, total candidates now: {self.candidate_len}" + ) + self.future_all_code_repair = [] + + return self.get_next_candidate() + def is_done(self) -> bool: """Check if processing is complete.""" return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty() @@ -250,6 +274,8 @@ def __init__( ) self.optimization_review = "" self.ast_code_to_id = {} + self.future_all_refinements: list[concurrent.futures.Future] = [] + self.future_all_code_repair: list[concurrent.futures.Future] = [] def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -528,8 +554,10 @@ def determine_best_candidate( ) console.rule() - future_all_refinements: list[concurrent.futures.Future] = [] self.ast_code_to_id.clear() + self.future_all_refinements.clear() + self.future_all_code_repair.clear() + valid_optimizations = [] optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated @@ -550,7 +578,9 @@ def determine_best_candidate( ) # Initialize candidate processor - processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements) + processor = CandidateProcessor( + candidates, future_line_profile_results, self.future_all_refinements, self.future_all_code_repair + ) candidate_index = 0 # Process candidates using queue-based approach @@ -609,10 +639,8 @@ def determine_best_candidate( "shorter_source_code": candidate.source_code, "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), } - self.reset_optimization_metrics_for_candidate( - candidate.optimization_id, speedup_ratios, is_correct, optimized_runtimes - ) - run_results, new_candidate = self.run_optimized_candidate( + + run_results = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, original_helper_code=original_helper_code, @@ -621,9 +649,6 @@ def determine_best_candidate( candidate=candidate, exp_type=exp_type, ) - if candidate.optimization_id != new_candidate.optimization_id: - # override the candidate if the optimization_id has changed, this may happen if the candidate was modified by the code-repair - candidate = new_candidate console.rule() if not is_successful(run_results): @@ -715,7 +740,7 @@ def determine_best_candidate( valid_optimizations.append(best_optimization) # queue corresponding refined optimization for best optimization if not candidate.optimization_id.endswith("refi"): - future_all_refinements.append( + self.future_all_refinements.append( self.refine_optimizations( valid_optimizations=[best_optimization], original_code_baseline=original_code_baseline, @@ -880,7 +905,7 @@ def refine_optimizations( ] return executor.submit(ai_service_client.optimize_python_code_refinement, request=request) - def code_repair_optimizations( + def repair_optimization( self, original_source_code: str, modified_source_code: str, @@ -888,7 +913,8 @@ def code_repair_optimizations( trace_id: str, optimization_id: str, ai_service_client: AiServiceClient, - ) -> OptimizedCandidate | None: + executor: concurrent.futures.ThreadPoolExecutor, + ) -> concurrent.futures.Future[OptimizedCandidate | None]: request = AIServiceCodeRepairRequest( optimization_id=optimization_id, original_source_code=original_source_code, @@ -896,7 +922,7 @@ def code_repair_optimizations( test_diffs=test_diffs, trace_id=trace_id, ) - return ai_service_client.optimize_python_code_repair(request=request) + return executor.submit(ai_service_client.optimize_python_code_repair, request=request) def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str @@ -1816,7 +1842,7 @@ def get_results_not_matched_error(self) -> Failure: console.rule() return Failure("Test results did not match the test results of the original code.") - def run_optimized_candidate( # noqa: PLR0911 + def run_optimized_candidate( self, *, optimization_candidate_index: int, @@ -1826,7 +1852,7 @@ def run_optimized_candidate( # noqa: PLR0911 code_context: CodeOptimizationContext, candidate: OptimizedCandidate, exp_type: str, - ) -> tuple[Result[OptimizedCandidateResult, str], OptimizedCandidate]: + ) -> Result[OptimizedCandidateResult, str]: assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 with progress_bar("Testing optimization candidate"): @@ -1884,16 +1910,16 @@ def run_optimized_candidate( # noqa: PLR0911 result_unmatched_perc = len(diffs) / len(candidate_behavior_results) if result_unmatched_perc > 0.5: # if the test unmatched percentage is greater than 50%, we can't fix it - return self.get_results_not_matched_error(), candidate + return self.get_results_not_matched_error() if candidate.optimization_id.endswith("cdrp"): # prevent looping for now - return self.get_results_not_matched_error(), candidate + return self.get_results_not_matched_error() ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client - - with progress_bar("Some of the test results are not matching, let me see if I can fix this"): - new_candidate = self.code_repair_optimizations( + logger.info("Adding this to the repair queue") + self.future_all_code_repair.append( + self.repair_optimization( original_source_code=code_context.read_writable_code.markdown, modified_source_code=candidate.source_code.markdown, test_diffs=diffs, @@ -1902,51 +1928,11 @@ def run_optimized_candidate( # noqa: PLR0911 else self.function_trace_id, ai_service_client=ai_service_client, optimization_id=candidate.optimization_id, + executor=self.executor, ) - if not new_candidate: - return Failure("Code repair failed to generate a valid candidate."), candidate - - code_print( - new_candidate.source_code.flat, - file_name=f"candidate_{optimization_candidate_index}.py", - function_name=self.function_to_optimize.function_name, ) - normalized_code = normalize_code(new_candidate.source_code.flat.strip()) - self.ast_code_to_id[normalized_code] = { - "optimization_id": new_candidate.optimization_id, - "shorter_source_code": new_candidate.source_code, - "diff_len": diff_length(new_candidate.source_code.flat, code_context.read_writable_code.flat), - } - try: - # revert first to original code then replace with new repaired code, so we don't get any weird behavior - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - did_update = self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, - optimized_code=new_candidate.source_code, - original_helper_code=original_helper_code, - ) - if did_update: - return self.run_optimized_candidate( - optimization_candidate_index=optimization_candidate_index, - baseline_results=baseline_results, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, - code_context=code_context, - candidate=new_candidate, - exp_type=exp_type, - ) - msg = "No functions were replaced in the optimized code. Skipping optimization candidate." - logger.warning(f"force_lsp|{msg}") - return Failure(msg), candidate - except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: - logger.error(e) - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - return Failure("Code repair failed to generate a valid candidate."), candidate + return self.get_results_not_matched_error() logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") @@ -2038,7 +2024,7 @@ def run_optimized_candidate( # noqa: PLR0911 total_candidate_timing=total_candidate_timing, async_throughput=candidate_async_throughput, ) - ), candidate + ) def run_and_parse_tests( self, From 726405bf43cfc2f73255bf830f25eecbf5a7b1b1 Mon Sep 17 00:00:00 2001 From: ali Date: Fri, 5 Dec 2025 15:48:26 +0200 Subject: [PATCH 32/35] optimization source --- codeflash/api/aiservice.py | 32 +++++++++++--------- codeflash/models/models.py | 8 +++++ codeflash/optimization/function_optimizer.py | 9 ++---- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 4fce2006d..f147b4852 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -18,7 +18,12 @@ from codeflash.code_utils.time_utils import humanize_runtime from codeflash.lsp.helpers import is_LSP_enabled from codeflash.models.ExperimentMetadata import ExperimentMetadata -from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate +from codeflash.models.models import ( + AIServiceRefinerRequest, + CodeStringsMarkdown, + OptimizedCandidate, + OptimizedCandidateSource, +) from codeflash.telemetry.posthog_cf import ph from codeflash.version import __version__ as codeflash_version @@ -86,7 +91,9 @@ def make_ai_service_request( # response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code return response - def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]: + def _get_valid_candidates( + self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource + ) -> list[OptimizedCandidate]: candidates: list[OptimizedCandidate] = [] for opt in optimizations_json: code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"]) @@ -94,7 +101,10 @@ def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> lis continue candidates.append( OptimizedCandidate( - source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"] + source_code=code, + explanation=opt["explanation"], + optimization_id=opt["optimization_id"], + source=source, ) ) return candidates @@ -157,7 +167,7 @@ def optimize_python_code( # noqa: D417 console.rule() end_time = time.perf_counter() logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.") - return self._get_valid_candidates(optimizations_json) + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE) try: error = response.json()["error"] except Exception: @@ -222,7 +232,7 @@ def optimize_python_code_line_profiler( # noqa: D417 f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information." ) console.rule() - return self._get_valid_candidates(optimizations_json) + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP) try: error = response.json()["error"] except Exception: @@ -275,15 +285,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") console.rule() - refinements = self._get_valid_candidates(refined_optimizations) - return [ - OptimizedCandidate( - source_code=c.source_code, - explanation=c.explanation, - optimization_id=c.optimization_id[:-4] + "refi", - ) - for c in refinements - ] + return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE) try: error = response.json()["error"] @@ -324,7 +326,7 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op fixed_optimization = response.json() console.rule() - valid_candidates = self._get_valid_candidates([fixed_optimization]) + valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.REPAIR) if not valid_candidates: logger.error("Code repair failed to generate a valid candidate.") return None diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 52f8ba4c8..e7116ffc5 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -383,11 +383,19 @@ class TestsInFile: test_type: TestType +class OptimizedCandidateSource(enum.Enum, str): + OPTIMIZE = "OPTIMIZE" + OPTIMIZE_LP = "OPTIMIZE_LP" + REFINE = "REFINE" + REPAIR = "REPAIR" + + @dataclass(frozen=True) class OptimizedCandidate: source_code: CodeStringsMarkdown explanation: str optimization_id: str + source: OptimizedCandidateSource @dataclass(frozen=True) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index bbebc2b27..b0a60efe9 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -77,6 +77,7 @@ OptimizationSet, OptimizedCandidate, OptimizedCandidateResult, + OptimizedCandidateSource, OriginalCodeBaseline, TestFile, TestFiles, @@ -739,7 +740,7 @@ def determine_best_candidate( ) valid_optimizations.append(best_optimization) # queue corresponding refined optimization for best optimization - if not candidate.optimization_id.endswith("refi"): + if candidate.source != OptimizedCandidateSource.REFINE: self.future_all_refinements.append( self.refine_optimizations( valid_optimizations=[best_optimization], @@ -1908,14 +1909,10 @@ def run_optimized_candidate( console.rule() else: result_unmatched_perc = len(diffs) / len(candidate_behavior_results) - if result_unmatched_perc > 0.5: + if candidate.source == OptimizedCandidateSource.REPAIR or result_unmatched_perc > 0.5: # if the test unmatched percentage is greater than 50%, we can't fix it return self.get_results_not_matched_error() - if candidate.optimization_id.endswith("cdrp"): - # prevent looping for now - return self.get_results_not_matched_error() - ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client logger.info("Adding this to the repair queue") self.future_all_code_repair.append( From cfb52c0486158bc568b793156f6ab38c1efbbc9c Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Fri, 5 Dec 2025 10:22:51 -0500 Subject: [PATCH 33/35] fix for equivalence --- codeflash/verification/equivalence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 545616bae..fc8f26445 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -51,7 +51,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ): continue if original_test_result is None or cdd_test_result is None: - return False, [] + continue did_all_timeout = did_all_timeout and original_test_result.timed_out if original_test_result.timed_out: continue From a2464d4d690695f62c86e3c4cb2092750b71ccd3 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Fri, 5 Dec 2025 11:18:42 -0500 Subject: [PATCH 34/35] quick and dirty --- codeflash/api/aiservice.py | 1 + codeflash/models/models.py | 1 + codeflash/optimization/function_optimizer.py | 39 +++++++++++--------- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index f147b4852..df935dcf6 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -315,6 +315,7 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op "modified_source_code": request.modified_source_code, "trace_id": request.trace_id, "test_diffs": request.test_diffs, + "past_trials": request.past_trials, } response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) except (requests.exceptions.RequestException, TypeError) as e: diff --git a/codeflash/models/models.py b/codeflash/models/models.py index e7116ffc5..5b46c9eb6 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -74,6 +74,7 @@ class AIServiceCodeRepairRequest: modified_source_code: str trace_id: str test_diffs: list[TestDiff] + past_trials: str # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index b0a60efe9..8918dc1ae 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -739,21 +739,21 @@ def determine_best_candidate( async_throughput=candidate_result.async_throughput, ) valid_optimizations.append(best_optimization) - # queue corresponding refined optimization for best optimization - if candidate.source != OptimizedCandidateSource.REFINE: - self.future_all_refinements.append( - self.refine_optimizations( - valid_optimizations=[best_optimization], - original_code_baseline=original_code_baseline, - code_context=code_context, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - executor=self.executor, - function_references=function_references, - ) - ) + # # queue corresponding refined optimization for best optimization + # if candidate.source != OptimizedCandidateSource.REFINE: + # self.future_all_refinements.append( + # self.refine_optimizations( + # valid_optimizations=[best_optimization], + # original_code_baseline=original_code_baseline, + # code_context=code_context, + # trace_id=self.function_trace_id[:-4] + exp_type + # if self.experiment_id + # else self.function_trace_id, + # ai_service_client=ai_service_client, + # executor=self.executor, + # function_references=function_references, + # ) + # ) else: # For async functions, prioritize throughput metrics over runtime even for slow candidates is_async = ( @@ -913,6 +913,7 @@ def repair_optimization( test_diffs: list[TestDiff], trace_id: str, optimization_id: str, + past_trials: str, ai_service_client: AiServiceClient, executor: concurrent.futures.ThreadPoolExecutor, ) -> concurrent.futures.Future[OptimizedCandidate | None]: @@ -922,6 +923,7 @@ def repair_optimization( modified_source_code=modified_source_code, test_diffs=test_diffs, trace_id=trace_id, + past_trials=past_trials, ) return executor.submit(ai_service_client.optimize_python_code_repair, request=request) @@ -1915,7 +1917,8 @@ def run_optimized_candidate( ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client logger.info("Adding this to the repair queue") - self.future_all_code_repair.append( + for _ in range(3): + # first candidate self.repair_optimization( original_source_code=code_context.read_writable_code.markdown, modified_source_code=candidate.source_code.markdown, @@ -1926,9 +1929,9 @@ def run_optimized_candidate( ai_service_client=ai_service_client, optimization_id=candidate.optimization_id, executor=self.executor, + past_trials="", ) - ) - + # behavior to test, if pass break return self.get_results_not_matched_error() logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") From 3cf23870b5fe01bd30bf220a73070437238f16fd Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 9 Dec 2025 14:31:24 -0500 Subject: [PATCH 35/35] working dirty implementation --- codeflash/api/aiservice.py | 1 + codeflash/models/models.py | 3 +- codeflash/optimization/function_optimizer.py | 512 +++++++++++++------ codeflash/verification/equivalence.py | 26 +- 4 files changed, 382 insertions(+), 160 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index df935dcf6..8b0367c0d 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -316,6 +316,7 @@ def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> Op "trace_id": request.trace_id, "test_diffs": request.test_diffs, "past_trials": request.past_trials, + "trial_no": request.trial_no } response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) except (requests.exceptions.RequestException, TypeError) as e: diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 5b46c9eb6..d972552d3 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -75,6 +75,7 @@ class AIServiceCodeRepairRequest: trace_id: str test_diffs: list[TestDiff] past_trials: str + trial_no: str # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully @@ -384,7 +385,7 @@ class TestsInFile: test_type: TestType -class OptimizedCandidateSource(enum.Enum, str): +class OptimizedCandidateSource(str, Enum): OPTIMIZE = "OPTIMIZE" OPTIMIZE_LP = "OPTIMIZE_LP" REFINE = "REFINE" diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 8918dc1ae..ad35a6c61 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -5,6 +5,7 @@ import os import queue import random +import sqlite3 import subprocess import time import uuid @@ -48,6 +49,8 @@ N_TESTS_TO_GENERATE_EFFECTIVE, REPEAT_OPTIMIZATION_PROBABILITY, TOTAL_LOOPING_TIME_EFFECTIVE, + MIN_IMPROVEMENT_THRESHOLD, + MIN_TESTCASE_PASSED_THRESHOLD, ) from codeflash.code_utils.deduplicate_code import normalize_code from codeflash.code_utils.edit_generated_tests import ( @@ -67,6 +70,7 @@ from codeflash.either import Failure, Success, is_successful from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId +from codeflash.models import models from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( AIServiceCodeRepairRequest, @@ -83,7 +87,7 @@ TestFiles, TestingMode, TestResults, - TestType, + TestType, TestDiffScope, ) from codeflash.result.create_pr import check_create_pr, existing_tests_source_for from codeflash.result.critic import ( @@ -119,6 +123,145 @@ ) from codeflash.verification.verification_utils import TestConfig +CODE_REPAIR_LOG_DB = Path("/Users/aseemsaxena/Downloads/codeflash_dev/codeflash-internal/django/aiservice/code_repair/code_repair_log.db") + + +SCOPE_DESCRIPTIONS = { + TestDiffScope.RETURN_VALUE: ( + "The function returned a different value in the optimized code compared to the original." + ), + TestDiffScope.STDOUT: ("The output printed to stdout is different in the optimized code compared to the original."), + TestDiffScope.DID_PASS: ( + "The test passed in one version but failed in the other (a change in pass/fail behavior)." + ), +} + +def build_test_details(test_diffs: list[TestDiff]) -> str: + sections = [] + for test_no, diff in enumerate(test_diffs, 1): + test_src_code = "```python\n" + diff.test_src_code + "\n```" if diff.test_src_code else "" + section = [ + f"#### Test #{test_no}", + f"{SCOPE_DESCRIPTIONS.get(diff.scope, diff.scope.value)}", + f"Expected: {diff.original_value!r}. Got: {diff.candidate_value!r}" + if diff.scope != TestDiffScope.DID_PASS + else "", + f"Original code test status: {'Passed' if diff.original_pass else 'Failed'}. Optimized code test status: {'Passed' if diff.candidate_pass else 'Failed'}", + f"Pytest error (original code): {diff.original_pytest_error}" if diff.original_pytest_error else "", + f"Pytest error (optimized code): {diff.candidate_pytest_error}" if diff.candidate_pytest_error else "", + "Test Source:", + test_src_code, + "---", + ] + sections.append("\n".join(filter(None, section))) + + return "\n".join(sections) + +def _init_code_repair_log_db() -> None: + """Initialize the SQLite database for code repair logging.""" + conn = sqlite3.connect(CODE_REPAIR_LOG_DB) + cursor = conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS code_repair_logs ( + optimization_id TEXT PRIMARY KEY, + trace_id TEXT, + user_prompt TEXT, + explanation TEXT, + refined_optimization TEXT, + trial_no TEXT, + past_trials TEXT, + passed TEXT, + faster TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() + conn.close() + + +def log_code_repair_to_db( + optimization_id: str, + trace_id: str | None = None, + user_prompt: str | None = None, + explanation: str | None = None, + refined_optimization: str | None = None, + trial_no: str | None = None, + past_trials: str | None = None, + passed: str | None = None, + faster: str | None = None, +) -> None: + """Log code repair data to SQLite database. + + Uses upsert pattern to allow incremental logging with different columns at different places. + Only non-None values will be updated; existing values are preserved. + """ + try: + _init_code_repair_log_db() + conn = sqlite3.connect(CODE_REPAIR_LOG_DB) + cursor = conn.cursor() + + # Build dynamic upsert query based on provided columns + columns = ["optimization_id"] + values = [optimization_id] + update_parts = ["updated_at = CURRENT_TIMESTAMP"] + + if trace_id is not None: + columns.append("trace_id") + values.append(trace_id) + update_parts.append("trace_id = excluded.trace_id") + + if user_prompt is not None: + columns.append("user_prompt") + values.append(user_prompt) + update_parts.append("user_prompt = excluded.user_prompt") + + if explanation is not None: + columns.append("explanation") + values.append(explanation) + update_parts.append("explanation = excluded.explanation") + + if refined_optimization is not None: + columns.append("refined_optimization") + values.append(refined_optimization) + update_parts.append("refined_optimization = excluded.refined_optimization") + + if trial_no is not None: + columns.append("trial_no") + values.append(trial_no) + update_parts.append("trial_no = excluded.trial_no") + + if past_trials is not None: + columns.append("past_trials") + values.append(past_trials) + update_parts.append("past_trials = excluded.past_trials") + + if passed is not None: + columns.append("passed") + values.append(passed) + update_parts.append("passed = excluded.passed") + + if faster is not None: + columns.append("faster") + values.append(faster) + update_parts.append("faster = excluded.faster") + + placeholders = ", ".join(["?"] * len(values)) + columns_str = ", ".join(columns) + update_str = ", ".join(update_parts) + + cursor.execute( + f""" + INSERT INTO code_repair_logs ({columns_str}) + VALUES ({placeholders}) + ON CONFLICT(optimization_id) DO UPDATE SET {update_str} + """, # noqa: S608 + values, + ) + conn.commit() + conn.close() + except Exception: + logger.exception("Failed to log code repair data to SQLite") class CandidateProcessor: """Handles candidate processing using a queue-based approach.""" @@ -649,6 +792,7 @@ def determine_best_candidate( code_context=code_context, candidate=candidate, exp_type=exp_type, + original_code_baseline=original_code_baseline ) console.rule() @@ -812,6 +956,7 @@ def determine_best_candidate( source_code=self.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], optimization_id=valid_opt.candidate.optimization_id, explanation=valid_opt.candidate.explanation, + source=valid_opt.candidate.source, ) new_best_opt = BestOptimization( candidate=new_candidate_with_shorter_code, @@ -914,9 +1059,9 @@ def repair_optimization( trace_id: str, optimization_id: str, past_trials: str, + trial_no: str, ai_service_client: AiServiceClient, - executor: concurrent.futures.ThreadPoolExecutor, - ) -> concurrent.futures.Future[OptimizedCandidate | None]: + ) -> OptimizedCandidate | None: request = AIServiceCodeRepairRequest( optimization_id=optimization_id, original_source_code=original_source_code, @@ -924,8 +1069,9 @@ def repair_optimization( test_diffs=test_diffs, trace_id=trace_id, past_trials=past_trials, + trial_no=trial_no ) - return executor.submit(ai_service_client.optimize_python_code_repair, request=request) + return ai_service_client.optimize_python_code_repair(request=request) def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str @@ -1855,177 +2001,245 @@ def run_optimized_candidate( code_context: CodeOptimizationContext, candidate: OptimizedCandidate, exp_type: str, + original_code_baseline, # noqa: ANN001 ) -> Result[OptimizedCandidateResult, str]: - assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 - - with progress_bar("Testing optimization candidate"): - test_env = self.get_test_env( - codeflash_loop_index=0, - codeflash_test_iteration=optimization_candidate_index, - codeflash_tracer_disable=1, - ) - - get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True) - # Instrument codeflash capture - candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") - candidate_helper_code = {} - for module_abspath in original_helper_code: - candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8") - if self.function_to_optimize.is_async: - from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function - - add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR - ) - - try: - instrument_codeflash_capture( - self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root - ) - - total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE - candidate_behavior_results, _ = self.run_and_parse_tests( - testing_type=TestingMode.BEHAVIOR, - test_env=test_env, - test_files=self.test_files, - optimization_iteration=optimization_candidate_index, - testing_time=total_looping_time, - enable_coverage=False, + current_candidate = candidate + current_candidate_index = optimization_candidate_index + past_trials = "" + for trial_no in range(4): + print("Trial no: ", trial_no) + assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 + with progress_bar("Testing optimization candidate"): + test_env = self.get_test_env( + codeflash_loop_index=0, + codeflash_test_iteration=current_candidate_index, + codeflash_tracer_disable=1, ) - # Remove instrumentation - finally: - self.write_code_and_helpers( - candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path - ) - console.print( - TestResults.report_to_tree( - candidate_behavior_results.get_test_pass_fail_report_by_type(), - title=f"Behavioral Test Results for candidate {optimization_candidate_index}", - ) - ) - console.rule() - # print(type(code_context), type(candidate)) - match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) - if match: - logger.info("h3|Test results matched ✅") - console.rule() - else: - result_unmatched_perc = len(diffs) / len(candidate_behavior_results) - if candidate.source == OptimizedCandidateSource.REPAIR or result_unmatched_perc > 0.5: - # if the test unmatched percentage is greater than 50%, we can't fix it - return self.get_results_not_matched_error() - - ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client - logger.info("Adding this to the repair queue") - for _ in range(3): - # first candidate - self.repair_optimization( - original_source_code=code_context.read_writable_code.markdown, - modified_source_code=candidate.source_code.markdown, - test_diffs=diffs, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - optimization_id=candidate.optimization_id, - executor=self.executor, - past_trials="", - ) - # behavior to test, if pass break - return self.get_results_not_matched_error() - - logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") - - if test_framework == "pytest": - # For async functions, instrument at definition site for performance benchmarking + get_run_tmp_file(Path(f"test_return_values_{current_candidate_index}.sqlite")).unlink(missing_ok=True) + # Instrument codeflash capture + candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8") + candidate_helper_code = {} + for module_abspath in original_helper_code: + candidate_helper_code[module_abspath] = Path(module_abspath).read_text("utf-8") if self.function_to_optimize.is_async: from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function - add_async_decorator_to_function( - self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE + self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.BEHAVIOR ) - try: - candidate_benchmarking_results, _ = self.run_and_parse_tests( - testing_type=TestingMode.PERFORMANCE, + instrument_codeflash_capture( + self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + ) + total_looping_time = TOTAL_LOOPING_TIME_EFFECTIVE + candidate_behavior_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, test_env=test_env, test_files=self.test_files, - optimization_iteration=optimization_candidate_index, + optimization_iteration=current_candidate_index, testing_time=total_looping_time, enable_coverage=False, ) + # Remove instrumentation finally: - # Restore original source if we instrumented it - if self.function_to_optimize.is_async: - self.write_code_and_helpers( - candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path - ) - loop_count = ( - max(all_loop_indices) - if ( - all_loop_indices := { - result.loop_index for result in candidate_benchmarking_results.test_results - } + self.write_code_and_helpers( + candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path + ) + console.print( + TestResults.report_to_tree( + candidate_behavior_results.get_test_pass_fail_report_by_type(), + title=f"Behavioral Test Results for candidate {current_candidate_index}", ) - else 0 ) - - else: - candidate_benchmarking_results = TestResults() - start_time: float = time.time() - loop_count = 0 - for i in range(100): - if i >= 5 and time.time() - start_time >= TOTAL_LOOPING_TIME_EFFECTIVE * 1.5: - # * 1.5 to give unittest a bit more time to run + console.rule() + match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) + if match: + logger.info("h3|Test results matched ✅") + console.rule() + if trial_no!=0: + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no), + passed="yes", + ) + break + if trial_no<=2: + # repair process + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + # first candidate + repair_candidate = self.repair_optimization( + original_source_code=code_context.read_writable_code.markdown, + modified_source_code=current_candidate.source_code.markdown, + test_diffs=diffs, + trace_id=self.function_trace_id, + ai_service_client=ai_service_client, + optimization_id=candidate.optimization_id, + past_trials=past_trials, + trial_no=str(trial_no+1) + ) + if not repair_candidate: + logger.debug("llm call failed") + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no+1), + passed="no", + faster="no" + ) + match = False + if trial_no != 2: + continue break - test_env["CODEFLASH_LOOP_INDEX"] = str(i + 1) - unittest_loop_results, _cov = self.run_and_parse_tests( - testing_type=TestingMode.PERFORMANCE, - test_env=test_env, - test_files=self.test_files, - optimization_iteration=optimization_candidate_index, - testing_time=TOTAL_LOOPING_TIME_EFFECTIVE, - unittest_loop_index=i + 1, + try: + # update code + did_update = self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=repair_candidate.source_code, + original_helper_code=original_helper_code, + ) + if not did_update: + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no + 1), + passed="no", + faster="no", + ) + match = False + if trial_no != 2: + continue + break + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + logger.error(e) + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no + 1), + passed="no", + faster="no", + ) + match = False + if trial_no != 2: + continue + break + past_trials += f"Trial {trial_no + 1}\n" + past_trials += f"Candidate Code\n{current_candidate.source_code.markdown}\n" + past_trials += "Abridged test results\n" + past_trials += build_test_details(diffs)[:2000] + current_candidate = repair_candidate + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no + 1), + passed="no", + faster="no", ) - loop_count = i + 1 - candidate_benchmarking_results.merge(unittest_loop_results) - - if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0: - logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.") - console.rule() - - logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}") + # behavior to test, if pass break + # log the results + # return self.get_results_not_matched_error() + if not match: + print("didn't work after 3 trials abort") + if trial_no!=0: + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no), + passed="no", + faster="no" + ) + return self.get_results_not_matched_error() + # performance benchmark + logger.info(f"loading|Running performance tests for candidate {current_candidate_index}...") - candidate_async_throughput = None + if test_framework == "pytest": + # For async functions, instrument at definition site for performance benchmarking if self.function_to_optimize.is_async: - candidate_async_throughput = calculate_function_throughput_from_test_results( - candidate_benchmarking_results, self.function_to_optimize.function_name + from codeflash.code_utils.instrument_existing_tests import add_async_decorator_to_function + + add_async_decorator_to_function( + self.function_to_optimize.file_path, self.function_to_optimize, TestingMode.PERFORMANCE ) - logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second") - if self.args.benchmark: - candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks( - self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root + try: + candidate_benchmarking_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.PERFORMANCE, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=current_candidate_index, + testing_time=total_looping_time, + enable_coverage=False, ) - for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items(): - logger.debug( - f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}" + finally: + # Restore original source if we instrumented it + if self.function_to_optimize.is_async: + self.write_code_and_helpers( + candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path ) - return Success( - OptimizedCandidateResult( - max_loop_count=loop_count, - best_test_runtime=total_candidate_timing, - behavior_test_results=candidate_behavior_results, - benchmarking_test_results=candidate_benchmarking_results, - replay_benchmarking_test_results=candidate_replay_benchmarking_results - if self.args.benchmark - else None, - optimization_candidate_index=optimization_candidate_index, - total_candidate_timing=total_candidate_timing, - async_throughput=candidate_async_throughput, + loop_count = ( + max(all_loop_indices) + if ( + all_loop_indices := { + result.loop_index for result in candidate_benchmarking_results.test_results + } ) + else 0 ) + if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0: + logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.") + console.rule() + + logger.debug(f"Total optimized code {current_candidate_index} runtime (ns): {total_candidate_timing}") + + candidate_async_throughput = None + if self.function_to_optimize.is_async: + candidate_async_throughput = calculate_function_throughput_from_test_results( + candidate_benchmarking_results, self.function_to_optimize.function_name + ) + logger.debug(f"Candidate async function throughput: {candidate_async_throughput} calls/second") + + if self.args.benchmark: + candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks( + self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root + ) + for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items(): + logger.debug( + f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}" + ) + best_test_runtime = total_candidate_timing + perf_gain = performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime + ) + noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_baseline.runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD + #log here again + report = candidate_behavior_results.get_test_pass_fail_report_by_type() + pass_count = 0 + for test_type in report: + pass_count += report[test_type]["passed"] + + if pass_count >= MIN_TESTCASE_PASSED_THRESHOLD: + speedup_critic_val = True + # If one or more tests passed, check if least one of them was a successful REPLAY_TEST + speedup_critic_val = bool(pass_count >= 6) + faster = "yes" if (perf_gain > noise_floor and speedup_critic_val) else "no" + if trial_no!=0: + log_code_repair_to_db( + trace_id=self.function_trace_id, + optimization_id=candidate.optimization_id + "_" + str(trial_no), + faster=faster, + ) + return Success( + OptimizedCandidateResult( + max_loop_count=loop_count, + best_test_runtime=total_candidate_timing, + behavior_test_results=candidate_behavior_results, + benchmarking_test_results=candidate_benchmarking_results, + replay_benchmarking_test_results=candidate_replay_benchmarking_results + if self.args.benchmark + else None, + optimization_candidate_index=optimization_candidate_index, + total_candidate_timing=total_candidate_timing, + async_throughput=candidate_async_throughput, + ) + ) + def run_and_parse_tests( self, testing_type: TestingMode, diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index fc8f26445..ef7fb910d 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -30,16 +30,22 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) candidate_test_failures = candidate_results.test_failures original_test_failures = original_results.test_failures - cdd_pytest_error = ( - candidate_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") - if candidate_test_failures - else "" - ) - original_pytest_error = ( - original_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") - if original_test_failures - else "" - ) + try: + cdd_pytest_error = ( + candidate_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if candidate_test_failures + else "" + ) + except: + cdd_pytest_error = "" + try: + original_pytest_error = ( + original_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if original_test_failures + else "" + ) + except: + original_pytest_error="" if cdd_test_result is not None and original_test_result is None: continue