diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 6b478f148..f1928a9ac 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -532,7 +532,7 @@ def generate_regression_tests( # noqa: D417 ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text}) return None - def get_optimization_impact( + def get_optimization_review( self, original_code: dict[Path, str], new_code: dict[Path, str], @@ -544,8 +544,9 @@ def get_optimization_impact( replay_tests: str, root_dir: Path, concolic_tests: str, # noqa: ARG002 + calling_fn_details: str, ) -> str: - """Compute the optimization impact of current Pull Request. + """Compute the optimization review of current Pull Request. Args: original_code: dict -> data structure mapping file paths to function definition for original code @@ -558,10 +559,11 @@ def get_optimization_impact( replay_tests: str -> replay test table root_dir: Path -> path of git directory concolic_tests: str -> concolic_tests (not used) + calling_fn_details: str -> filenames and definitions of functions which call the function_to_optimize Returns: ------- - - 'high' or 'low' optimization impact + - 'high', 'medium' or 'low' optimization review """ diff_str = "\n".join( @@ -577,14 +579,7 @@ def get_optimization_impact( ] ) code_diff = f"```diff\n{diff_str}\n```" - # TODO get complexity metrics and fn call heuristics -> constructing a complete static call graph can be expensive for really large repos - # grep function name in codebase -> ast parser to get no of calls and no of calls in loop -> radon lib to get complexity metrics -> send as additional context to the AI service - # metric 1 -> call count - how many times the function is called in the codebase - # metric 2 -> loop call count - how many times the function is called in a loop in the codebase - # metric 3 -> presence of decorators like @profile, @cache -> this means the owner of the repo cares about the performance of this function - # metric 4 -> cyclomatic complexity (https://en.wikipedia.org/wiki/Cyclomatic_complexity) - # metric 5 (for future) -> halstead complexity (https://en.wikipedia.org/wiki/Halstead_complexity_measures) - logger.info("!lsp|Computing Optimization Impact…") + logger.info("!lsp|Computing Optimization Review…") payload = { "code_diff": code_diff, "explanation": explanation.raw_explanation_message, @@ -598,22 +593,23 @@ def get_optimization_impact( "benchmark_details": explanation.benchmark_details if explanation.benchmark_details else None, "optimized_runtime": humanize_runtime(explanation.best_runtime_ns), "original_runtime": humanize_runtime(explanation.original_runtime_ns), + "calling_fn_details": calling_fn_details, } console.rule() try: - response = self.make_ai_service_request("/optimization_impact", payload=payload, timeout=600) + response = self.make_ai_service_request("/optimization_review", payload=payload, timeout=600) except requests.exceptions.RequestException as e: logger.exception(f"Error generating optimization refinements: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) return "" if response.status_code == 200: - return cast("str", response.json()["impact"]) + return cast("str", response.json()["review"]) try: error = cast("str", response.json()["error"]) except Exception: error = response.text - logger.error(f"Error generating impact candidates: {response.status_code} - {error}") + logger.error(f"Error generating optimization review: {response.status_code} - {error}") ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) console.rule() return "" diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index a0a5685b3..d410f75dd 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -130,7 +130,7 @@ def suggest_changes( coverage_message: str, replay_tests: str = "", concolic_tests: str = "", - optimization_impact: str = "", + optimization_review: str = "", ) -> Response: """Suggest changes to a pull request. @@ -156,7 +156,7 @@ def suggest_changes( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, - "optimizationImpact": optimization_impact, + "optimizationImpact": optimization_review, # impact keyword left for legacy reasons, touches js/ts code } return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload) @@ -173,6 +173,7 @@ def create_pr( coverage_message: str, replay_tests: str = "", concolic_tests: str = "", + optimization_review: str = "", ) -> Response: """Create a pull request, targeting the specified branch. (usually 'main'). @@ -197,6 +198,7 @@ def create_pr( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, + "optimizationImpact": optimization_review, # Impact keyword left for legacy reasons, it touches js/ts codebase } return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload) @@ -212,6 +214,7 @@ def create_staging( replay_tests: str, concolic_tests: str, root_dir: Path, + optimization_review: str = "", ) -> Response: """Create a staging pull request, targeting the specified branch. (usually 'staging'). @@ -252,6 +255,7 @@ def create_staging( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, + "optimizationImpact": optimization_review, # Impact keyword left for legacy reasons, it touches js/ts codebase } return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 4d6235b0a..0a515a080 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1,21 +1,22 @@ -# ruff: noqa: ARG002 from __future__ import annotations import ast +from dataclasses import dataclass from itertools import chain -from typing import TYPE_CHECKING, Optional +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union +import jedi import libcst as cst from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package from codeflash.cli_cmds.console import logger -from codeflash.models.models import FunctionParent +from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW +from codeflash.models.models import CodePosition, FunctionParent if TYPE_CHECKING: - from pathlib import Path - from libcst.helpers import ModuleNameAndPackage from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -33,28 +34,28 @@ def __init__(self) -> None: self.scope_depth = 0 self.if_else_depth = 0 - def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: + def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002 self.scope_depth += 1 return True - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 self.scope_depth -= 1 - def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: + def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002 self.scope_depth += 1 return True - def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 self.scope_depth -= 1 - def visit_If(self, node: cst.If) -> Optional[bool]: + def visit_If(self, node: cst.If) -> Optional[bool]: # noqa: ARG002 self.if_else_depth += 1 return True - def leave_If(self, original_node: cst.If) -> None: + def leave_If(self, original_node: cst.If) -> None: # noqa: ARG002 self.if_else_depth -= 1 - def visit_Else(self, node: cst.Else) -> Optional[bool]: + def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002 # Else blocks are already counted as part of the if statement return True @@ -81,24 +82,24 @@ def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: self.scope_depth = 0 self.if_else_depth = 0 - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.scope_depth += 1 - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 self.scope_depth -= 1 return updated_node - def visit_ClassDef(self, node: cst.ClassDef) -> None: + def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002 self.scope_depth += 1 - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002 self.scope_depth -= 1 return updated_node - def visit_If(self, node: cst.If) -> None: + def visit_If(self, node: cst.If) -> None: # noqa: ARG002 self.if_else_depth += 1 - def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: + def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: # noqa: ARG002 self.if_else_depth -= 1 return updated_node @@ -146,7 +147,7 @@ def _find_insertion_index(self, updated_node: cst.Module) -> int: return insert_index - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 # Add any new assignments that weren't in the original file new_statements = list(updated_node.body) @@ -190,20 +191,20 @@ def __init__(self) -> None: self.global_statements = [] self.in_function_or_class = False - def visit_ClassDef(self, node: cst.ClassDef) -> bool: + def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002 # Don't visit inside classes self.in_function_or_class = True return False - def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 self.in_function_or_class = False - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002 # Don't visit inside functions self.in_function_or_class = True return False - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 self.in_function_or_class = False def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: @@ -284,16 +285,16 @@ def visit_Module(self, node: cst.Module) -> None: self.depth = 0 self._collect_imports_from_block(node) - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.depth += 1 - def leave_FunctionDef(self, node: cst.FunctionDef) -> None: + def leave_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.depth -= 1 - def visit_ClassDef(self, node: cst.ClassDef) -> None: + def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002 self.depth += 1 - def leave_ClassDef(self, node: cst.ClassDef) -> None: + def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002 self.depth -= 1 def visit_If(self, node: cst.If) -> None: @@ -316,7 +317,9 @@ def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import self.inserted = False def leave_SimpleStatementLine( - self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + self, + original_node: cst.SimpleStatementLine, # noqa: ARG002 + updated_node: cst.SimpleStatementLine, ) -> cst.Module: self.current_line += 1 @@ -327,7 +330,7 @@ def leave_SimpleStatementLine( return cst.Module(body=[updated_node]) - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 # If there were no imports, add at the beginning of the module if self.last_import_line == 0 and not self.inserted: updated_body = list(updated_node.body) @@ -355,7 +358,9 @@ def find_last_import_line(target_code: str) -> int: class FutureAliasedImportTransformer(cst.CSTTransformer): def leave_ImportFrom( - self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + self, + original_node: cst.ImportFrom, # noqa: ARG002 + updated_node: cst.ImportFrom, ) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel: import libcst.matchers as m @@ -748,3 +753,418 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),))) return preexisting_objects + + +@dataclass +class FunctionCallLocation: + """Represents a location where the target function is called.""" + + calling_function: str + line: int + column: int + + +@dataclass +class FunctionDefinitionInfo: + """Contains information about a function definition.""" + + name: str + node: ast.FunctionDef + source_code: str + start_line: int + end_line: int + is_method: bool + class_name: Optional[str] = None + + +class FunctionCallFinder(ast.NodeVisitor): + """AST visitor that finds all function definitions that call a specific qualified function. + + Args: + target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") + target_filepath: The filepath where the target function is defined + + """ + + def __init__(self, target_function_name: str, target_filepath: str, source_lines: list[str]) -> None: + self.target_function_name = target_function_name + self.target_filepath = target_filepath + self.source_lines = source_lines # Store original source lines for extraction + + # Parse the target function name into parts + self.target_parts = target_function_name.split(".") + self.target_base_name = self.target_parts[-1] + + # Track current context + self.current_function_stack: list[tuple[str, ast.FunctionDef]] = [] + self.current_class_stack: list[str] = [] + + # Track imports to resolve qualified names + self.imports: dict[str, str] = {} # Maps imported names to their full paths + + # Results + self.function_calls: list[FunctionCallLocation] = [] + self.calling_functions: set[str] = set() + self.function_definitions: dict[str, FunctionDefinitionInfo] = {} + + # Track if we found calls in the current function + self.found_call_in_current_function = False + self.functions_with_nested_calls: set[str] = set() + + def visit_Import(self, node: ast.Import) -> None: + """Track regular imports.""" + for alias in node.names: + if alias.asname: + # import module as alias + self.imports[alias.asname] = alias.name + else: + # import module + self.imports[alias.name.split(".")[-1]] = alias.name + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Track from imports.""" + if node.module: + for alias in node.names: + if alias.name == "*": + # from module import * + self.imports["*"] = node.module + elif alias.asname: + # from module import name as alias + self.imports[alias.asname] = f"{node.module}.{alias.name}" + else: + # from module import name + self.imports[alias.name] = f"{node.module}.{alias.name}" + self.generic_visit(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + """Track when entering a class definition.""" + self.current_class_stack.append(node.name) + self.generic_visit(node) + self.current_class_stack.pop() + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Track when entering a function definition.""" + self._visit_function_def(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + """Track when entering an async function definition.""" + self._visit_function_def(node) + + def _visit_function_def(self, node: ast.FunctionDef) -> None: + """Track when entering a function definition.""" + func_name = node.name + + # Build the full qualified name including class if applicable + full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name + + self.current_function_stack.append((full_name, node)) + self.found_call_in_current_function = False + + # Visit the function body + self.generic_visit(node) + + # Process the function after visiting its body + if self.found_call_in_current_function and full_name not in self.function_definitions: + # Extract function source code + source_code = self._extract_source_code(node) + + self.function_definitions[full_name] = FunctionDefinitionInfo( + name=full_name, + node=node, + source_code=source_code, + start_line=node.lineno, + end_line=node.end_lineno if hasattr(node, "end_lineno") else node.lineno, + is_method=bool(self.current_class_stack), + class_name=self.current_class_stack[-1] if self.current_class_stack else None, + ) + + # Handle nested functions - mark parent as containing nested calls + if self.found_call_in_current_function and len(self.current_function_stack) > 1: + parent_name = self.current_function_stack[-2][0] + self.functions_with_nested_calls.add(parent_name) + + # Also store the parent function if not already stored + if parent_name not in self.function_definitions: + parent_node = self.current_function_stack[-2][1] + parent_source = self._extract_source_code(parent_node) + + # Check if parent is a method (excluding current level) + parent_class_context = self.current_class_stack if len(self.current_function_stack) == 2 else [] + + self.function_definitions[parent_name] = FunctionDefinitionInfo( + name=parent_name, + node=parent_node, + source_code=parent_source, + start_line=parent_node.lineno, + end_line=parent_node.end_lineno if hasattr(parent_node, "end_lineno") else parent_node.lineno, + is_method=bool(parent_class_context), + class_name=parent_class_context[-1] if parent_class_context else None, + ) + + self.current_function_stack.pop() + + # Reset flag for parent function + if self.current_function_stack: + parent_name = self.current_function_stack[-1][0] + self.found_call_in_current_function = parent_name in self.calling_functions + + def visit_Call(self, node: ast.Call) -> None: + """Check if this call matches our target function.""" + if not self.current_function_stack: + # Not inside a function, skip + self.generic_visit(node) + return + + if self._is_target_function_call(node): + current_func_name = self.current_function_stack[-1][0] + + call_location = FunctionCallLocation( + calling_function=current_func_name, line=node.lineno, column=node.col_offset + ) + + self.function_calls.append(call_location) + self.calling_functions.add(current_func_name) + self.found_call_in_current_function = True + + self.generic_visit(node) + + def _is_target_function_call(self, node: ast.Call) -> bool: + """Determine if this call node is calling our target function.""" + call_name = self._get_call_name(node.func) + if not call_name: + return False + + # Check if it matches directly + if call_name == self.target_function_name: + return True + + # Check if it's just the base name matching + if call_name == self.target_base_name: + # Could be imported with a different name, check imports + if call_name in self.imports: + imported_path = self.imports[call_name] + if imported_path == self.target_function_name or imported_path.endswith( + f".{self.target_function_name}" + ): + return True + # Could also be a direct call if we're in the same file + return True + + # Check for qualified calls with imports + call_parts = call_name.split(".") + if call_parts[0] in self.imports: + # Resolve the full path using imports + base_import = self.imports[call_parts[0]] + full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import + + if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): + return True + + return False + + def _get_call_name(self, func_node) -> Optional[str]: # noqa : ANN001 + """Extract the name being called from a function node.""" + if isinstance(func_node, ast.Name): + return func_node.id + if isinstance(func_node, ast.Attribute): + parts = [] + current = func_node + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + return ".".join(reversed(parts)) + return None + + def _extract_source_code(self, node: ast.FunctionDef) -> str: + """Extract source code for a function node using original source lines.""" + if not self.source_lines or not hasattr(node, "lineno"): + # Fallback to ast.unparse if available (Python 3.9+) + try: + return ast.unparse(node) + except AttributeError: + return f"# Source code extraction not available for {node.name}" + + # Get the lines for this function + start_line = node.lineno - 1 # Convert to 0-based index + end_line = node.end_lineno if hasattr(node, "end_lineno") else len(self.source_lines) + + # Extract the function lines + func_lines = self.source_lines[start_line:end_line] + + # Find the minimum indentation (excluding empty lines) + min_indent = float("inf") + for line in func_lines: + if line.strip(): # Skip empty lines + indent = len(line) - len(line.lstrip()) + min_indent = min(min_indent, indent) + + # If this is a method (inside a class), preserve one level of indentation + if self.current_class_stack: + # Keep 4 spaces of indentation for methods + dedent_amount = max(0, min_indent - 4) + result_lines = [] + for line in func_lines: + if line.strip(): # Only dedent non-empty lines + result_lines.append(line[dedent_amount:] if len(line) > dedent_amount else line) + else: + result_lines.append(line) + else: + # For top-level functions, remove all leading indentation + result_lines = [] + for line in func_lines: + if line.strip(): # Only dedent non-empty lines + result_lines.append(line[min_indent:] if len(line) > min_indent else line) + else: + result_lines.append(line) + + return "".join(result_lines).rstrip() + + def get_results(self) -> dict[str, str]: + """Get the results of the analysis. + + Returns: + A dictionary mapping qualified function names to their source code definitions. + + """ + return {info.name: info.source_code for info in self.function_definitions.values()} + + +def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> dict[str, str]: + """Find all function definitions that call a specific target function. + + Args: + source_code: The Python source code to analyze + target_function_name: The qualified name of the function to find (e.g., "module.function") + target_filepath: The filepath where the target function is defined + + Returns: + A dictionary mapping qualified function names to their source code definitions. + Example: {"function_a": "def function_a(): ...", "MyClass.method_one": "def method_one(self): ..."} + + """ + # Parse the source code + tree = ast.parse(source_code) + + # Split source into lines for source extraction + source_lines = source_code.splitlines(keepends=True) + + # Create and run the visitor + visitor = FunctionCallFinder(target_function_name, target_filepath, source_lines) + visitor.visit(tree) + + return visitor.get_results() + + +def find_occurances( + qualified_name: str, file_path: str, fn_matches: list[Path], project_root: Path, tests_root: Path +) -> list[str]: # max chars for context + context_len = 0 + fn_call_context = "" + for cur_file in fn_matches: + if context_len > MAX_CONTEXT_LEN_REVIEW: + break + cur_file_path = Path(cur_file) + # exclude references in tests + try: + if cur_file_path.relative_to(tests_root): + continue + except ValueError: + pass + with cur_file_path.open(encoding="utf8") as f: + file_content = f.read() + results = find_function_calls(file_content, target_function_name=qualified_name, target_filepath=file_path) + if results: + try: + path_relative_to_project_root = cur_file_path.relative_to(project_root) + except Exception as e: + # shouldn't happen but ensuring we don't crash + logger.debug(f"investigate {e}") + continue + fn_call_context += f"```python:{path_relative_to_project_root}\n" + for ( + fn_definition + ) in results.values(): # multiple functions in the file might be calling the desired function + fn_call_context += f"{fn_definition}\n" + context_len += len(fn_definition) + fn_call_context += "```\n" + return fn_call_context + + +def find_specific_function_in_file( + source_code: str, filepath: Union[str, Path], target_function: str, target_class: str | None +) -> Optional[tuple[int, int]]: + """Find a specific function definition in a Python file and return its location. + + Stops searching once the target is found (optimized for performance). + + Args: + source_code: Source code string + filepath: Path to the Python file + target_function: Function Name of the function to find + target_class: Class name of the function to find + + Returns: + Tuple of (line_number, column_offset) if found, None otherwise + + """ + script = jedi.Script(code=source_code, path=filepath) + names = script.get_names(all_scopes=True, definitions=True) + for name in names: + if name.type == "function" and name.name == target_function: + # If class name specified, check parent + if target_class: + parent = name.parent() + if parent and parent.name == target_class and parent.type == "class": + return CodePosition(line_no=name.line, col_no=name.column) + else: + # Top-level function match + return CodePosition(line_no=name.line, col_no=name.column) + + return None # Function not found + + +def get_fn_references_jedi( + source_code: str, file_path: Path, project_root: Path, target_function: str, target_class: str | None +) -> list[Path]: + function_position: CodePosition = find_specific_function_in_file( + source_code, file_path, target_function, target_class + ) + try: + script = jedi.Script(code=source_code, path=file_path, project=jedi.Project(path=project_root)) + # Get references to the function + references = script.get_references(line=function_position.line_no, column=function_position.col_no) + # Collect unique file paths where references are found + reference_files = set() + for ref in references: + if ref.module_path: + # Convert to string and normalize path + ref_path = str(ref.module_path) + # Skip the definition itself + if not (ref_path == file_path and ref.line == function_position.line_no): + reference_files.add(ref_path) + return sorted(reference_files) + except Exception as e: + print(f"Error during Jedi analysis: {e}") + return [] + + +def get_opt_review_metrics( + source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path +) -> str: + try: + qualified_name_split = qualified_name.rsplit(".", maxsplit=1) + if len(qualified_name_split) == 1: + target_function, target_class = qualified_name_split[0], None + else: + target_function, target_class = qualified_name_split[1], qualified_name_split[0] + matches = get_fn_references_jedi( + source_code, file_path, project_root, target_function, target_class + ) # jedi is not perfect, it doesn't capture aliased references + calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root) + except Exception as e: + calling_fns_details = "" + logger.debug(f"Investigate {e}") + return calling_fns_details diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index cc1eb50da..6b2805fbf 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -34,3 +34,5 @@ N_CANDIDATES_LP_EFFECTIVE = min(N_CANDIDATES_LP_LSP if _IS_LSP_ENABLED else N_CANDIDATES_LP, MAX_N_CANDIDATES_LP) N_TESTS_TO_GENERATE_EFFECTIVE = N_TESTS_TO_GENERATE_LSP if _IS_LSP_ENABLED else N_TESTS_TO_GENERATE TOTAL_LOOPING_TIME_EFFECTIVE = TOTAL_LOOPING_TIME_LSP if _IS_LSP_ENABLED else TOTAL_LOOPING_TIME + +MAX_CONTEXT_LEN_REVIEW = 1000 diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 61e7ef4cc..eb61a689b 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -23,6 +23,7 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils +from codeflash.code_utils.code_extractor import get_opt_review_metrics from codeflash.code_utils.code_replacer import ( add_custom_marker_to_all_tests, modify_autouse_fixture, @@ -1460,12 +1461,21 @@ def process_review( if raise_pr or staging_review: data["root_dir"] = git_root_dir() - opt_impact_response = "" + calling_fn_details = get_opt_review_metrics( + self.function_to_optimize_source_code, + self.function_to_optimize.file_path, + self.function_to_optimize.qualified_name, + self.project_root, + self.test_cfg.tests_root, + ) + opt_review_response = "" try: - opt_impact_response = self.aiservice_client.get_optimization_impact(**data) + opt_review_response = self.aiservice_client.get_optimization_review( + **data, calling_fn_details=calling_fn_details + ) except Exception as e: - logger.debug(f"optimization impact response failed, investigate {e}") - data["optimization_impact"] = opt_impact_response + logger.debug(f"optimization review response failed, investigate {e}") + data["optimization_review"] = opt_review_response if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index f9fbf84d7..55f3713fd 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -185,7 +185,7 @@ def check_create_pr( concolic_tests: str, root_dir: Path, git_remote: Optional[str] = None, - optimization_impact: str = "", + optimization_review: str = "", ) -> None: pr_number: Optional[int] = env_utils.get_pr_number() git_repo = git.Repo(search_parent_directories=True) @@ -227,7 +227,7 @@ def check_create_pr( coverage_message=coverage_message, replay_tests=replay_tests, concolic_tests=concolic_tests, - optimization_impact=optimization_impact, + optimization_review=optimization_review, ) if response.ok: logger.info(f"Suggestions were successfully made to PR #{pr_number}") @@ -277,6 +277,7 @@ def check_create_pr( coverage_message=coverage_message, replay_tests=replay_tests, concolic_tests=concolic_tests, + optimization_review=optimization_review, ) if response.ok: pr_id = response.text