diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index 92f2be8a1..c324b2b05 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -23,12 +23,18 @@ from git import Repo -def get_git_diff(repo_directory: Path | None = None, *, uncommitted_changes: bool = False) -> dict[str, list[int]]: +def get_git_diff( + repo_directory: Path | None = None, *, only_this_commit: Optional[str] = None, uncommitted_changes: bool = False +) -> dict[str, list[int]]: if repo_directory is None: repo_directory = Path.cwd() repository = git.Repo(repo_directory, search_parent_directories=True) commit = repository.head.commit - if uncommitted_changes: + if only_this_commit: + uni_diff_text = repository.git.diff( + only_this_commit + "^1", only_this_commit, ignore_blank_lines=True, ignore_space_at_eol=True + ) + elif uncommitted_changes: uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True) else: uni_diff_text = repository.git.diff( diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 27a46af0a..9f4db7b5e 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -232,7 +232,16 @@ def get_functions_to_optimize( def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001 modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes) - modified_functions: dict[str, list[FunctionToOptimize]] = {} + return get_functions_within_lines(modified_lines) + + +def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]: + modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash) + return get_functions_within_lines(modified_lines) + + +def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str, list[FunctionToOptimize]]: + functions: dict[str, list[FunctionToOptimize]] = {} for path_str, lines_in_file in modified_lines.items(): path = Path(path_str) if not path.exists(): @@ -246,14 +255,14 @@ def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[F continue function_lines = FunctionVisitor(file_path=str(path)) wrapper.visit(function_lines) - modified_functions[str(path)] = [ + functions[str(path)] = [ function_to_optimize for function_to_optimize in function_lines.functions if (start_line := function_to_optimize.starting_line) is not None and (end_line := function_to_optimize.ending_line) is not None and any(start_line <= line <= end_line for line in lines_in_file) ] - return modified_functions + return functions def get_all_files_and_functions(module_root_path: Path) -> dict[str, list[FunctionToOptimize]]: diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 433603d36..a68688ed6 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -13,7 +13,11 @@ from codeflash.cli_cmds.cli import process_pyproject_config from codeflash.code_utils.git_utils import create_diff_patch_from_worktree from codeflash.code_utils.shell_utils import save_api_key_to_rc -from codeflash.discovery.functions_to_optimize import filter_functions, get_functions_within_git_diff +from codeflash.discovery.functions_to_optimize import ( + filter_functions, + get_functions_inside_a_commit, + get_functions_within_git_diff, +) from codeflash.either import is_successful from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol @@ -22,6 +26,8 @@ from lsprotocol import types + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + @dataclass class OptimizableFunctionsParams: @@ -39,6 +45,11 @@ class ProvideApiKeyParams: api_key: str +@dataclass +class OptimizableFunctionsInCommitParams: + commit_hash: str + + server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol) @@ -47,6 +58,22 @@ def get_functions_in_current_git_diff( server: CodeflashLanguageServer, _params: OptimizableFunctionsParams ) -> dict[str, str | dict[str, list[str]]]: functions = get_functions_within_git_diff(uncommitted_changes=True) + file_to_qualified_names = _group_functions_by_file(server, functions) + return {"functions": file_to_qualified_names, "status": "success"} + + +@server.feature("getOptimizableFunctionsInCommit") +def get_functions_in_commit( + server: CodeflashLanguageServer, params: OptimizableFunctionsInCommitParams +) -> dict[str, str | dict[str, list[str]]]: + functions = get_functions_inside_a_commit(params.commit_hash) + file_to_qualified_names = _group_functions_by_file(server, functions) + return {"functions": file_to_qualified_names, "status": "success"} + + +def _group_functions_by_file( + server: CodeflashLanguageServer, functions: dict[str, list[FunctionToOptimize]] +) -> dict[str, list[str]]: file_to_funcs_to_optimize, _ = filter_functions( modified_functions=functions, tests_root=server.optimizer.test_cfg.tests_root, @@ -58,7 +85,7 @@ def get_functions_in_current_git_diff( file_to_qualified_names: dict[str, list[str]] = { str(path): [f.qualified_name for f in funcs] for path, funcs in file_to_funcs_to_optimize.items() } - return {"functions": file_to_qualified_names, "status": "success"} + return file_to_qualified_names @server.feature("getOptimizableFunctions")