Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion codeflash/models/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections import defaultdict
from collections import Counter, defaultdict
from typing import TYPE_CHECKING

from rich.tree import Tree
Expand Down Expand Up @@ -675,6 +675,16 @@ def total_passed_runtime(self) -> int:
[min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()]
)

def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
map_gen_test_file_to_no_of_tests = Counter()
for gen_test_result in self.test_results:
if (
gen_test_result.test_type == TestType.GENERATED_REGRESSION
and gen_test_result.id.test_function_name not in test_functions_to_remove
):
map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1
return map_gen_test_file_to_no_of_tests

def __iter__(self) -> Iterator[FunctionTestInvocation]:
return iter(self.test_results)

Expand Down
16 changes: 10 additions & 6 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,9 @@ def process_review(
generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
)
map_gen_test_file_to_no_of_tests = original_code_baseline.behavior_test_results.file_to_no_of_tests(
test_functions_to_remove
)

original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
optimized_runtime_by_test = (
Expand All @@ -1413,11 +1416,12 @@ def process_review(

generated_tests_str = ""
for test in generated_tests.generated_tests:
formatted_generated_test = format_generated_code(
test.generated_original_test_source, self.args.formatter_cmds
)
generated_tests_str += f"```python\n{formatted_generated_test}\n```"
generated_tests_str += "\n\n"
if map_gen_test_file_to_no_of_tests[test.behavior_file_path] > 0:
formatted_generated_test = format_generated_code(
test.generated_original_test_source, self.args.formatter_cmds
)
generated_tests_str += f"```python\n{formatted_generated_test}\n```"
generated_tests_str += "\n\n"

if concolic_test_str:
formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds)
Expand Down Expand Up @@ -1537,7 +1541,7 @@ def process_review(
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
)

# If worktree mode, do not revert code and helpers,, otherwise we would have an empty diff when writing the patch in the lsp
# If worktree mode, do not revert code and helpers, otherwise we would have an empty diff when writing the patch in the lsp
if self.args.worktree:
return

Expand Down
Loading
Loading