Skip to content

Commit ac50ae2

Browse files
authored
Merge branch 'main' into cf-842
2 parents c01acc3 + daa7627 commit ac50ae2

File tree

5 files changed

+514
-11
lines changed

5 files changed

+514
-11
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,8 @@ def get_formatter_cmds(formatter: str) -> list[str]:
10051005
return ["your-formatter $file"]
10061006
if formatter in {"don't use a formatter", "disabled"}:
10071007
return ["disabled"]
1008+
if " && " in formatter:
1009+
return formatter.split(" && ")
10081010
return [formatter]
10091011

10101012

codeflash/lsp/beta.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,25 @@ def get_config_suggestions(_params: any) -> dict[str, any]:
205205
test_framework_suggestions, default_test_framework = get_suggestions(CommonSections.test_framework)
206206
formatter_suggestions, default_formatter = get_suggestions(CommonSections.formatter_cmds)
207207
get_valid_subdirs.cache_clear()
208+
209+
configured_module_root = Path(server.args.module_root).relative_to(Path.cwd()) if server.args.module_root else None
210+
configured_tests_root = Path(server.args.tests_root).relative_to(Path.cwd()) if server.args.tests_root else None
211+
configured_test_framework = server.args.test_framework if server.args.test_framework else None
212+
213+
configured_formatter = ""
214+
if isinstance(server.args.formatter_cmds, list):
215+
configured_formatter = " && ".join([cmd.strip() for cmd in server.args.formatter_cmds])
216+
elif isinstance(server.args.formatter_cmds, str):
217+
configured_formatter = server.args.formatter_cmds.strip()
218+
208219
return {
209-
"module_root": {"choices": module_root_suggestions, "default": default_module_root},
210-
"tests_root": {"choices": tests_root_suggestions, "default": default_tests_root},
211-
"test_framework": {"choices": test_framework_suggestions, "default": default_test_framework},
212-
"formatter_cmds": {"choices": formatter_suggestions, "default": default_formatter},
220+
"module_root": {"choices": module_root_suggestions, "default": configured_module_root or default_module_root},
221+
"tests_root": {"choices": tests_root_suggestions, "default": configured_tests_root or default_tests_root},
222+
"test_framework": {
223+
"choices": test_framework_suggestions,
224+
"default": configured_test_framework or default_test_framework,
225+
},
226+
"formatter_cmds": {"choices": formatter_suggestions, "default": configured_formatter or default_formatter},
213227
}
214228

215229

codeflash/models/models.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
3+
from collections import Counter, defaultdict
44
from typing import TYPE_CHECKING
55

66
from rich.tree import Tree
@@ -675,6 +675,16 @@ def total_passed_runtime(self) -> int:
675675
[min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()]
676676
)
677677

678+
def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
679+
map_gen_test_file_to_no_of_tests = Counter()
680+
for gen_test_result in self.test_results:
681+
if (
682+
gen_test_result.test_type == TestType.GENERATED_REGRESSION
683+
and gen_test_result.id.test_function_name not in test_functions_to_remove
684+
):
685+
map_gen_test_file_to_no_of_tests[gen_test_result.file_name] += 1
686+
return map_gen_test_file_to_no_of_tests
687+
678688
def __iter__(self) -> Iterator[FunctionTestInvocation]:
679689
return iter(self.test_results)
680690

codeflash/optimization/function_optimizer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,9 @@ def process_review(
14021402
generated_tests = remove_functions_from_generated_tests(
14031403
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
14041404
)
1405+
map_gen_test_file_to_no_of_tests = original_code_baseline.behavior_test_results.file_to_no_of_tests(
1406+
test_functions_to_remove
1407+
)
14051408

14061409
original_runtime_by_test = original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
14071410
optimized_runtime_by_test = (
@@ -1414,11 +1417,12 @@ def process_review(
14141417

14151418
generated_tests_str = ""
14161419
for test in generated_tests.generated_tests:
1417-
formatted_generated_test = format_generated_code(
1418-
test.generated_original_test_source, self.args.formatter_cmds
1419-
)
1420-
generated_tests_str += f"```python\n{formatted_generated_test}\n```"
1421-
generated_tests_str += "\n\n"
1420+
if map_gen_test_file_to_no_of_tests[test.behavior_file_path] > 0:
1421+
formatted_generated_test = format_generated_code(
1422+
test.generated_original_test_source, self.args.formatter_cmds
1423+
)
1424+
generated_tests_str += f"```python\n{formatted_generated_test}\n```"
1425+
generated_tests_str += "\n\n"
14221426

14231427
if concolic_test_str:
14241428
formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds)
@@ -1540,7 +1544,7 @@ def process_review(
15401544
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
15411545
)
15421546

1543-
# If worktree mode, do not revert code and helpers,, otherwise we would have an empty diff when writing the patch in the lsp
1547+
# If worktree mode, do not revert code and helpers, otherwise we would have an empty diff when writing the patch in the lsp
15441548
if self.args.worktree:
15451549
return
15461550

0 commit comments

Comments
 (0)