diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 20c478eb4..df935dcf6 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -18,7 +18,12 @@ from codeflash.code_utils.time_utils import humanize_runtime from codeflash.lsp.helpers import is_LSP_enabled from codeflash.models.ExperimentMetadata import ExperimentMetadata -from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate +from codeflash.models.models import ( + AIServiceRefinerRequest, + CodeStringsMarkdown, + OptimizedCandidate, + OptimizedCandidateSource, +) from codeflash.telemetry.posthog_cf import ph from codeflash.version import __version__ as codeflash_version @@ -27,7 +32,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata - from codeflash.models.models import AIServiceRefinerRequest + from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest from codeflash.result.explanation import Explanation @@ -86,7 +91,9 @@ def make_ai_service_request( # response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code return response - def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]: + def _get_valid_candidates( + self, optimizations_json: list[dict[str, Any]], source: OptimizedCandidateSource + ) -> list[OptimizedCandidate]: candidates: list[OptimizedCandidate] = [] for opt in optimizations_json: code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"]) @@ -94,7 +101,10 @@ def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> lis continue candidates.append( OptimizedCandidate( - source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"] + source_code=code, + explanation=opt["explanation"], + optimization_id=opt["optimization_id"], + source=source, ) ) return candidates @@ -157,7 +167,7 @@ def optimize_python_code( # noqa: D417 console.rule() end_time = time.perf_counter() logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.") - return self._get_valid_candidates(optimizations_json) + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE) try: error = response.json()["error"] except Exception: @@ -222,7 +232,7 @@ def optimize_python_code_line_profiler( # noqa: D417 f"!lsp|Generated {len(optimizations_json)} candidate optimizations using line profiler information." ) console.rule() - return self._get_valid_candidates(optimizations_json) + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE_LP) try: error = response.json()["error"] except Exception: @@ -275,15 +285,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.") console.rule() - refinements = self._get_valid_candidates(refined_optimizations) - return [ - OptimizedCandidate( - source_code=c.source_code, - explanation=c.explanation, - optimization_id=c.optimization_id[:-4] + "refi", - ) - for c in refinements - ] + return self._get_valid_candidates(refined_optimizations, OptimizedCandidateSource.REFINE) try: error = response.json()["error"] @@ -294,6 +296,53 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [] + def optimize_python_code_repair(self, request: AIServiceCodeRepairRequest) -> OptimizedCandidate | None: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Args: + request: optimization candidate details for refinement + + Returns: + ------- + - OptimizationCandidate: new fixed candidate. + + """ + console.rule() + try: + payload = { + "optimization_id": request.optimization_id, + "original_source_code": request.original_source_code, + "modified_source_code": request.modified_source_code, + "trace_id": request.trace_id, + "test_diffs": request.test_diffs, + "past_trials": request.past_trials, + } + response = self.make_ai_service_request("/code_repair", payload=payload, timeout=120) + except (requests.exceptions.RequestException, TypeError) as e: + logger.exception(f"Error generating optimization repair: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + fixed_optimization = response.json() + console.rule() + + valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.REPAIR) + if not valid_candidates: + logger.error("Code repair failed to generate a valid candidate.") + return None + + return valid_candidates[0] + + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return None + def get_new_explanation( # noqa: D417 self, source_code: str, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 744f76087..5b46c9eb6 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -3,6 +3,7 @@ from collections import Counter, defaultdict from typing import TYPE_CHECKING +import libcst as cst from rich.tree import Tree from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log @@ -47,6 +48,35 @@ class AIServiceRefinerRequest: function_references: str | None = None +class TestDiffScope(str, Enum): + RETURN_VALUE = "return_value" + STDOUT = "stdout" + DID_PASS = "did_pass" # noqa: S105 + + +@dataclass +class TestDiff: + scope: TestDiffScope + original_pass: bool + candidate_pass: bool + + original_value: str | None = None + candidate_value: str | None = None + test_src_code: Optional[str] = None + candidate_pytest_error: Optional[str] = None + original_pytest_error: Optional[str] = None + + +@dataclass(frozen=True) +class AIServiceCodeRepairRequest: + optimization_id: str + original_source_code: str + modified_source_code: str + trace_id: str + test_diffs: list[TestDiff] + past_trials: str + + # If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully # qualified name of the method is foo.eggs.Ham.spam, its qualified name is Ham.spam, and its name is spam. The full name # of the module is foo.eggs. @@ -354,11 +384,19 @@ class TestsInFile: test_type: TestType +class OptimizedCandidateSource(enum.Enum, str): + OPTIMIZE = "OPTIMIZE" + OPTIMIZE_LP = "OPTIMIZE_LP" + REFINE = "REFINE" + REPAIR = "REPAIR" + + @dataclass(frozen=True) class OptimizedCandidate: source_code: CodeStringsMarkdown explanation: str optimization_id: str + source: OptimizedCandidateSource @dataclass(frozen=True) @@ -505,6 +543,42 @@ def id(self) -> str: f"{self.function_getting_tested}:{self.iteration_id}" ) + # TestSuiteClass.test_function_name + def test_fn_qualified_name(self) -> str: + # Use f-string with inline conditional to reduce string concatenation operations + return ( + f"{self.test_class_name}.{self.test_function_name}" + if self.test_class_name + else str(self.test_function_name) + ) + + def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: + for stmt in class_node.body.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name: + return stmt + return None + + def get_src_code(self, test_path: Path) -> Optional[str]: + if not test_path.exists(): + return None + test_src = test_path.read_text(encoding="utf-8") + module_node = cst.parse_module(test_src) + + if self.test_class_name: + for stmt in module_node.body: + if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name: + func_node = self.find_func_in_class(stmt, self.test_function_name) + if func_node: + return module_node.code_for_node(func_node).strip() + # class not found + return None + + # Otherwise, look for a top level function + for stmt in module_node.body: + if isinstance(stmt, cst.FunctionDef) and stmt.name.value == self.test_function_name: + return module_node.code_for_node(stmt).strip() + return None + @staticmethod def from_str_id(string_id: str, iteration_id: str | None = None) -> InvocationId: components = string_id.split(":") @@ -549,7 +623,10 @@ class TestResults(BaseModel): # noqa: PLW1641 # also we don't support deletion of test results elements - caution is advised test_results: list[FunctionTestInvocation] = [] test_result_idx: dict[str, int] = {} + perf_stdout: Optional[str] = None + # mapping between test function name and stdout failure message + test_failures: Optional[dict[str, str]] = None def add(self, function_test_invocation: FunctionTestInvocation) -> None: unique_id = function_test_invocation.unique_invocation_loop_id diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2eef51f0f..8918dc1ae 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -69,6 +69,7 @@ from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( + AIServiceCodeRepairRequest, BestOptimization, CodeOptimizationContext, GeneratedTests, @@ -76,6 +77,7 @@ OptimizationSet, OptimizedCandidate, OptimizedCandidateResult, + OptimizedCandidateSource, OriginalCodeBaseline, TestFile, TestFiles, @@ -113,6 +115,7 @@ CoverageData, FunctionCalledInTest, FunctionSource, + TestDiff, ) from codeflash.verification.verification_utils import TestConfig @@ -124,7 +127,8 @@ def __init__( self, initial_candidates: list, future_line_profile_results: concurrent.futures.Future, - future_all_refinements: list, + future_all_refinements: list[concurrent.futures.Future], + future_all_code_repair: list[concurrent.futures.Future], ) -> None: self.candidate_queue = queue.Queue() self.line_profiler_done = False @@ -137,6 +141,7 @@ def __init__( self.future_line_profile_results = future_line_profile_results self.future_all_refinements = future_all_refinements + self.future_all_code_repair = future_all_code_repair def get_next_candidate(self) -> OptimizedCandidate | None: """Get the next candidate from the queue, handling async results as needed.""" @@ -149,6 +154,8 @@ def _handle_empty_queue(self) -> OptimizedCandidate | None: """Handle empty queue by checking for pending async results.""" if not self.line_profiler_done: return self._process_line_profiler_results() + if len(self.future_all_code_repair) > 0: + return self._process_code_repair() if self.line_profiler_done and not self.refinement_done: return self._process_refinement_results() return None # All done @@ -188,10 +195,30 @@ def _process_refinement_results(self) -> OptimizedCandidate | None: logger.info( f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}" ) + self.future_all_refinements = [] self.refinement_done = True return self.get_next_candidate() + def _process_code_repair(self) -> OptimizedCandidate | None: + logger.info(f"loading|Repairing {len(self.future_all_code_repair)} candidates") + concurrent.futures.wait(self.future_all_code_repair) + candidates_added = 0 + for future_code_repair in self.future_all_code_repair: + possible_code_repair = future_code_repair.result() + if possible_code_repair: + self.candidate_queue.put(possible_code_repair) + self.candidate_len += 1 + candidates_added += 1 + + if candidates_added > 0: + logger.info( + f"Added {candidates_added} candidates from code repair, total candidates now: {self.candidate_len}" + ) + self.future_all_code_repair = [] + + return self.get_next_candidate() + def is_done(self) -> bool: """Check if processing is complete.""" return self.line_profiler_done and self.refinement_done and self.candidate_queue.empty() @@ -247,6 +274,9 @@ def __init__( max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4 ) self.optimization_review = "" + self.ast_code_to_id = {} + self.future_all_refinements: list[concurrent.futures.Future] = [] + self.future_all_code_repair: list[concurrent.futures.Future] = [] def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -387,7 +417,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() code_print( @@ -459,6 +488,48 @@ def optimize_function(self) -> Result[BestOptimization, str]: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) + def reset_optimization_metrics_for_candidate( + self, opt_id: str, speedup_ratios: dict, is_correct: dict, optimized_runtimes: dict + ) -> None: + speedup_ratios[opt_id] = None + is_correct[opt_id] = False + optimized_runtimes[opt_id] = None + + def was_candidate_tested_before(self, normalized_code: str) -> bool: + # check if this code has been evaluated before by checking the ast normalized code string + return normalized_code in self.ast_code_to_id + + def update_results_for_duplicate_candidate( + self, + candidate: OptimizedCandidate, + code_context: CodeOptimizationContext, + normalized_code: str, + speedup_ratios: dict, + is_correct: dict, + optimized_runtimes: dict, + optimized_line_profiler_results: dict, + optimizations_post: dict, + ) -> None: + logger.info("Current candidate has been encountered before in testing, Skipping optimization candidate.") + past_opt_id = self.ast_code_to_id[normalized_code]["optimization_id"] + # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes + speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] + is_correct[candidate.optimization_id] = is_correct[past_opt_id] + optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] + # line profiler results only available for successful runs + if past_opt_id in optimized_line_profiler_results: + optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[past_opt_id] + optimizations_post[candidate.optimization_id] = self.ast_code_to_id[normalized_code][ + "shorter_source_code" + ].markdown + optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown + new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) + if ( + new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"] + ): # new candidate has a shorter diff than the previously encountered one + self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code + self.ast_code_to_id[normalized_code]["diff_len"] = new_diff_len + def determine_best_candidate( self, *, @@ -484,8 +555,10 @@ def determine_best_candidate( ) console.rule() - future_all_refinements: list[concurrent.futures.Future] = [] - ast_code_to_id = {} + self.ast_code_to_id.clear() + self.future_all_refinements.clear() + self.future_all_code_repair.clear() + valid_optimizations = [] optimizations_post = {} # we need to overwrite some opt candidates' code strings as they are no longer evaluated, instead their shorter/longer versions might be evaluated @@ -506,7 +579,9 @@ def determine_best_candidate( ) # Initialize candidate processor - processor = CandidateProcessor(candidates, future_line_profile_results, future_all_refinements) + processor = CandidateProcessor( + candidates, future_line_profile_results, self.future_all_refinements, self.future_all_code_repair + ) candidate_index = 0 # Process candidates using queue-based approach @@ -548,47 +623,39 @@ def determine_best_candidate( continue # check if this code has been evaluated before by checking the ast normalized code string normalized_code = normalize_code(candidate.source_code.flat.strip()) - if normalized_code in ast_code_to_id: - logger.info( - "Current candidate has been encountered before in testing, Skipping optimization candidate." + if self.was_candidate_tested_before(normalized_code): + self.update_results_for_duplicate_candidate( + candidate=candidate, + code_context=code_context, + normalized_code=normalized_code, + speedup_ratios=speedup_ratios, + is_correct=is_correct, + optimized_runtimes=optimized_runtimes, + optimized_line_profiler_results=optimized_line_profiler_results, + optimizations_post=optimizations_post, ) - past_opt_id = ast_code_to_id[normalized_code]["optimization_id"] - # update speedup ratio, is_correct, optimizations_post, optimized_line_profiler_results, optimized_runtimes - speedup_ratios[candidate.optimization_id] = speedup_ratios[past_opt_id] - is_correct[candidate.optimization_id] = is_correct[past_opt_id] - optimized_runtimes[candidate.optimization_id] = optimized_runtimes[past_opt_id] - # line profiler results only available for successful runs - if past_opt_id in optimized_line_profiler_results: - optimized_line_profiler_results[candidate.optimization_id] = optimized_line_profiler_results[ - past_opt_id - ] - optimizations_post[candidate.optimization_id] = ast_code_to_id[normalized_code][ - "shorter_source_code" - ].markdown - optimizations_post[past_opt_id] = ast_code_to_id[normalized_code]["shorter_source_code"].markdown - new_diff_len = diff_length(candidate.source_code.flat, code_context.read_writable_code.flat) - if ( - new_diff_len < ast_code_to_id[normalized_code]["diff_len"] - ): # new candidate has a shorter diff than the previously encountered one - ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code - ast_code_to_id[normalized_code]["diff_len"] = new_diff_len continue - ast_code_to_id[normalized_code] = { + self.ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, "shorter_source_code": candidate.source_code, "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), } + run_results = self.run_optimized_candidate( optimization_candidate_index=candidate_index, baseline_results=original_code_baseline, original_helper_code=original_helper_code, file_path_to_helper_classes=file_path_to_helper_classes, + code_context=code_context, + candidate=candidate, + exp_type=exp_type, ) + console.rule() if not is_successful(run_results): - optimized_runtimes[candidate.optimization_id] = None - is_correct[candidate.optimization_id] = False - speedup_ratios[candidate.optimization_id] = None + self.reset_optimization_metrics_for_candidate( + candidate.optimization_id, speedup_ratios, is_correct, optimized_runtimes + ) else: candidate_result: OptimizedCandidateResult = run_results.unwrap() best_test_runtime = candidate_result.best_test_runtime @@ -672,21 +739,21 @@ def determine_best_candidate( async_throughput=candidate_result.async_throughput, ) valid_optimizations.append(best_optimization) - # queue corresponding refined optimization for best optimization - if not candidate.optimization_id.endswith("refi"): - future_all_refinements.append( - self.refine_optimizations( - valid_optimizations=[best_optimization], - original_code_baseline=original_code_baseline, - code_context=code_context, - trace_id=self.function_trace_id[:-4] + exp_type - if self.experiment_id - else self.function_trace_id, - ai_service_client=ai_service_client, - executor=self.executor, - function_references=function_references, - ) - ) + # # queue corresponding refined optimization for best optimization + # if candidate.source != OptimizedCandidateSource.REFINE: + # self.future_all_refinements.append( + # self.refine_optimizations( + # valid_optimizations=[best_optimization], + # original_code_baseline=original_code_baseline, + # code_context=code_context, + # trace_id=self.function_trace_id[:-4] + exp_type + # if self.experiment_id + # else self.function_trace_id, + # ai_service_client=ai_service_client, + # executor=self.executor, + # function_references=function_references, + # ) + # ) else: # For async functions, prioritize throughput metrics over runtime even for slow candidates is_async = ( @@ -742,7 +809,7 @@ def determine_best_candidate( for valid_opt in valid_optimizations: valid_opt_normalized_code = normalize_code(valid_opt.candidate.source_code.flat.strip()) new_candidate_with_shorter_code = OptimizedCandidate( - source_code=ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], + source_code=self.ast_code_to_id[valid_opt_normalized_code]["shorter_source_code"], optimization_id=valid_opt.candidate.optimization_id, explanation=valid_opt.candidate.explanation, ) @@ -839,6 +906,27 @@ def refine_optimizations( ] return executor.submit(ai_service_client.optimize_python_code_refinement, request=request) + def repair_optimization( + self, + original_source_code: str, + modified_source_code: str, + test_diffs: list[TestDiff], + trace_id: str, + optimization_id: str, + past_trials: str, + ai_service_client: AiServiceClient, + executor: concurrent.futures.ThreadPoolExecutor, + ) -> concurrent.futures.Future[OptimizedCandidate | None]: + request = AIServiceCodeRepairRequest( + optimization_id=optimization_id, + original_source_code=original_source_code, + modified_source_code=modified_source_code, + test_diffs=test_diffs, + trace_id=trace_id, + past_trials=past_trials, + ) + return executor.submit(ai_service_client.optimize_python_code_repair, request=request) + def log_successful_optimization( self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str ) -> None: @@ -1752,6 +1840,11 @@ def establish_original_code_baseline( ) ) + def get_results_not_matched_error(self) -> Failure: + logger.info("h4|Test results did not match the test results of the original code ❌") + console.rule() + return Failure("Test results did not match the test results of the original code.") + def run_optimized_candidate( self, *, @@ -1759,6 +1852,9 @@ def run_optimized_candidate( baseline_results: OriginalCodeBaseline, original_helper_code: dict[Path, str], file_path_to_helper_classes: dict[Path, set[str]], + code_context: CodeOptimizationContext, + candidate: OptimizedCandidate, + exp_type: str, ) -> Result[OptimizedCandidateResult, str]: assert (test_framework := self.args.test_framework) in {"pytest", "unittest"} # noqa: RUF018 @@ -1808,13 +1904,35 @@ def run_optimized_candidate( ) ) console.rule() - if compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results): + # print(type(code_context), type(candidate)) + match, diffs = compare_test_results(baseline_results.behavior_test_results, candidate_behavior_results) + if match: logger.info("h3|Test results matched ✅") console.rule() else: - logger.info("h4|Test results did not match the test results of the original code ❌") - console.rule() - return Failure("Test results did not match the test results of the original code.") + result_unmatched_perc = len(diffs) / len(candidate_behavior_results) + if candidate.source == OptimizedCandidateSource.REPAIR or result_unmatched_perc > 0.5: + # if the test unmatched percentage is greater than 50%, we can't fix it + return self.get_results_not_matched_error() + + ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client + logger.info("Adding this to the repair queue") + for _ in range(3): + # first candidate + self.repair_optimization( + original_source_code=code_context.read_writable_code.markdown, + modified_source_code=candidate.source_code.markdown, + test_diffs=diffs, + trace_id=self.function_trace_id[:-4] + exp_type + if self.experiment_id + else self.function_trace_id, + ai_service_client=ai_service_client, + optimization_id=candidate.optimization_id, + executor=self.executor, + past_trials="", + ) + # behavior to test, if pass break + return self.get_results_not_matched_error() logger.info(f"loading|Running performance tests for candidate {optimization_candidate_index}...") diff --git a/codeflash/verification/equivalence.py b/codeflash/verification/equivalence.py index 9d7f5ba2c..fc8f26445 100644 --- a/codeflash/verification/equivalence.py +++ b/codeflash/verification/equivalence.py @@ -1,27 +1,46 @@ +from __future__ import annotations + import sys +from typing import TYPE_CHECKING from codeflash.cli_cmds.console import logger -from codeflash.models.models import TestResults, TestType, VerificationType +from codeflash.models.models import TestDiff, TestDiffScope, TestResults, TestType, VerificationType from codeflash.verification.comparator import comparator +if TYPE_CHECKING: + from codeflash.models.models import TestResults + INCREASED_RECURSION_LIMIT = 5000 -def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> bool: +def compare_test_results(original_results: TestResults, candidate_results: TestResults) -> tuple[bool, list[TestDiff]]: # This is meant to be only called with test results for the first loop index if len(original_results) == 0 or len(candidate_results) == 0: - return False # empty test results are not equal + return False, [] # empty test results are not equal original_recursion_limit = sys.getrecursionlimit() if original_recursion_limit < INCREASED_RECURSION_LIMIT: sys.setrecursionlimit(INCREASED_RECURSION_LIMIT) # Increase recursion limit to avoid RecursionError test_ids_superset = original_results.get_all_unique_invocation_loop_ids().union( set(candidate_results.get_all_unique_invocation_loop_ids()) ) - are_equal: bool = True + test_diffs: list[TestDiff] = [] did_all_timeout: bool = True for test_id in test_ids_superset: original_test_result = original_results.get_by_unique_invocation_loop_id(test_id) cdd_test_result = candidate_results.get_by_unique_invocation_loop_id(test_id) + candidate_test_failures = candidate_results.test_failures + original_test_failures = original_results.test_failures + cdd_pytest_error = ( + candidate_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if candidate_test_failures + else "" + ) + original_pytest_error = ( + original_test_failures.get(original_test_result.id.test_fn_qualified_name(), "") + if original_test_failures + else "" + ) + if cdd_test_result is not None and original_test_result is None: continue # If helper function instance_state verification is not present, that's ok. continue @@ -32,8 +51,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR ): continue if original_test_result is None or cdd_test_result is None: - are_equal = False - break + continue did_all_timeout = did_all_timeout and original_test_result.timed_out if original_test_result.timed_out: continue @@ -43,42 +61,53 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO} ): superset_obj = True + + test_src_code = original_test_result.id.get_src_code(original_test_result.file_name) + test_diff = TestDiff( + scope=TestDiffScope.RETURN_VALUE, + original_value=repr(original_test_result.return_value), + candidate_value=repr(cdd_test_result.return_value), + test_src_code=test_src_code, + candidate_pytest_error=cdd_pytest_error, + original_pass=original_test_result.did_pass, + candidate_pass=cdd_test_result.did_pass, + original_pytest_error=original_pytest_error, + ) if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj): - are_equal = False + test_diff.scope = TestDiffScope.RETURN_VALUE + test_diffs.append(test_diff) + try: logger.debug( - "File Name: %s\n" - "Test Type: %s\n" - "Verification Type: %s\n" - "Invocation ID: %s\n" - "Original return value: %s\n" - "Candidate return value: %s\n" - "-------------------", - original_test_result.file_name, - original_test_result.test_type, - original_test_result.verification_type, - original_test_result.id, - original_test_result.return_value, - cdd_test_result.return_value, + f"File Name: {original_test_result.file_name}\n" + f"Test Type: {original_test_result.test_type}\n" + f"Verification Type: {original_test_result.verification_type}\n" + f"Invocation ID: {original_test_result.id}\n" + f"Original return value: {original_test_result.return_value}\n" + f"Candidate return value: {cdd_test_result.return_value}\n" ) except Exception as e: logger.error(e) - break - if (original_test_result.stdout and cdd_test_result.stdout) and not comparator( + elif (original_test_result.stdout and cdd_test_result.stdout) and not comparator( original_test_result.stdout, cdd_test_result.stdout ): - are_equal = False - break + test_diff.scope = TestDiffScope.STDOUT + test_diff.original_value = str(original_test_result.stdout) + test_diff.candidate_value = str(cdd_test_result.stdout) + test_diffs.append(test_diff) - if original_test_result.test_type in { + elif original_test_result.test_type in { TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST, TestType.GENERATED_REGRESSION, TestType.REPLAY_TEST, } and (cdd_test_result.did_pass != original_test_result.did_pass): - are_equal = False - break + test_diff.scope = TestDiffScope.DID_PASS + test_diff.original_value = str(original_test_result.did_pass) + test_diff.candidate_value = str(cdd_test_result.did_pass) + test_diffs.append(test_diff) + sys.setrecursionlimit(original_recursion_limit) if did_all_timeout: - return False - return are_equal + return False, test_diffs + return len(test_diffs) == 0, test_diffs diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index ef513a0a3..f5cdad9d1 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -512,6 +512,61 @@ def merge_test_results( return merged_test_results +FAILURES_HEADER_RE = re.compile(r"=+ FAILURES =+") +TEST_HEADER_RE = re.compile(r"_{3,}\s*(.*?)\s*_{3,}$") + + +def parse_test_failures_from_stdout(test_results: TestResults, stdout: str) -> TestResults: + """Extract individual pytest test failures from stdout grouped by test case qualified name, and add them to the test results.""" + lines = stdout.splitlines() + start = end = None + + for i, line in enumerate(lines): + if FAILURES_HEADER_RE.search(line.strip()): + start = i + break + + if start is None: + return test_results + + for j in range(start + 1, len(lines)): + stripped = lines[j].strip() + if "short test summary info" in stripped: + end = j + break + # any new === section === block + if stripped.startswith("=") and stripped.count("=") > 3: + end = j + break + + # If no clear "end", just grap the rest of the string + if end is None: + end = len(lines) + + failure_block = lines[start:end] + + failures: dict[str, str] = {} + current_name = None + current_lines: list[str] = [] + + for line in failure_block: + m = TEST_HEADER_RE.match(line.strip()) + if m: + if current_name is not None: + failures[current_name] = "".join(current_lines) + + current_name = m.group(1) + current_lines = [] + elif current_name: + current_lines.append(line + "\n") + + if current_name: + failures[current_name] = "".join(current_lines) + + test_results.test_failures = failures + return test_results + + def parse_test_results( test_xml_path: Path, test_files: TestFiles, @@ -572,4 +627,9 @@ def parse_test_results( function_name=function_name, ) coverage.log_coverage() + try: + parse_test_failures_from_stdout(results, run_result.stdout) + except Exception as e: + logger.exception(e) + return results, coverage if all_args else None diff --git a/tests/test_codeflash_capture.py b/tests/test_codeflash_capture.py index c326cecc4..79133bc15 100644 --- a/tests/test_codeflash_capture.py +++ b/tests/test_codeflash_capture.py @@ -502,7 +502,8 @@ def __init__(self, x=2): pytest_max_loops=1, testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -626,7 +627,8 @@ def __init__(self, *args, **kwargs): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -754,7 +756,8 @@ def __init__(self, x=2): testing_time=0.1, ) - assert compare_test_results(test_results, test_results2) + match, _ = compare_test_results(test_results, test_results2) + assert match finally: test_path.unlink(missing_ok=True) sample_code_path.unlink(missing_ok=True) @@ -902,7 +905,8 @@ def another_helper(self): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: test_path.unlink(missing_ok=True) @@ -1132,7 +1136,8 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert not compare_test_results(test_results, mutated_test_results) + match, _ = compare_test_results(test_results, mutated_test_results) + assert not match # This fto code stopped using a helper class. it should still pass no_helper1_fto_code = """ @@ -1170,10 +1175,304 @@ def target_function(self): ) # Remove instrumentation FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) - assert compare_test_results(test_results, no_helper1_test_results) + match, _ = compare_test_results(test_results, no_helper1_test_results) + assert match finally: test_path.unlink(missing_ok=True) fto_file_path.unlink(missing_ok=True) helper_path_1.unlink(missing_ok=True) helper_path_2.unlink(missing_ok=True) + +def test_instrument_codeflash_capture_and_run_tests_2() -> None: + # End to end run that instruments code and runs tests. Made to be similar to code used in the optimizer.py + test_code = """import math +import pytest +from typing import List, Tuple, Optional +from code_to_optimize.tests.pytest.fto_file import calculate_portfolio_metrics + +def test_calculate_portfolio_metrics(): + # Test case 1: Basic portfolio + investments = [ + ('Stocks', 0.6, 0.12), + ('Bonds', 0.3, 0.04), + ('Cash', 0.1, 0.01) + ] + + result = calculate_portfolio_metrics(investments) + + # Check weighted return calculation + expected_return = 0.6*0.12 + 0.3*0.04 + 0.1*0.01 + assert abs(result['weighted_return'] - expected_return) < 1e-10 + + # Check volatility calculation + expected_vol = math.sqrt((0.6*0.12)**2 + (0.3*0.04)**2 + (0.1*0.01)**2) + assert abs(result['volatility'] - expected_vol) < 1e-10 + + # Check Sharpe ratio + expected_sharpe = (expected_return - 0.02) / expected_vol + assert abs(result['sharpe_ratio'] - expected_sharpe) < 1e-10 + + # Check best/worst performers + assert result['best_performing'][0] == 'Stocks' + assert result['worst_performing'][0] == 'Cash' + assert result['total_assets'] == 3 + +def test_empty_investments(): + with pytest.raises(ValueError, match="Investments list cannot be empty"): + calculate_portfolio_metrics([]) + +def test_weights_not_sum_to_one(): + investments = [('Stock', 0.5, 0.1), ('Bond', 0.4, 0.05)] + with pytest.raises(ValueError, match="Portfolio weights must sum to 1.0"): + calculate_portfolio_metrics(investments) + +def test_zero_volatility(): + investments = [('Cash', 1.0, 0.0)] + result = calculate_portfolio_metrics(investments, risk_free_rate=0.0) + assert result['sharpe_ratio'] == 0.0 + assert result['volatility'] == 0.0 +""" + + original_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + if abs(sum(weight for _, weight, _ in investments) - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Calculate weighted return + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Calculate portfolio volatility (simplified) + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Calculate Sharpe ratio + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Find best and worst performing assets + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + 'weighted_return': round(weighted_return, 6), + 'volatility': round(volatility, 6), + 'sharpe_ratio': round(sharpe_ratio, 6), + 'best_performing': (best_asset[0], round(best_asset[2], 6)), + 'worst_performing': (worst_asset[0], round(worst_asset[2], 6)), + 'total_assets': len(investments) + } +""" + test_dir = (Path(__file__).parent.parent / "code_to_optimize" / "tests" / "pytest").resolve() + test_file_name = "test_multiple_helpers.py" + + fto_file_name = "fto_file.py" + + test_path = test_dir / test_file_name + test_path_perf = test_dir / "test_multiple_helpers_perf.py" + fto_file_path = test_dir / fto_file_name + + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + + try: + with fto_file_path.open("w") as f: + f.write(original_code) + with test_path.open("w") as f: + f.write(test_code) + + fto = FunctionToOptimize("calculate_portfolio_metrics", fto_file_path, parents=[]) + file_path_to_helper_class = { + } + instrument_codeflash_capture(fto, file_path_to_helper_class, tests_root) + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + + test_type = TestType.EXISTING_UNIT_TEST + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=fto, test_cfg=test_config) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=test_path, + test_type=test_type, + original_file_path=test_path, + benchmarking_file_path=test_path_perf, + ) + ] + ) + # Code in optimizer.py + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = { + } + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + + test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + + # Now, let's say we optimize the code and make changes. + new_fto_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + total_weight = sum(w for _, w, _ in investments) + if total_weight != 1.0: # Should use tolerance check + raise ValueError("Portfolio weights must sum to 1.0") + + weighted_return = 1.0 + for _, weight, ret in investments: + weighted_return *= (1 + ret) ** weight + weighted_return = weighted_return - 1.0 # Convert back from geometric + + returns = [r for _, _, r in investments] + mean_return = sum(returns) / len(returns) + volatility = math.sqrt(sum((r - mean_return) ** 2 for r in returns) / len(returns)) + + # BUG 4: Sharpe ratio calculation is correct but uses wrong inputs + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + def risk_adjusted_return(return_val, weight): + return (return_val - risk_free_rate) / (weight * return_val) if weight * return_val != 0 else return_val + + best_asset = max(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + worst_asset = min(investments, key=lambda x: risk_adjusted_return(x[2], x[1])) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": 2, + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fto_code) + # Instrument codeflash capture + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = {} + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results) + + assert not matched + + new_fixed_code = """import math +from typing import List, Tuple, Optional + +def calculate_portfolio_metrics( + investments: List[Tuple[str, float, float]], + risk_free_rate: float = 0.02 +) -> dict: + if not investments: + raise ValueError("Investments list cannot be empty") + + # Tolerant weight check (matches original) + total_weight = sum(weight for _, weight, _ in investments) + if abs(total_weight - 1.0) > 1e-10: + raise ValueError("Portfolio weights must sum to 1.0") + + # Same weighted return as original + weighted_return = sum(weight * ret for _, weight, ret in investments) + + # Same volatility formula as original + volatility = math.sqrt(sum((weight * ret) ** 2 for _, weight, ret in investments)) + + # Same Sharpe ratio logic + if volatility == 0: + sharpe_ratio = 0.0 + else: + sharpe_ratio = (weighted_return - risk_free_rate) / volatility + + # Same best/worst logic (based on return only) + best_asset = max(investments, key=lambda x: x[2]) + worst_asset = min(investments, key=lambda x: x[2]) + + return { + "weighted_return": round(weighted_return, 6), + "volatility": round(volatility, 6), + "sharpe_ratio": round(sharpe_ratio, 6), + "best_performing": (best_asset[0], round(best_asset[2], 6)), + "worst_performing": (worst_asset[0], round(worst_asset[2], 6)), + "total_assets": len(investments), + } +""" + with fto_file_path.open("w") as f: + f.write(new_fixed_code) + candidate_fto_code = Path(fto.file_path).read_text("utf-8") + candidate_helper_code = {} + for file_path in file_path_to_helper_class: + candidate_helper_code[file_path] = Path(file_path).read_text("utf-8") + file_path_to_helper_classes = {} + instrument_codeflash_capture(fto, file_path_to_helper_classes, tests_root) + modified_test_results_2, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + ) + # Remove instrumentation + FunctionOptimizer.write_code_and_helpers(candidate_fto_code, candidate_helper_code, fto.file_path) + matched, diffs = compare_test_results(test_results, modified_test_results_2) + # now the test should match and no diffs should be found + assert len(diffs) == 0 + assert matched + + finally: + test_path.unlink(missing_ok=True) + fto_file_path.unlink(missing_ok=True) \ No newline at end of file diff --git a/tests/test_comparator.py b/tests/test_comparator.py index 06d178f95..6c2781229 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -1176,7 +1176,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_1) + match, _ = compare_test_results(original_results, new_results_1) + assert match new_results_2 = TestResults() new_results_2.add( @@ -1199,7 +1200,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_2) + match, _ = compare_test_results(original_results, new_results_2) + assert not match new_results_3 = TestResults() new_results_3.add( @@ -1241,7 +1243,8 @@ def test_compare_results_fn(): ) ) - assert compare_test_results(original_results, new_results_3) + match, _ = compare_test_results(original_results, new_results_3) + assert match new_results_4 = TestResults() new_results_4.add( @@ -1264,7 +1267,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(original_results, new_results_4) + match, _ = compare_test_results(original_results, new_results_4) + assert not match new_results_5_baseline = TestResults() new_results_5_baseline.add( @@ -1308,7 +1312,8 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_5_baseline, new_results_5_opt) + match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt) + assert not match new_results_6_baseline = TestResults() new_results_6_baseline.add( @@ -1352,9 +1357,11 @@ def test_compare_results_fn(): ) ) - assert not compare_test_results(new_results_6_baseline, new_results_6_opt) + match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt) + assert not match - assert not compare_test_results(TestResults(), TestResults()) + match, _ = compare_test_results(TestResults(), TestResults()) + assert not match def test_exceptions(): diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index ece7d38b0..7bdfa364b 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -223,7 +223,8 @@ def test_sort(): result: [0, 1, 2, 3, 4, 5] """ assert out_str == results2[0].stdout - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) @@ -368,7 +369,8 @@ def test_sort(): assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],) out_str = """codeflash stdout : BubbleSorter.sorter() called\n""" assert test_results[1].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[2].id.function_getting_tested == "BubbleSorter.__init__" assert test_results[2].id.test_function_name == "test_sort" assert test_results[2].did_pass @@ -396,7 +398,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match # Replace with optimized code that mutated instance attribute optimized_code = """ @@ -491,7 +494,8 @@ def sorter(self, arr): ) assert new_test_results[3].runtime > 0 assert new_test_results[3].did_pass - assert not compare_test_results(test_results, new_test_results) + match, _ = compare_test_results(test_results, new_test_results) + assert not match finally: fto_path.write_text(original_code, "utf-8") @@ -630,7 +634,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_classmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_classmethod" assert test_results[1].id.iteration_id == "4_0" @@ -655,7 +660,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") @@ -794,7 +800,8 @@ def test_sort(): out_str = """codeflash stdout : BubbleSorter.sorter_staticmethod() called """ assert test_results[0].stdout == out_str - assert compare_test_results(test_results, test_results) + match, _ = compare_test_results(test_results, test_results) + assert match assert test_results[1].id.function_getting_tested == "BubbleSorter.sorter_staticmethod" assert test_results[1].id.iteration_id == "4_0" @@ -819,7 +826,8 @@ def test_sort(): testing_time=0.1, ) - assert compare_test_results(test_results, results2) + match, _ = compare_test_results(test_results, results2) + assert match finally: fto_path.write_text(original_code, "utf-8") diff --git a/tests/test_instrumentation_run_results_aiservice.py b/tests/test_instrumentation_run_results_aiservice.py index cae2c76f1..03556718d 100644 --- a/tests/test_instrumentation_run_results_aiservice.py +++ b/tests/test_instrumentation_run_results_aiservice.py @@ -221,10 +221,10 @@ def sorter(self, arr): testing_time=0.1, ) # assert test_results_mutated_attr[0].return_value[1]["self"].x == 1 TODO: add self as input to function - assert compare_test_results( + match, _ = compare_test_results( test_results, test_results_mutated_attr ) # Without codeflash capture, the init state was not verified, and the results are verified as correct even with the attribute mutated - + assert match assert test_results_mutated_attr[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n" finally: fto_path.write_text(original_code, "utf-8") @@ -403,9 +403,10 @@ def sorter(self, arr): assert test_results_mutated_attr[0].return_value[0] == {"x": 1} assert test_results_mutated_attr[0].verification_type == VerificationType.INIT_STATE_FTO assert test_results_mutated_attr[0].stdout == "" - assert not compare_test_results( + match,_ = compare_test_results( test_results, test_results_mutated_attr ) # The test should fail because the instance attribute was mutated + assert not match # Replace with optimized code that did not mutate existing instance attribute, but added a new one optimized_code_new_attr = """ import sys @@ -457,9 +458,10 @@ def sorter(self, arr): assert test_results_new_attr[0].stdout == "" # assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input # assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input - assert compare_test_results( + match,_ = compare_test_results( test_results, test_results_new_attr ) # The test should pass because the instance attribute was not mutated, only a new one was added + assert match finally: fto_path.write_text(original_code, "utf-8") test_path.unlink(missing_ok=True) diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index c67883c12..c05384d03 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -427,8 +427,8 @@ def bubble_sort_with_unused_socket(data_container): testing_time=1.0, ) assert len(optimized_test_results_unused_socket) == 1 - verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) - assert verification_result is True + match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) + assert match # Remove the previous instrumentation replay_test_path.write_text(original_replay_test_code) @@ -517,8 +517,8 @@ def bubble_sort_with_used_socket(data_container): assert test_results_used_socket.test_results[0].did_pass is False # Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined. - assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False - + match, _ = compare_test_results(test_results_used_socket, optimized_test_results_used_socket) + assert not match finally: # cleanup output_file.unlink(missing_ok=True)