diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index abbcb68c1..7e8983b3b 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -4,7 +4,7 @@ import os import re from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import libcst as cst from libcst import MetadataWrapper @@ -149,7 +149,7 @@ def leave_SimpleStatementSuite( return updated_node -def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, int]: +def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]: unique_inv_ids: dict[str, int] = {} for inv_id, runtimes in inv_id_runtimes.items(): test_qualified_name = ( @@ -157,10 +157,11 @@ def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, i if inv_id.test_class_name else inv_id.test_function_name ) - abs_path = str(Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve().with_suffix("")) - if "__unit_test_" not in abs_path: + abs_path = tests_project_rootdir / Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py") + abs_path_str = str(abs_path.resolve().with_suffix("")) + if "__unit_test_" not in abs_path_str or not test_qualified_name: continue - key = test_qualified_name + "#" + abs_path # type: ignore[operator] + key = test_qualified_name + "#" + abs_path_str parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr] cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr] match_key = key + "#" + cur_invid @@ -174,10 +175,11 @@ def add_runtime_comments_to_generated_tests( generated_tests: GeneratedTestsList, original_runtimes: dict[InvocationId, list[int]], optimized_runtimes: dict[InvocationId, list[int]], + tests_project_rootdir: Optional[Path] = None, ) -> GeneratedTestsList: """Add runtime performance comments to function calls in generated tests.""" - original_runtimes_dict = unique_inv_id(original_runtimes) - optimized_runtimes_dict = unique_inv_id(optimized_runtimes) + original_runtimes_dict = unique_inv_id(original_runtimes, tests_project_rootdir or Path()) + optimized_runtimes_dict = unique_inv_id(optimized_runtimes, tests_project_rootdir or Path()) # Process each generated test modified_tests = [] for test in generated_tests.generated_tests: diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index c54ceaafd..10744c631 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -338,21 +338,19 @@ def initialize_function_optimization( ) -> dict[str, str]: document_uri = params.textDocument.uri document = server.workspace.get_text_document(document_uri) + file_path = Path(document.path) server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info") if server.optimizer is None: _initialize_optimizer_if_api_key_is_valid(server) - server.optimizer.worktree_mode() - - original_args, _ = server.optimizer.original_args_and_test_cfg - + server.optimizer.args.file = file_path server.optimizer.args.function = params.functionName - original_relative_file_path = Path(document.path).relative_to(original_args.project_root) - server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path server.optimizer.args.previous_checkpoint_functions = False + server.optimizer.worktree_mode() + server.show_message_log( f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info" ) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 86e9bf33f..1825798bd 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1376,7 +1376,7 @@ def process_review( ) generated_tests = add_runtime_comments_to_generated_tests( - generated_tests, original_runtime_by_test, optimized_runtime_by_test + generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir ) generated_tests_str = "\n#------------------------------------------------\n".join(