From eb19db3a2dc8bb37e84cfca670ecfe925a06c45a Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 6 May 2026 21:59:09 -0500 Subject: [PATCH 1/5] perf: use set for O(1) membership check in TestFiles.add() The previous implementation used `if test_file not in self.test_files` which performs O(n) equality comparison across all TestFile objects. Replace with a `_seen_paths` set keyed on `instrumented_behavior_file_path` for O(1) deduplication lookups. --- codeflash/models/models.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 640e5230a..1fd6fbd7d 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -426,9 +426,15 @@ class TestFile(BaseModel): class TestFiles(BaseModel): test_files: list[TestFile] + _seen_paths: set[Path] = PrivateAttr(default_factory=set) + + def model_post_init(self, __context: Any, /) -> None: + self._seen_paths = {tf.instrumented_behavior_file_path for tf in self.test_files} def add(self, test_file: TestFile) -> None: - if test_file not in self.test_files: + key = test_file.instrumented_behavior_file_path + if key not in self._seen_paths: + self._seen_paths.add(key) self.test_files.append(test_file) else: msg = "Test file already exists in the list" From 0dbdc1cfc923cc2c18daa73741190a1282d9f4de Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 6 May 2026 21:59:12 -0500 Subject: [PATCH 2/5] test: add unit tests for TestFiles.add() deduplication --- tests/test_test_files_add.py | 46 ++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 tests/test_test_files_add.py diff --git a/tests/test_test_files_add.py b/tests/test_test_files_add.py new file mode 100644 index 000000000..be4305b6a --- /dev/null +++ b/tests/test_test_files_add.py @@ -0,0 +1,46 @@ +from pathlib import Path + +import pytest + +from codeflash.models.models import TestFile, TestFiles +from codeflash.models.test_type import TestType + + +class TestTestFilesAdd: + def test_add_unique_test_file(self) -> None: + tf = TestFiles(test_files=[]) + test_file = TestFile( + instrumented_behavior_file_path=Path("/tmp/test_behavior.py"), + benchmarking_file_path=Path("/tmp/test_perf.py"), + test_type=TestType.GENERATED_REGRESSION, + ) + tf.add(test_file) + assert len(tf.test_files) == 1 + assert tf.test_files[0] is test_file + + def test_add_duplicate_raises(self) -> None: + tf = TestFiles(test_files=[]) + test_file = TestFile( + instrumented_behavior_file_path=Path("/tmp/test_behavior.py"), + benchmarking_file_path=Path("/tmp/test_perf.py"), + test_type=TestType.GENERATED_REGRESSION, + ) + tf.add(test_file) + with pytest.raises(ValueError, match="Test file already exists"): + tf.add(test_file) + + def test_add_many_files_performance(self) -> None: + tf = TestFiles(test_files=[]) + for i in range(100): + test_file = TestFile( + instrumented_behavior_file_path=Path(f"/tmp/test_behavior_{i}.py"), + benchmarking_file_path=Path(f"/tmp/test_perf_{i}.py"), + test_type=TestType.GENERATED_REGRESSION, + ) + tf.add(test_file) + + assert len(tf.test_files) == 100 + assert len(tf._seen_paths) == 100 + # Verify all paths are unique in the set + expected_paths = {Path(f"/tmp/test_behavior_{i}.py") for i in range(100)} + assert tf._seen_paths == expected_paths From 7364daf7d4c077e3cc3af45a342e07a75cabbc86 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 6 May 2026 23:03:56 -0500 Subject: [PATCH 3/5] fix: silently skip duplicate TestFiles.add() instead of raising Java test discovery adds the same instrumented_behavior_file_path with different fields. The old Pydantic __eq__ treated them as different; the set-based dedup correctly identifies them as the same test file. Since get_test_type_by_instrumented_file_path() returns on first path match anyway, duplicates by path are dead weight. Silent skip (first write wins) is both correct and O(1). --- codeflash/models/models.py | 3 --- tests/test_test_files_add.py | 8 +++----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 1fd6fbd7d..09cdf1cfe 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -436,9 +436,6 @@ def add(self, test_file: TestFile) -> None: if key not in self._seen_paths: self._seen_paths.add(key) self.test_files.append(test_file) - else: - msg = "Test file already exists in the list" - raise ValueError(msg) def get_by_original_file_path(self, file_path: Path) -> TestFile | None: normalized = self._normalize_path_for_comparison(file_path) diff --git a/tests/test_test_files_add.py b/tests/test_test_files_add.py index be4305b6a..05e3710d6 100644 --- a/tests/test_test_files_add.py +++ b/tests/test_test_files_add.py @@ -1,7 +1,5 @@ from pathlib import Path -import pytest - from codeflash.models.models import TestFile, TestFiles from codeflash.models.test_type import TestType @@ -18,7 +16,7 @@ def test_add_unique_test_file(self) -> None: assert len(tf.test_files) == 1 assert tf.test_files[0] is test_file - def test_add_duplicate_raises(self) -> None: + def test_add_duplicate_is_noop(self) -> None: tf = TestFiles(test_files=[]) test_file = TestFile( instrumented_behavior_file_path=Path("/tmp/test_behavior.py"), @@ -26,8 +24,8 @@ def test_add_duplicate_raises(self) -> None: test_type=TestType.GENERATED_REGRESSION, ) tf.add(test_file) - with pytest.raises(ValueError, match="Test file already exists"): - tf.add(test_file) + tf.add(test_file) # silent skip — first write wins + assert len(tf.test_files) == 1 def test_add_many_files_performance(self) -> None: tf = TestFiles(test_files=[]) From a563e702fa413401576bbb71bdef60147564dcf5 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 6 May 2026 22:57:55 -0500 Subject: [PATCH 4/5] fix: resolve mypy strict errors in models.py --- codeflash/models/models.py | 39 ++++++++++++++++++++------------------ 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 09cdf1cfe..3338f3841 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -9,7 +9,7 @@ from functools import lru_cache from pathlib import Path from re import Pattern -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast +from typing import TYPE_CHECKING, Any, NamedTuple, Optional from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator from pydantic.dataclasses import dataclass @@ -17,7 +17,7 @@ from codeflash.models.test_type import TestType if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Generator import libcst as cst from rich.tree import Tree @@ -298,11 +298,13 @@ def flat(self) -> str: """ if self._cache.get("flat") is not None: - return self._cache["flat"] - self._cache["flat"] = "\n".join( + result: str = self._cache["flat"] + return result + flat: str = "\n".join( get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings ) - return self._cache["flat"] + self._cache["flat"] = flat + return flat @property def markdown(self) -> str: @@ -332,7 +334,8 @@ def file_to_path(self) -> dict[str, str]: """ try: - return self._cache["file_to_path"] + cached: dict[str, str] = self._cache["file_to_path"] + return cached except KeyError: mapping = {str(code_string.file_path): code_string.code for code_string in self.code_strings} self._cache["file_to_path"] = mapping @@ -497,8 +500,8 @@ def _normalize_path_for_comparison(path: Path) -> str: # Only lowercase on Windows where filesystem is case-insensitive return resolved.lower() if sys.platform == "win32" else resolved - def __iter__(self) -> Iterator[TestFile]: - return iter(self.test_files) + def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058 + yield from self.test_files def __len__(self) -> int: return len(self.test_files) @@ -517,9 +520,9 @@ class CandidateEvaluationContext: optimized_runtimes: dict[str, float | None] = Field(default_factory=dict) is_correct: dict[str, bool] = Field(default_factory=dict) optimized_line_profiler_results: dict[str, str] = Field(default_factory=dict) - ast_code_to_id: dict = Field(default_factory=dict) + ast_code_to_id: dict[str, Any] = Field(default_factory=dict) optimizations_post: dict[str, str] = Field(default_factory=dict) - valid_optimizations: list = Field(default_factory=list) + valid_optimizations: list[Any] = Field(default_factory=list) def record_failed_candidate(self, optimization_id: str) -> None: """Record results for a failed candidate.""" @@ -546,7 +549,7 @@ def handle_duplicate_candidate( # Copy results from the previous evaluation (use .get() in case past_opt_id was registered # but never benchmarked due to an unhandled exception in process_single_candidate) self.speedup_ratios[candidate.optimization_id] = self.speedup_ratios.get(past_opt_id) - self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id) + self.is_correct[candidate.optimization_id] = self.is_correct.get(past_opt_id, False) self.optimized_runtimes[candidate.optimization_id] = self.optimized_runtimes.get(past_opt_id) # Line profiler results only available for successful runs @@ -634,7 +637,7 @@ class OriginalCodeBaseline(BaseModel): behavior_test_results: TestResults benchmarking_test_results: TestResults replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None - line_profile_results: dict + line_profile_results: dict[str, Any] runtime: int coverage_results: Optional[CoverageData] async_throughput: Optional[int] = None @@ -796,7 +799,7 @@ def get_src_code(self, test_path: Path) -> Optional[str]: f"// Testing function: {self.function_getting_tested}" ) - if self.test_class_name: + if self.test_class_name and self.test_function_name: for stmt in module_node.body: if isinstance(stmt, cst.ClassDef) and stmt.name.value == self.test_class_name: func_node = self.find_func_in_class(stmt, self.test_function_name) @@ -887,7 +890,7 @@ def group_by_benchmarks( """Group TestResults by benchmark for calculating improvements for each benchmark.""" from codeflash.code_utils.code_utils import module_name_from_file_path - test_results_by_benchmark = defaultdict(TestResults) + test_results_by_benchmark: defaultdict[BenchmarkKey, TestResults] = defaultdict(TestResults) benchmark_module_path = {} for benchmark_key in benchmark_keys: benchmark_module_path[benchmark_key] = module_name_from_file_path( @@ -1018,7 +1021,7 @@ def effective_loop_count(self) -> int: return max(loop_indices) if loop_indices else 0 def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]: - map_gen_test_file_to_no_of_tests = Counter() + map_gen_test_file_to_no_of_tests: Counter[Path] = Counter() for gen_test_result in self.test_results: if ( gen_test_result.test_type == TestType.GENERATED_REGRESSION @@ -1027,8 +1030,8 @@ def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Pa map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1 return map_gen_test_file_to_no_of_tests - def __iter__(self) -> Iterator[FunctionTestInvocation]: - return iter(self.test_results) + def __iter__(self) -> Generator[Any, None, None]: # noqa: PYI058 + yield from self.test_results def __len__(self) -> int: return len(self.test_results) @@ -1054,7 +1057,7 @@ def __eq__(self, other: object) -> bool: if len(self) != len(other): return False original_recursion_limit = sys.getrecursionlimit() - cast("TestResults", other) + assert isinstance(other, TestResults) for test_result in self: other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id) if other_test_result is None: From 9459c5a3a88dea485824b3c78d1cb4e610c2d0bc Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Wed, 6 May 2026 23:46:53 -0500 Subject: [PATCH 5/5] fix: surface subprocess errors when pytest XML is missing When the test subprocess exits non-zero and produces no JUnit XML, log the return code and stdout/stderr at WARNING level so the root cause is visible in CI logs. Previously this was a generic "No test results found" message that made Windows CI flakes impossible to diagnose. Also fixes pre-existing mypy strict errors in parse_xml.py: - Add return type to _parse_func - Type CompletedProcess[str] (subprocess uses text=True) - Parameterize generic types (tuple, re.Match) - Remove dead .decode() branches (stdout is already str) --- codeflash/languages/python/parse_xml.py | 34 ++++++++++++------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/codeflash/languages/python/parse_xml.py b/codeflash/languages/python/parse_xml.py index 840fa2055..b080608b4 100644 --- a/codeflash/languages/python/parse_xml.py +++ b/codeflash/languages/python/parse_xml.py @@ -9,7 +9,7 @@ import os import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from junitparser.xunit2 import JUnitXml @@ -48,7 +48,7 @@ ) -def _parse_func(file_path: Path): +def _parse_func(file_path: Path) -> Any: from lxml.etree import XMLParser, parse xml_parser = XMLParser(huge_tree=True) @@ -59,13 +59,22 @@ def parse_python_test_xml( test_xml_file_path: Path, test_files: TestFiles, test_config: TestConfig, - run_result: subprocess.CompletedProcess | None = None, + run_result: subprocess.CompletedProcess[str] | None = None, ) -> TestResults: from codeflash.verification.parse_test_output import resolve_test_file_from_class_path test_results = TestResults() if not test_xml_file_path.exists(): - logger.warning(f"No test results for {test_xml_file_path} found.") + if run_result is not None and run_result.returncode != 0: + stderr_snippet = (run_result.stderr or "")[:500] + stdout_snippet = (run_result.stdout or "")[:500] + logger.warning( + f"No test results for {test_xml_file_path} found. " + f"Subprocess exited with code {run_result.returncode}.\n" + f"stdout: {stdout_snippet}\nstderr: {stderr_snippet}" + ) + else: + logger.warning(f"No test results for {test_xml_file_path} found.") console.rule() return test_results try: @@ -87,12 +96,7 @@ def parse_python_test_xml( ): logger.info("Test failed to load, skipping it.") if run_result is not None: - if isinstance(run_result.stdout, str) and isinstance(run_result.stderr, str): - logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}") - else: - logger.info( - f"Test log - STDOUT : {run_result.stdout.decode()} \n STDERR : {run_result.stderr.decode()}" - ) + logger.info(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}") return test_results test_class_path = testcase.classname @@ -159,7 +163,7 @@ def parse_python_test_xml( sys_stdout = testcase.system_out or "" begin_matches = list(matches_re_start.finditer(sys_stdout)) - end_matches: dict[tuple, re.Match] = {} + end_matches: dict[tuple[str, ...], re.Match[str]] = {} for match in matches_re_end.finditer(sys_stdout): groups = match.groups() if len(groups[5].split(":")) > 1: @@ -234,11 +238,5 @@ def parse_python_test_xml( f"Tests '{[test_file.original_file_path for test_file in test_files.test_files]}' failed to run, skipping" ) if run_result is not None: - stdout, stderr = "", "" - try: - stdout = run_result.stdout.decode() - stderr = run_result.stderr.decode() - except AttributeError: - stdout = run_result.stderr - logger.debug(f"Test log - STDOUT : {stdout} \n STDERR : {stderr}") + logger.debug(f"Test log - STDOUT : {run_result.stdout} \n STDERR : {run_result.stderr}") return test_results