Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import libcst as cst
from libcst import MetadataWrapper
Expand Down Expand Up @@ -149,15 +149,16 @@ def leave_SimpleStatementSuite(
return updated_node


def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, int]:
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
unique_inv_ids: dict[str, int] = {}
for inv_id, runtimes in inv_id_runtimes.items():
test_qualified_name = (
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
if inv_id.test_class_name
else inv_id.test_function_name
)
abs_path = str(Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve().with_suffix(""))
abs_path = tests_project_rootdir / Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py")
abs_path = str(abs_path.resolve().with_suffix(""))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to test with different test dirs, some which are a subdirectory of the module root some which are not. I remember i was using the test_dir early on but it didn't work for all scenarioes @mohammedahmed18

if "__unit_test_" not in abs_path:
continue
key = test_qualified_name + "#" + abs_path # type: ignore[operator]
Expand All @@ -174,10 +175,11 @@ def add_runtime_comments_to_generated_tests(
generated_tests: GeneratedTestsList,
original_runtimes: dict[InvocationId, list[int]],
optimized_runtimes: dict[InvocationId, list[int]],
tests_project_rootdir: Optional[Path] = None,
) -> GeneratedTestsList:
"""Add runtime performance comments to function calls in generated tests."""
original_runtimes_dict = unique_inv_id(original_runtimes)
optimized_runtimes_dict = unique_inv_id(optimized_runtimes)
original_runtimes_dict = unique_inv_id(original_runtimes, tests_project_rootdir or Path())
optimized_runtimes_dict = unique_inv_id(optimized_runtimes, tests_project_rootdir or Path())
# Process each generated test
modified_tests = []
for test in generated_tests.generated_tests:
Expand Down
9 changes: 3 additions & 6 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,12 @@ def initialize_function_optimization(
if server.optimizer is None:
_initialize_optimizer_if_api_key_is_valid(server)

server.optimizer.worktree_mode()

original_args, _ = server.optimizer.original_args_and_test_cfg

server.optimizer.args.file = file_path
server.optimizer.args.function = params.functionName
original_relative_file_path = file_path.relative_to(original_args.project_root)
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
server.optimizer.args.previous_checkpoint_functions = False

server.optimizer.worktree_mode()

server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)
Expand Down
2 changes: 1 addition & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,7 @@ def process_review(
)

generated_tests = add_runtime_comments_to_generated_tests(
generated_tests, original_runtime_by_test, optimized_runtime_by_test
generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir
)

generated_tests_str = "\n#------------------------------------------------\n".join(
Expand Down
Loading