From 39c3c6491cead400f80d1cb4233acbd83fdc7759 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Sun, 12 Oct 2025 21:33:54 -0400 Subject: [PATCH 01/23] python side done, todo backend js/ts --- codeflash/api/cfapi.py | 6 ++++++ codeflash/optimization/function_optimizer.py | 14 ++++++-------- codeflash/result/create_pr.py | 3 +++ 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index bd3e927d4..da07c66f2 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -130,6 +130,7 @@ def suggest_changes( coverage_message: str, replay_tests: str = "", concolic_tests: str = "", + optimization_impact: str = "", ) -> Response: """Suggest changes to a pull request. @@ -155,6 +156,7 @@ def suggest_changes( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, + "optimization_impact": optimization_impact, } return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload) @@ -171,6 +173,7 @@ def create_pr( coverage_message: str, replay_tests: str = "", concolic_tests: str = "", + optimization_impact: str = "", ) -> Response: """Create a pull request, targeting the specified branch. (usually 'main'). @@ -195,6 +198,7 @@ def create_pr( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, + "optimization_impact": optimization_impact, } return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload) @@ -210,6 +214,7 @@ def create_staging( replay_tests: str, concolic_tests: str, root_dir: Path, + optimization_impact: str = "", ) -> Response: """Create a staging pull request, targeting the specified branch. (usually 'staging'). @@ -250,6 +255,7 @@ def create_staging( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, + "optimization_impact": optimization_impact, } return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 8aaf11ec5..e54aac92d 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1461,14 +1461,12 @@ def process_review( if raise_pr or staging_review: data["root_dir"] = git_root_dir() - # try: - # # modify argument of staging vs pr based on the impact - # opt_impact_response = self.aiservice_client.get_optimization_impact(**data) - # if opt_impact_response == "low": - # raise_pr = False - # staging_review = True - # except Exception as e: - # logger.debug(f"optimization impact response failed, investigate {e}") + opt_impact_response = "" + try: + opt_impact_response = self.aiservice_client.get_optimization_impact(**data) + except Exception as e: + logger.debug(f"optimization impact response failed, investigate {e}") + data["optimization_impact"] = opt_impact_response if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 7731c67f2..3f1ffa200 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -185,6 +185,7 @@ def check_create_pr( concolic_tests: str, root_dir: Path, git_remote: Optional[str] = None, + optimization_impact: str = "", ) -> None: pr_number: Optional[int] = env_utils.get_pr_number() git_repo = git.Repo(search_parent_directories=True) @@ -226,6 +227,7 @@ def check_create_pr( coverage_message=coverage_message, replay_tests=replay_tests, concolic_tests=concolic_tests, + optimization_impact=optimization_impact, ) if response.ok: logger.info(f"Suggestions were successfully made to PR #{pr_number}") @@ -275,6 +277,7 @@ def check_create_pr( coverage_message=coverage_message, replay_tests=replay_tests, concolic_tests=concolic_tests, + optimization_impact=optimization_impact, ) if response.ok: pr_id = response.text From 567a4410ffbd644e32ca4afcd0ddcdb43c0e5b0a Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 13 Oct 2025 14:44:44 -0700 Subject: [PATCH 02/23] Update cfapi.py --- codeflash/api/cfapi.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index da07c66f2..a0a5685b3 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -156,7 +156,7 @@ def suggest_changes( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, - "optimization_impact": optimization_impact, + "optimizationImpact": optimization_impact, } return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload) @@ -173,7 +173,6 @@ def create_pr( coverage_message: str, replay_tests: str = "", concolic_tests: str = "", - optimization_impact: str = "", ) -> Response: """Create a pull request, targeting the specified branch. (usually 'main'). @@ -198,7 +197,6 @@ def create_pr( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, - "optimization_impact": optimization_impact, } return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload) @@ -214,7 +212,6 @@ def create_staging( replay_tests: str, concolic_tests: str, root_dir: Path, - optimization_impact: str = "", ) -> Response: """Create a staging pull request, targeting the specified branch. (usually 'staging'). @@ -255,7 +252,6 @@ def create_staging( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, - "optimization_impact": optimization_impact, } return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload) From 2cedeafa3b3ac8866e66004a7440fc2a7f065794 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 13 Oct 2025 14:46:38 -0700 Subject: [PATCH 03/23] Apply suggestion from @aseembits93 --- codeflash/result/create_pr.py | 1 - 1 file changed, 1 deletion(-) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 3f1ffa200..f9fbf84d7 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -277,7 +277,6 @@ def check_create_pr( coverage_message=coverage_message, replay_tests=replay_tests, concolic_tests=concolic_tests, - optimization_impact=optimization_impact, ) if response.ok: pr_id = response.text From 7c6a41e2b03bc91fecefc344834d67664943ca8d Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 14 Oct 2025 19:10:29 -0700 Subject: [PATCH 04/23] todos for next iteration --- codeflash/api/aiservice.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index e15333d75..6b478f148 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -577,6 +577,13 @@ def get_optimization_impact( ] ) code_diff = f"```diff\n{diff_str}\n```" + # TODO get complexity metrics and fn call heuristics -> constructing a complete static call graph can be expensive for really large repos + # grep function name in codebase -> ast parser to get no of calls and no of calls in loop -> radon lib to get complexity metrics -> send as additional context to the AI service + # metric 1 -> call count - how many times the function is called in the codebase + # metric 2 -> loop call count - how many times the function is called in a loop in the codebase + # metric 3 -> presence of decorators like @profile, @cache -> this means the owner of the repo cares about the performance of this function + # metric 4 -> cyclomatic complexity (https://en.wikipedia.org/wiki/Cyclomatic_complexity) + # metric 5 (for future) -> halstead complexity (https://en.wikipedia.org/wiki/Halstead_complexity_measures) logger.info("!lsp|Computing Optimization Impact…") payload = { "code_diff": code_diff, From 83fabf5e26cc7e5c7b7579f6a53fd58c488c2662 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Wed, 15 Oct 2025 15:57:32 -0700 Subject: [PATCH 05/23] wip --- codeflash/code_utils/code_extractor.py | 89 +++++- codeflash/code_utils/compat.py | 3 + codeflash/models/models.py | 8 + codeflash/optimization/function_optimizer.py | 4 + example_usage.py | 33 ++ function_call_visitor.py | 316 +++++++++++++++++++ test_function_call_visitor.py | 266 ++++++++++++++++ 7 files changed, 716 insertions(+), 3 deletions(-) create mode 100644 example_usage.py create mode 100644 function_call_visitor.py create mode 100644 test_function_call_visitor.py diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 4d6235b0a..3fddbaeaa 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -2,7 +2,10 @@ from __future__ import annotations import ast +import json +import subprocess from itertools import chain +from pathlib import Path from typing import TYPE_CHECKING, Optional import libcst as cst @@ -14,12 +17,10 @@ from codeflash.models.models import FunctionParent if TYPE_CHECKING: - from pathlib import Path - from libcst.helpers import ModuleNameAndPackage from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.models.models import FunctionSource + from codeflash.models.models import FunctionSource, ImpactMetrics class GlobalAssignmentCollector(cst.CSTVisitor): @@ -748,3 +749,85 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP if isinstance(cnode, (ast.FunctionDef, ast.AsyncFunctionDef)): preexisting_objects.add((cnode.name, (FunctionParent(node.name, "ClassDef"),))) return preexisting_objects + + +def search_with_ripgrep(pattern: str, path: str = ".") -> dict[str, list[tuple[int, str]]]: + """Use ripgrep to search for a pattern in the repository. + + Args: + pattern: The pattern to search for + path: The directory to search in (default: current directory) + + Returns: + Dictionary with filepaths as keys and list of (line_no, content) tuples as values + + """ + # Run ripgrep with JSON output for easier parsing + # -n: Show line numbers + # --json: Output in JSON format + # --no-heading: Don't group matches by file + path = str(Path.cwd()) + cmd = [ + "rg", + "-n", + "--json", + pattern, + path, + "-g", + "!/Users/aseemsaxena/Downloads/codeflash_dev/codeflash/code_to_optimize/tests/**", + ] + print(" ".join(cmd)) + # Parse the JSON output + matches_dict = {} + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, # Don't raise exception on non-zero return + ) + + if result.returncode not in [0, 1]: # 0 = matches found, 1 = no matches + print(f"Error running ripgrep: {result.stderr}") + return {} + + for line in result.stdout.strip().split("\n"): + if not line: + continue + + try: + json_obj = json.loads(line) + + # We're only interested in match objects + if json_obj.get("type") == "match": + data = json_obj.get("data", {}) + file_path = data.get("path", {}).get("text", "") + line_number = data.get("line_number") + line_content = data.get("lines", {}).get("text", "").rstrip("\n") + + if file_path and line_number: + if file_path not in matches_dict: + matches_dict[file_path] = [] + matches_dict[file_path].append((line_number, line_content)) + + except json.JSONDecodeError: + continue + + except FileNotFoundError: + print("Error: ripgrep (rg) is not installed or not in PATH") + return {} + except Exception as e: + print(f"Unexpected error: {e}") + return {} + return matches_dict + + +def get_opt_impact_metrics(file_path: Path, qualified_name: str, project_root: Path, tests_root: Path) -> ImpactMetrics: + # grep for function / use rg (respects gitignore) + # SAFE_GREP_EXECUTABLE command + # ast visitor for occurances and loop occurances + # radon lib for complexity metrics + print(file_path, qualified_name, project_root, tests_root) + + # grep windows alternative + return 0 diff --git a/codeflash/code_utils/compat.py b/codeflash/code_utils/compat.py index eb4e5b561..66b203429 100644 --- a/codeflash/code_utils/compat.py +++ b/codeflash/code_utils/compat.py @@ -1,4 +1,5 @@ import os +import shutil import sys import tempfile from pathlib import Path @@ -17,6 +18,7 @@ class Compat: LF: str = os.linesep SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() + SAFE_GREP_EXECUTABLE: str = shutil.which("grep") # works even grep is aliased in the env IS_POSIX: bool = os.name != "nt" @@ -45,3 +47,4 @@ def codeflash_cache_db(self) -> Path: LF = _compat.LF SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE IS_POSIX = _compat.IS_POSIX +SAFE_GREP_EXECUTABLE = _compat.SAFE_GREP_EXECUTABLE diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 84179054e..d00c1246d 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -31,6 +31,14 @@ from codeflash.verification.comparator import comparator +@dataclass +class ImpactMetrics: + complexity_score: int + occurances: int + loop_occurances: int + presence_of_decorators: bool + + @dataclass(frozen=True) class AIServiceRefinerRequest: optimization_id: str diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index e54aac92d..b6673dbdb 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -24,6 +24,7 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils +from codeflash.code_utils.code_extractor import get_opt_impact_metrics from codeflash.code_utils.code_replacer import ( add_custom_marker_to_all_tests, modify_autouse_fixture, @@ -1467,6 +1468,9 @@ def process_review( except Exception as e: logger.debug(f"optimization impact response failed, investigate {e}") data["optimization_impact"] = opt_impact_response + data["impact_metrics"] = get_opt_impact_metrics( + self.project_root, self.test_cfg.tests_root + ) # need module root, tests root only if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) diff --git a/example_usage.py b/example_usage.py new file mode 100644 index 000000000..e52600e93 --- /dev/null +++ b/example_usage.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +""" +Example of using the ripgrep search script programmatically. +""" + +from ripgrep_search import search_with_ripgrep +import json + +# Search for any pattern you want +pattern = "sorter" # Change this to any pattern you need +results = search_with_ripgrep(pattern) + +# Access the results as a dictionary +print(f"Found matches in {len(results)} files") + +# Iterate through the results +for filepath, occurrences in results.items(): + print(f"\n{filepath}: {len(occurrences)} matches") + for line_no, content in occurrences[:3]: # Show first 3 matches per file + print(f" Line {line_no}: {content[:80]}...") + +# Save results to a JSON file if needed +with open("search_results.json", "w") as f: + json.dump(results, f, indent=2) + +# Or filter results for specific files +python_files_only = { + path: matches + for path, matches in results.items() + if path.endswith('.py') +} + +print(f"\nPython files with matches: {len(python_files_only)}") \ No newline at end of file diff --git a/function_call_visitor.py b/function_call_visitor.py new file mode 100644 index 000000000..c20da3800 --- /dev/null +++ b/function_call_visitor.py @@ -0,0 +1,316 @@ +"""AST Visitor to count function calls and identify calls within loops. + +This module provides a visitor that can track calls to specific functions, +including regular functions, methods, classmethods, and staticmethods. +""" +from __future__ import annotations + +import ast +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + + +@dataclass +class CallInfo: + """Information about a function call.""" + + line: int + col: int + call_text: str + in_loop: bool + loop_type: Optional[str] = None # 'for', 'while', or nested combinations + file_path: Optional[str] = None + + def __repr__(self): + loop_info = f" (in {self.loop_type} loop)" if self.in_loop else "" + file_info = f"{self.file_path}:" if self.file_path else "" + return f"{file_info}{self.line}:{self.col} - {self.call_text}{loop_info}" + + +class FunctionCallVisitor(ast.NodeVisitor): + """AST visitor to count and track function calls. + + Handles: + - Regular function calls: func() + - Method calls: obj.method() + - Class method calls: Class.method() + - Static method calls: Class.static_method() + - Nested attribute calls: module.submodule.func() + """ + + def __init__(self, target_functions: list[str], file_path: Optional[str] = None): + """Initialize the visitor. + + Args: + target_functions: list of function names to track. Can be: + - Simple names: ['print', 'len'] + - Qualified names: ['os.path.join', 'numpy.array'] + - Method names: ['append', 'extend'] (will match any obj.append()) + file_path: Optional path to the file being analyzed + + """ + self.target_functions = set(target_functions) + self.file_path = file_path + self.calls: list[CallInfo] = [] + self.loop_stack: list[str] = [] # Track nested loops + self._source_lines: Optional[list[str]] = None + + def set_source(self, source: str): + """Set the source code for better call text extraction.""" + self._source_lines = source.splitlines() + + def _get_call_name(self, node: ast.Call) -> Optional[str]: + """Extract the full name of the called function.""" + if isinstance(node.func, ast.Name): + # Simple function call: func() + return node.func.id + if isinstance(node.func, ast.Attribute): + # Method or qualified call: obj.method() or module.func() + parts = [] + current = node.func + + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + + if isinstance(current, ast.Name): + parts.append(current.id) + full_name = ".".join(reversed(parts)) + + # Check if we should track this call + # Match exact qualified names or just the method name + if full_name in self.target_functions: + return full_name + + # Also check if just the method name matches + # (for tracking all calls to a method regardless of object) + method_name = parts[0] # The rightmost part is the method + if method_name in self.target_functions: + return full_name + + # Check partial matches (e.g., 'path.join' matches 'os.path.join') + for target in self.target_functions: + if full_name.endswith(target) or target.endswith(full_name): + return full_name + + return None + + def _get_call_text(self, node: ast.Call) -> str: + """Get a string representation of the call.""" + if self._source_lines and hasattr(node, "lineno") and hasattr(node, "end_lineno"): + try: + if node.lineno == node.end_lineno: + line = self._source_lines[node.lineno - 1] + if hasattr(node, "col_offset") and hasattr(node, "end_col_offset"): + return line[node.col_offset:node.end_col_offset] + else: + # Multi-line call + lines = [] + for i in range(node.lineno - 1, node.end_lineno): + if i < len(self._source_lines): + if i == node.lineno - 1: + lines.append(self._source_lines[i][node.col_offset:]) + elif i == node.end_lineno - 1: + lines.append(self._source_lines[i][:node.end_col_offset]) + else: + lines.append(self._source_lines[i]) + return " ".join(line.strip() for line in lines) + except (IndexError, AttributeError): + pass + + # Fallback to reconstructing from AST + return ast.unparse(node) if hasattr(ast, "unparse") else self._get_call_name(node) + "(...)" + + def _in_loop(self) -> bool: + """Check if we're currently inside a loop.""" + return len(self.loop_stack) > 0 + + def _get_loop_type(self) -> Optional[str]: + """Get the current loop type(s).""" + if not self.loop_stack: + return None + if len(self.loop_stack) == 1: + return self.loop_stack[0] + return " -> ".join(self.loop_stack) # Show nested loops + + def visit_Call(self, node: ast.Call): + """Visit a function call node.""" + call_name = self._get_call_name(node) + + if call_name: + # Check if this matches any of our target functions + should_track = False + + # Direct match + if call_name in self.target_functions: + should_track = True + else: + # Check if just the method/function name matches + simple_name = call_name.split(".")[-1] + if simple_name in self.target_functions: + should_track = True + else: + # Check for partial qualified matches + for target in self.target_functions: + if "." in target: + # For qualified targets, check if call matches the end + if call_name.endswith("." + target.split(".")[-1]): + should_track = True + break + + if should_track: + call_info = CallInfo( + line=node.lineno, + col=node.col_offset, + call_text=self._get_call_text(node), + in_loop=self._in_loop(), + loop_type=self._get_loop_type(), + file_path=self.file_path + ) + self.calls.append(call_info) + + self.generic_visit(node) + + def visit_For(self, node: ast.For): + """Visit a for loop.""" + self.loop_stack.append("for") + self.generic_visit(node) + self.loop_stack.pop() + + def visit_While(self, node: ast.While): + """Visit a while loop.""" + self.loop_stack.append("while") + self.generic_visit(node) + self.loop_stack.pop() + + def visit_AsyncFor(self, node: ast.AsyncFor): + """Visit an async for loop.""" + self.loop_stack.append("async for") + self.generic_visit(node) + self.loop_stack.pop() + + def get_summary(self) -> dict: + """Get a summary of the calls found.""" + total_calls = len(self.calls) + calls_in_loops = [c for c in self.calls if c.in_loop] + calls_outside_loops = [c for c in self.calls if not c.in_loop] + + return { + "total_calls": total_calls, + "calls_in_loops": len(calls_in_loops), + "calls_outside_loops": len(calls_outside_loops), + "all_calls": self.calls, + "loop_calls": calls_in_loops, + "non_loop_calls": calls_outside_loops + } + + +def analyze_file(file_path: str, target_functions: list[str]) -> dict: + """Analyze a Python file for function calls. + + Args: + file_path: Path to the Python file + target_functions: list of function names to track + + Returns: + dictionary with call statistics and details + + """ + with Path.open(file_path) as f: + source = f.read() + + tree = ast.parse(source, filename=file_path) + visitor = FunctionCallVisitor(target_functions, file_path) + visitor.set_source(source) + visitor.visit(tree) + + return visitor.get_summary() + + +def analyze_code(source: str, target_functions: list[str], file_path: Optional[str] = None) -> dict: + """Analyze Python source code for function calls. + + Args: + source: Python source code as string + target_functions: list of function names to track + file_path: Optional file path for reference + + Returns: + dictionary with call statistics and details + + """ + tree = ast.parse(source) + visitor = FunctionCallVisitor(target_functions, file_path) + visitor.set_source(source) + visitor.visit(tree) + + return visitor.get_summary() + + +if __name__ == "__main__": + # Example usage + example_code = """ +import os +import numpy as np + +def process_data(data): + print("Starting processing") + result = [] + + for item in data: + print(f"Processing {item}") + value = len(item) + result.append(value) + + for i in range(3): + print(f"Inner loop {i}") + np.array([1, 2, 3]) + + while len(result) < 10: + print("Adding more items") + result.append(0) + + os.path.join("dir", "file") + print("Done") + return result + +class DataProcessor: + def process(self, items): + for item in items: + self.validate(item) + print(f"Item: {item}") + + def validate(self, item): + if len(item) > 0: + print("Valid") + + @classmethod + def create(cls): + print("Creating processor") + return cls() + + @staticmethod + def utility(): + print("Utility function") +""" + + # Track multiple functions + targets = ["print", "len", "np.array", "os.path.join", "append", "validate"] + results = analyze_code(example_code, targets, "example.py") + + print("Function Call Analysis Results") + print("=" * 50) + print(f"Total calls found: {results['total_calls']}") + print(f"Calls in loops: {results['calls_in_loops']}") + print(f"Calls outside loops: {results['calls_outside_loops']}") + print("\nAll calls:") + print("-" * 50) + for call in results["all_calls"]: + print(f" {call}") + + if results["loop_calls"]: + print("\nCalls within loops:") + print("-" * 50) + for call in results["loop_calls"]: + print(f" {call}") diff --git a/test_function_call_visitor.py b/test_function_call_visitor.py new file mode 100644 index 000000000..b3f6f97e5 --- /dev/null +++ b/test_function_call_visitor.py @@ -0,0 +1,266 @@ +""" +Test and demonstrate the FunctionCallVisitor capabilities. +""" + +import ast +from function_call_visitor import FunctionCallVisitor, analyze_code, analyze_file + + +def test_basic_calls(): + """Test basic function call detection.""" + code = """ +def example(): + print("Hello") + len([1, 2, 3]) + max([4, 5, 6]) + print("World") +""" + results = analyze_code(code, ['print', 'len']) + print("Test: Basic Calls") + print(f" Found {results['total_calls']} calls") + for call in results['all_calls']: + print(f" {call}") + print() + + +def test_loop_detection(): + """Test detection of calls within loops.""" + code = """ +def process(): + print("Start") # Outside loop + + for i in range(10): + print(f"Item {i}") # In for loop + len(str(i)) # In for loop + + x = 0 + while x < 5: + print(f"While {x}") # In while loop + x += len([1, 2]) # In while loop + + print("End") # Outside loop +""" + results = analyze_code(code, ['print', 'len']) + print("Test: Loop Detection") + print(f" Total calls: {results['total_calls']}") + print(f" In loops: {results['calls_in_loops']}") + print(f" Outside loops: {results['calls_outside_loops']}") + print(" Loop calls:") + for call in results['loop_calls']: + print(f" {call}") + print() + + +def test_nested_loops(): + """Test detection in nested loops.""" + code = """ +def nested(): + for i in range(3): + print(f"Outer {i}") + for j in range(2): + print(f"Inner {i},{j}") + while j < 1: + print(f"Innermost") + j += 1 +""" + results = analyze_code(code, ['print']) + print("Test: Nested Loops") + for call in results['all_calls']: + print(f" {call}") + print() + + +def test_method_calls(): + """Test detection of method calls.""" + code = """ +class MyClass: + def __init__(self): + self.data = [] + + def process(self): + for item in [1, 2, 3]: + self.data.append(item) + self.validate(item) + + def validate(self, item): + if len(str(item)) > 0: + self.data.append(item * 2) + + @classmethod + def create(cls): + instance = cls() + instance.data.append(0) + return instance + + @staticmethod + def helper(): + result = [] + result.append(1) + return result + +obj = MyClass() +obj.process() +obj.data.append(99) +MyClass.create() +MyClass.helper() +""" + results = analyze_code(code, ['append', 'validate', 'len']) + print("Test: Method Calls") + print(f" Found {results['total_calls']} calls") + for call in results['all_calls']: + print(f" {call}") + print() + + +def test_module_calls(): + """Test detection of module function calls.""" + code = """ +import os.path +import numpy as np +from math import sqrt + +def example(): + # Module function calls + os.path.join("a", "b") + np.array([1, 2, 3]) + sqrt(16) + + for i in range(3): + os.path.exists(f"file_{i}") + np.zeros((2, 2)) + + # Nested module calls + result = os.path.dirname(os.path.join("x", "y")) +""" + results = analyze_code(code, ['os.path.join', 'np.array', 'sqrt', 'os.path.exists', 'np.zeros', 'os.path.dirname']) + print("Test: Module Calls") + print(f" Total calls: {results['total_calls']}") + print(" All calls:") + for call in results['all_calls']: + print(f" {call}") + print() + + +def test_complex_expressions(): + """Test calls in complex expressions.""" + code = """ +def complex_example(): + # Calls in list comprehensions + result = [len(x) for x in ["a", "bb", "ccc"]] + + # Calls in generator expressions + gen = (print(x) for x in range(3)) + + # Nested calls + value = max(len("hello"), len("world")) + + # Calls in lambda + func = lambda x: len(x) + len(x.strip()) + + # Calls in conditionals + if len("test") > 0: + print("Has length") + + # Calls in dict comprehensions + d = {x: len(x) for x in ["key1", "key2"]} +""" + results = analyze_code(code, ['len', 'print', 'max']) + print("Test: Complex Expressions") + print(f" Found {results['total_calls']} calls") + for call in results['all_calls']: + print(f" {call}") + print() + + +def test_async_code(): + """Test async function calls.""" + code = """ +async def async_example(): + print("Starting async") + + async for item in async_generator(): + print(f"Processing {item}") + await process_item(item) + + print("Done") + +async def async_generator(): + for i in range(3): + yield i + +async def process_item(item): + print(f"Item: {item}") +""" + results = analyze_code(code, ['print', 'process_item']) + print("Test: Async Code") + for call in results['all_calls']: + print(f" {call}") + print() + + +def test_partial_matching(): + """Test partial name matching.""" + code = """ +import os +import os.path +from pathlib import Path + +def file_operations(): + # These should all be caught when looking for 'join' + os.path.join("a", "b") + # path.join("c", "d") # Would need path to be defined + # something.else.join("x") # Would need something to be defined + + # Looking for any 'append' method + list1 = [] + list1.append(1) + list2 = [] + list2.append(2) + # some_obj.data.append(3) # Would need some_obj to be defined +""" + results = analyze_code(code, ['join', 'append']) + print("Test: Partial Matching") + print(f" Tracking 'join' and 'append'") + for call in results['all_calls']: + print(f" {call}") + print() + + +def run_all_tests(): + """Run all test cases.""" + print("=" * 60) + print("FunctionCallVisitor Test Suite") + print("=" * 60) + print() + + test_basic_calls() + test_loop_detection() + test_nested_loops() + test_method_calls() + test_module_calls() + test_complex_expressions() + test_async_code() + test_partial_matching() + + print("=" * 60) + print("All tests completed!") + print("=" * 60) + + +if __name__ == "__main__": + run_all_tests() + + # Example of analyzing an actual file + print("\nExample: Analyzing the visitor file itself") + print("-" * 60) + try: + results = analyze_file("function_call_visitor.py", ['isinstance', 'append', 'len']) + print(f"Found {results['total_calls']} calls in function_call_visitor.py") + print(f" In loops: {results['calls_in_loops']}") + print(f" Outside loops: {results['calls_outside_loops']}") + if results['loop_calls']: + print("\nCalls in loops:") + for call in results['loop_calls'][:5]: # Show first 5 + print(f" {call}") + except FileNotFoundError: + print(" (File not found - run from the same directory)") \ No newline at end of file From 06ffc139c04862823c770b789cc52ee46d2acc4d Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Wed, 15 Oct 2025 15:57:43 -0700 Subject: [PATCH 06/23] wip --- example_usage.py | 15 ++++------- function_call_visitor.py | 11 ++++---- test_function_call_visitor.py | 49 ++++++++++++++++------------------- 3 files changed, 34 insertions(+), 41 deletions(-) diff --git a/example_usage.py b/example_usage.py index e52600e93..94407d83e 100644 --- a/example_usage.py +++ b/example_usage.py @@ -1,11 +1,10 @@ #!/usr/bin/env python3 -""" -Example of using the ripgrep search script programmatically. -""" +"""Example of using the ripgrep search script programmatically.""" -from ripgrep_search import search_with_ripgrep import json +from ripgrep_search import search_with_ripgrep + # Search for any pattern you want pattern = "sorter" # Change this to any pattern you need results = search_with_ripgrep(pattern) @@ -24,10 +23,6 @@ json.dump(results, f, indent=2) # Or filter results for specific files -python_files_only = { - path: matches - for path, matches in results.items() - if path.endswith('.py') -} +python_files_only = {path: matches for path, matches in results.items() if path.endswith(".py")} -print(f"\nPython files with matches: {len(python_files_only)}") \ No newline at end of file +print(f"\nPython files with matches: {len(python_files_only)}") diff --git a/function_call_visitor.py b/function_call_visitor.py index c20da3800..5537e1a3a 100644 --- a/function_call_visitor.py +++ b/function_call_visitor.py @@ -3,6 +3,7 @@ This module provides a visitor that can track calls to specific functions, including regular functions, methods, classmethods, and staticmethods. """ + from __future__ import annotations import ast @@ -103,16 +104,16 @@ def _get_call_text(self, node: ast.Call) -> str: if node.lineno == node.end_lineno: line = self._source_lines[node.lineno - 1] if hasattr(node, "col_offset") and hasattr(node, "end_col_offset"): - return line[node.col_offset:node.end_col_offset] + return line[node.col_offset : node.end_col_offset] else: # Multi-line call lines = [] for i in range(node.lineno - 1, node.end_lineno): if i < len(self._source_lines): if i == node.lineno - 1: - lines.append(self._source_lines[i][node.col_offset:]) + lines.append(self._source_lines[i][node.col_offset :]) elif i == node.end_lineno - 1: - lines.append(self._source_lines[i][:node.end_col_offset]) + lines.append(self._source_lines[i][: node.end_col_offset]) else: lines.append(self._source_lines[i]) return " ".join(line.strip() for line in lines) @@ -166,7 +167,7 @@ def visit_Call(self, node: ast.Call): call_text=self._get_call_text(node), in_loop=self._in_loop(), loop_type=self._get_loop_type(), - file_path=self.file_path + file_path=self.file_path, ) self.calls.append(call_info) @@ -202,7 +203,7 @@ def get_summary(self) -> dict: "calls_outside_loops": len(calls_outside_loops), "all_calls": self.calls, "loop_calls": calls_in_loops, - "non_loop_calls": calls_outside_loops + "non_loop_calls": calls_outside_loops, } diff --git a/test_function_call_visitor.py b/test_function_call_visitor.py index b3f6f97e5..41536590f 100644 --- a/test_function_call_visitor.py +++ b/test_function_call_visitor.py @@ -1,9 +1,6 @@ -""" -Test and demonstrate the FunctionCallVisitor capabilities. -""" +"""Test and demonstrate the FunctionCallVisitor capabilities.""" -import ast -from function_call_visitor import FunctionCallVisitor, analyze_code, analyze_file +from function_call_visitor import analyze_code, analyze_file def test_basic_calls(): @@ -15,10 +12,10 @@ def example(): max([4, 5, 6]) print("World") """ - results = analyze_code(code, ['print', 'len']) + results = analyze_code(code, ["print", "len"]) print("Test: Basic Calls") print(f" Found {results['total_calls']} calls") - for call in results['all_calls']: + for call in results["all_calls"]: print(f" {call}") print() @@ -40,13 +37,13 @@ def process(): print("End") # Outside loop """ - results = analyze_code(code, ['print', 'len']) + results = analyze_code(code, ["print", "len"]) print("Test: Loop Detection") print(f" Total calls: {results['total_calls']}") print(f" In loops: {results['calls_in_loops']}") print(f" Outside loops: {results['calls_outside_loops']}") print(" Loop calls:") - for call in results['loop_calls']: + for call in results["loop_calls"]: print(f" {call}") print() @@ -63,9 +60,9 @@ def nested(): print(f"Innermost") j += 1 """ - results = analyze_code(code, ['print']) + results = analyze_code(code, ["print"]) print("Test: Nested Loops") - for call in results['all_calls']: + for call in results["all_calls"]: print(f" {call}") print() @@ -104,10 +101,10 @@ def helper(): MyClass.create() MyClass.helper() """ - results = analyze_code(code, ['append', 'validate', 'len']) + results = analyze_code(code, ["append", "validate", "len"]) print("Test: Method Calls") print(f" Found {results['total_calls']} calls") - for call in results['all_calls']: + for call in results["all_calls"]: print(f" {call}") print() @@ -132,11 +129,11 @@ def example(): # Nested module calls result = os.path.dirname(os.path.join("x", "y")) """ - results = analyze_code(code, ['os.path.join', 'np.array', 'sqrt', 'os.path.exists', 'np.zeros', 'os.path.dirname']) + results = analyze_code(code, ["os.path.join", "np.array", "sqrt", "os.path.exists", "np.zeros", "os.path.dirname"]) print("Test: Module Calls") print(f" Total calls: {results['total_calls']}") print(" All calls:") - for call in results['all_calls']: + for call in results["all_calls"]: print(f" {call}") print() @@ -164,10 +161,10 @@ def complex_example(): # Calls in dict comprehensions d = {x: len(x) for x in ["key1", "key2"]} """ - results = analyze_code(code, ['len', 'print', 'max']) + results = analyze_code(code, ["len", "print", "max"]) print("Test: Complex Expressions") print(f" Found {results['total_calls']} calls") - for call in results['all_calls']: + for call in results["all_calls"]: print(f" {call}") print() @@ -191,9 +188,9 @@ async def async_generator(): async def process_item(item): print(f"Item: {item}") """ - results = analyze_code(code, ['print', 'process_item']) + results = analyze_code(code, ["print", "process_item"]) print("Test: Async Code") - for call in results['all_calls']: + for call in results["all_calls"]: print(f" {call}") print() @@ -218,10 +215,10 @@ def file_operations(): list2.append(2) # some_obj.data.append(3) # Would need some_obj to be defined """ - results = analyze_code(code, ['join', 'append']) + results = analyze_code(code, ["join", "append"]) print("Test: Partial Matching") - print(f" Tracking 'join' and 'append'") - for call in results['all_calls']: + print(" Tracking 'join' and 'append'") + for call in results["all_calls"]: print(f" {call}") print() @@ -254,13 +251,13 @@ def run_all_tests(): print("\nExample: Analyzing the visitor file itself") print("-" * 60) try: - results = analyze_file("function_call_visitor.py", ['isinstance', 'append', 'len']) + results = analyze_file("function_call_visitor.py", ["isinstance", "append", "len"]) print(f"Found {results['total_calls']} calls in function_call_visitor.py") print(f" In loops: {results['calls_in_loops']}") print(f" Outside loops: {results['calls_outside_loops']}") - if results['loop_calls']: + if results["loop_calls"]: print("\nCalls in loops:") - for call in results['loop_calls'][:5]: # Show first 5 + for call in results["loop_calls"][:5]: # Show first 5 print(f" {call}") except FileNotFoundError: - print(" (File not found - run from the same directory)") \ No newline at end of file + print(" (File not found - run from the same directory)") From 9d69a58084ed3c2459b64eec26e2426cbdfcc3e6 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Wed, 15 Oct 2025 18:28:47 -0700 Subject: [PATCH 07/23] DIRTY --- codeflash/code_utils/code_extractor.py | 660 ++++++++++++++++++- codeflash/optimization/function_optimizer.py | 1 + find_sorter_references.py | 111 ++++ function_call_finder.py | 377 +++++++++++ function_call_finder_ast.py | 365 ++++++++++ ripgrep_search.py | 170 +++++ test_ast_vs_libcst.py | 128 ++++ test_function_call_finder.py | 52 ++ verify_output_format.py | 63 ++ 9 files changed, 1910 insertions(+), 17 deletions(-) create mode 100644 find_sorter_references.py create mode 100644 function_call_finder.py create mode 100644 function_call_finder_ast.py create mode 100644 ripgrep_search.py create mode 100644 test_ast_vs_libcst.py create mode 100644 test_function_call_finder.py create mode 100644 verify_output_format.py diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 3fddbaeaa..dee9118a1 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -4,9 +4,10 @@ import ast import json import subprocess +from dataclasses import dataclass from itertools import chain from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import libcst as cst from libcst.codemod import CodemodContext @@ -240,7 +241,7 @@ class DottedImportCollector(cst.CSTVisitor): import dbt.adapters.factory ==> "dbt.adapters.factory" from pathlib import Path ==> "pathlib.Path" from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter" - from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional" + from typing import Any, list, Optional ==> "typing.Any", "typing.list", "typing.Optional" from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps" """ @@ -445,7 +446,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]: and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__all__" ): - if isinstance(node.value, (ast.List, ast.Tuple)): + if isinstance(node.value, (ast.list, ast.tuple)): all_names = [] for elt in node.value.elts: if isinstance(elt, ast.Constant) and isinstance(elt.value, str): @@ -751,31 +752,23 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP return preexisting_objects -def search_with_ripgrep(pattern: str, path: str = ".") -> dict[str, list[tuple[int, str]]]: +def search_with_ripgrep(pattern: str, path: str, exclude_path: str) -> dict[str, list[tuple[int, str]]]: """Use ripgrep to search for a pattern in the repository. Args: pattern: The pattern to search for path: The directory to search in (default: current directory) + exclude_path: directory to avoid looking into Returns: - Dictionary with filepaths as keys and list of (line_no, content) tuples as values + dictionary with filepaths as keys and list of (line_no, content) tuples as values """ # Run ripgrep with JSON output for easier parsing # -n: Show line numbers # --json: Output in JSON format # --no-heading: Don't group matches by file - path = str(Path.cwd()) - cmd = [ - "rg", - "-n", - "--json", - pattern, - path, - "-g", - "!/Users/aseemsaxena/Downloads/codeflash_dev/codeflash/code_to_optimize/tests/**", - ] + cmd = ["rg", "-n", "--type","py", "--json", pattern, path, "-g", f"!{exclude_path}"] print(" ".join(cmd)) # Parse the JSON output matches_dict = {} @@ -822,12 +815,645 @@ def search_with_ripgrep(pattern: str, path: str = ".") -> dict[str, list[tuple[i return matches_dict +# @dataclass +# class FunctionCallLocation: +# """Represents a location where the target function is called.""" +# +# calling_function: str # Name of the function making the call +# line: int +# column: int +# call_node: cst.Call # The actual call node for additional analysis if needed +# +# +# @dataclass +# class FunctionDefinitionInfo: +# """Contains information about a function definition.""" +# +# name: str # Qualified name of the function +# node: cst.FunctionDef # The CST node of the function definition +# source_code: str # The source code of the function +# start_line: int +# end_line: int +# is_method: bool # Whether this is a class method +# class_name: Optional[str] = None # Name of containing class if it's a method +# +# +# class FunctionCallFinder(cst.CSTVisitor): +# """Visitor that finds all function definitions that call a specific qualified function. +# +# Args: +# target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") +# target_filepath: The filepath where the target function is defined +# +# """ +# +# METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) +# +# def __init__(self, target_function_name: str, target_filepath: str) -> None: +# super().__init__() +# self.target_function_name = target_function_name +# self.target_filepath = target_filepath +# +# # Parse the target function name into parts +# self.target_parts = target_function_name.split(".") +# self.target_base_name = self.target_parts[-1] +# +# # Track current context +# self.current_function_stack: list[tuple[str, cst.FunctionDef]] = [] # (name, node) pairs +# self.current_class_stack: list[str] = [] +# +# # Track imports to resolve qualified names +# self.imports: dict = {} # Maps imported names to their full paths +# +# # Results +# self.function_calls: list[FunctionCallLocation] = [] +# self.calling_functions: set[str] = set() # Unique function names that call the target +# self.function_definitions: dict[str, FunctionDefinitionInfo] = {} # Function name -> definition info +# +# # Track if we found calls in the current function +# self.found_call_in_current_function = False +# # Track functions with nested calls (parent functions that contain nested functions with calls) +# self.functions_with_nested_calls: set[str] = set() +# +# def visit_Import(self, node: cst.Import) -> None: +# """Track regular imports.""" +# for name in node.names: +# if isinstance(name, cst.ImportAlias): +# if name.asname: +# # import module as alias +# module_name = name.name.value if isinstance(name.name, cst.Attribute) else str(name.name) +# alias = name.asname.name.value +# self.imports[alias] = module_name +# else: +# # import module +# module_name = self._get_dotted_name(name.name) +# if module_name: +# self.imports[module_name.split(".")[-1]] = module_name +# +# def visit_ImportFrom(self, node: cst.ImportFrom) -> None: +# """Track from imports.""" +# if not node.module: +# return +# +# module_path = self._get_dotted_name(node.module) +# if not module_path: +# return +# +# if isinstance(node.names, cst.ImportStar): +# # from module import * +# self.imports["*"] = module_path +# else: +# # from module import name1, name2 +# for name in node.names: +# if isinstance(name, cst.ImportAlias): +# import_name = name.name.value +# if name.asname: +# # from module import name as alias +# alias = name.asname.name.value +# self.imports[alias] = f"{module_path}.{import_name}" +# else: +# # from module import name +# self.imports[import_name] = f"{module_path}.{import_name}" +# +# def visit_ClassDef(self, node: cst.ClassDef) -> None: +# """Track when entering a class definition.""" +# self.current_class_stack.append(node.name.value) +# +# def leave_ClassDef(self, node: cst.ClassDef) -> None: +# """Track when leaving a class definition.""" +# if self.current_class_stack: +# self.current_class_stack.pop() +# +# def visit_FunctionDef(self, node: cst.FunctionDef) -> None: +# """Track when entering a function definition.""" +# func_name = node.name.value +# +# # Build the full qualified name including class if applicable +# full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name +# +# self.current_function_stack.append((full_name, node)) +# self.found_call_in_current_function = False +# +# def leave_FunctionDef(self, node: cst.FunctionDef) -> None: +# """Track when leaving a function definition and store it if it contains target calls.""" +# if self.current_function_stack: +# full_name, func_node = self.current_function_stack.pop() +# +# # If we found a call in this function, store its definition +# if self.found_call_in_current_function and full_name not in self.function_definitions: +# # Get position information +# position = self.get_metadata(cst.metadata.PositionProvider, func_node) +# +# # Extract function source code by converting node to module +# # For methods, we need to maintain proper indentation +# func_source = cst.Module(body=[func_node]).code +# +# # For methods, add proper indentation (4 spaces) +# if self.current_class_stack: +# lines = func_source.split('\n') +# func_source = '\n'.join(' ' + line if line else line for line in lines) +# +# self.function_definitions[full_name] = FunctionDefinitionInfo( +# name=full_name, +# node=func_node, +# source_code=func_source.rstrip(), # Remove trailing whitespace +# start_line=position.start.line if position else -1, +# end_line=position.end.line if position else -1, +# is_method=bool(self.current_class_stack), +# class_name=self.current_class_stack[-1] if self.current_class_stack else None +# ) +# +# # Handle nested functions - mark parent as containing nested calls +# if self.found_call_in_current_function and self.current_function_stack: +# parent_name = self.current_function_stack[-1][0] +# self.functions_with_nested_calls.add(parent_name) +# # Also store the parent function if not already stored +# if parent_name not in self.function_definitions: +# parent_func_node = self.current_function_stack[-1][1] +# parent_position = self.get_metadata(cst.metadata.PositionProvider, parent_func_node) +# parent_source = cst.Module(body=[parent_func_node]).code +# +# # Get parent class context (go up one level in stack since we're inside the nested function) +# parent_class_stack = self.current_class_stack[:-1] if len(self.current_function_stack) == 1 and self.current_class_stack else [] +# +# if parent_class_stack: +# lines = parent_source.split('\n') +# parent_source = '\n'.join(' ' + line if line else line for line in lines) +# +# self.function_definitions[parent_name] = FunctionDefinitionInfo( +# name=parent_name, +# node=parent_func_node, +# source_code=parent_source.rstrip(), +# start_line=parent_position.start.line if parent_position else -1, +# end_line=parent_position.end.line if parent_position else -1, +# is_method=bool(parent_class_stack), +# class_name=parent_class_stack[-1] if parent_class_stack else None +# ) +# +# # Reset the flag for parent function if we're in nested functions +# if self.current_function_stack: +# # Check if the parent function should also be marked as containing calls +# parent_name = self.current_function_stack[-1][0] +# self.found_call_in_current_function = parent_name in self.calling_functions +# +# def visit_Call(self, node: cst.Call) -> None: +# """Check if this call matches our target function.""" +# if not self.current_function_stack: +# # Not inside a function, skip +# return +# +# if self._is_target_function_call(node): +# # Get position information +# position = self.get_metadata(cst.metadata.PositionProvider, node) +# +# current_func_name = self.current_function_stack[-1][0] +# +# call_location = FunctionCallLocation( +# calling_function=current_func_name, +# line=position.start.line if position else -1, +# column=position.start.column if position else -1, +# call_node=node, +# ) +# +# self.function_calls.append(call_location) +# self.calling_functions.add(current_func_name) +# self.found_call_in_current_function = True +# +# def _is_target_function_call(self, node: cst.Call) -> bool: +# """Determine if this call node is calling our target function. +# +# Handles various call patterns: +# - Direct calls: function() +# - Qualified calls: module.function() +# - Method calls: obj.method() +# """ +# func = node.func +# +# # Get the call name +# call_name = self._get_call_name(func) +# if not call_name: +# return False +# +# # Check if it matches directly +# if call_name == self.target_function_name: +# return True +# +# # Check if it's just the base name matching +# if call_name == self.target_base_name: +# # Could be imported with a different name, check imports +# if call_name in self.imports: +# imported_path = self.imports[call_name] +# # Check if the imported path matches our target +# if imported_path == self.target_function_name or imported_path.endswith( +# f".{self.target_function_name}" +# ): +# return True +# # Could also be a direct call if we're in the same file +# return True +# +# # Check for qualified calls with imports +# call_parts = call_name.split(".") +# if call_parts[0] in self.imports: +# # Resolve the full path using imports +# base_import = self.imports[call_parts[0]] +# full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import +# +# if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): +# return True +# +# return False +# +# def _get_call_name(self, func: Union[cst.Name, cst.Attribute, cst.Call]) -> Optional[str]: +# """Extract the name being called from a function node.""" +# if isinstance(func, cst.Name): +# return func.value +# if isinstance(func, cst.Attribute): +# return self._get_dotted_name(func) +# if isinstance(func, cst.Call): +# # Chained calls like foo()() +# return None +# return None +# +# def _get_dotted_name(self, node: Union[cst.Name, cst.Attribute]) -> Optional[str]: +# """Get the full dotted name from an Attribute or Name node.""" +# if isinstance(node, cst.Name): +# return node.value +# if isinstance(node, cst.Attribute): +# parts = [] +# current = node +# while isinstance(current, cst.Attribute): +# parts.append(current.attr.value) +# current = current.value +# if isinstance(current, cst.Name): +# parts.append(current.value) +# return ".".join(reversed(parts)) +# return None +# +# def get_results(self) -> dict[str, str]: +# """Get the results of the analysis. +# +# Returns: +# A dictionary mapping qualified function names to their source code definitions. +# Only includes functions that call the target function (directly or through nested functions). +# +# """ +# return { +# info.name: info.source_code +# for info in self.function_definitions.values() +# } +# +# +# def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> dict: +# """Find all function definitions that call a specific target function. +# +# Args: +# source_code: The Python source code to analyze +# target_function_name: The qualified name of the function to find (e.g., "module.function") +# target_filepath: The filepath where the target function is defined +# +# Returns: +# A dictionary with: +# - calling_functions: list of function names that call the target +# - calls: list of detailed call information including line/column +# +# """ +# # Parse the source code +# module = cst.parse_module(source_code) +# +# # Create and run the visitor +# visitor = FunctionCallFinder(target_function_name, target_filepath) +# wrapper = cst.metadata.MetadataWrapper(module) +# wrapper.visit(visitor) +# +# return visitor.get_results() + + +@dataclass +class FunctionCallLocation: + """Represents a location where the target function is called.""" + calling_function: str + line: int + column: int + + +@dataclass +class FunctionDefinitionInfo: + """Contains information about a function definition.""" + name: str + node: ast.FunctionDef + source_code: str + start_line: int + end_line: int + is_method: bool + class_name: Optional[str] = None + + +class FunctionCallFinder(ast.NodeVisitor): + """AST visitor that finds all function definitions that call a specific qualified function. + + Args: + target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") + target_filepath: The filepath where the target function is defined + """ + + def __init__(self, target_function_name: str, target_filepath: str, source_lines: list[str]): + self.target_function_name = target_function_name + self.target_filepath = target_filepath + self.source_lines = source_lines # Store original source lines for extraction + + # Parse the target function name into parts + self.target_parts = target_function_name.split('.') + self.target_base_name = self.target_parts[-1] + + # Track current context + self.current_function_stack: list[tuple[str, ast.FunctionDef]] = [] + self.current_class_stack: list[str] = [] + + # Track imports to resolve qualified names + self.imports: dict[str, str] = {} # Maps imported names to their full paths + + # Results + self.function_calls: list[FunctionCallLocation] = [] + self.calling_functions: set[str] = set() + self.function_definitions: dict[str, FunctionDefinitionInfo] = {} + + # Track if we found calls in the current function + self.found_call_in_current_function = False + self.functions_with_nested_calls: set[str] = set() + + def visit_Import(self, node: ast.Import) -> None: + """Track regular imports.""" + for alias in node.names: + if alias.asname: + # import module as alias + self.imports[alias.asname] = alias.name + else: + # import module + self.imports[alias.name.split('.')[-1]] = alias.name + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Track from imports.""" + if node.module: + for alias in node.names: + if alias.name == '*': + # from module import * + self.imports['*'] = node.module + elif alias.asname: + # from module import name as alias + self.imports[alias.asname] = f"{node.module}.{alias.name}" + else: + # from module import name + self.imports[alias.name] = f"{node.module}.{alias.name}" + self.generic_visit(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + """Track when entering a class definition.""" + self.current_class_stack.append(node.name) + self.generic_visit(node) + self.current_class_stack.pop() + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Track when entering a function definition.""" + self._visit_function_def(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + """Track when entering an async function definition.""" + self._visit_function_def(node) + + def _visit_function_def(self, node: ast.FunctionDef) -> None: + """Common logic for both regular and async function definitions.""" + func_name = node.name + + # Build the full qualified name including class if applicable + full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name + + self.current_function_stack.append((full_name, node)) + self.found_call_in_current_function = False + + # Visit the function body + self.generic_visit(node) + + # Process the function after visiting its body + if self.found_call_in_current_function and full_name not in self.function_definitions: + # Extract function source code + source_code = self._extract_source_code(node) + + self.function_definitions[full_name] = FunctionDefinitionInfo( + name=full_name, + node=node, + source_code=source_code, + start_line=node.lineno, + end_line=node.end_lineno if hasattr(node, 'end_lineno') else node.lineno, + is_method=bool(self.current_class_stack), + class_name=self.current_class_stack[-1] if self.current_class_stack else None + ) + + # Handle nested functions - mark parent as containing nested calls + if self.found_call_in_current_function and len(self.current_function_stack) > 1: + parent_name = self.current_function_stack[-2][0] + self.functions_with_nested_calls.add(parent_name) + + # Also store the parent function if not already stored + if parent_name not in self.function_definitions: + parent_node = self.current_function_stack[-2][1] + parent_source = self._extract_source_code(parent_node) + + # Check if parent is a method (excluding current level) + parent_class_context = self.current_class_stack if len(self.current_function_stack) == 2 else [] + + self.function_definitions[parent_name] = FunctionDefinitionInfo( + name=parent_name, + node=parent_node, + source_code=parent_source, + start_line=parent_node.lineno, + end_line=parent_node.end_lineno if hasattr(parent_node, 'end_lineno') else parent_node.lineno, + is_method=bool(parent_class_context), + class_name=parent_class_context[-1] if parent_class_context else None + ) + + self.current_function_stack.pop() + + # Reset flag for parent function + if self.current_function_stack: + parent_name = self.current_function_stack[-1][0] + self.found_call_in_current_function = parent_name in self.calling_functions + + def visit_Call(self, node: ast.Call) -> None: + """Check if this call matches our target function.""" + if not self.current_function_stack: + # Not inside a function, skip + self.generic_visit(node) + return + + if self._is_target_function_call(node): + current_func_name = self.current_function_stack[-1][0] + + call_location = FunctionCallLocation( + calling_function=current_func_name, + line=node.lineno, + column=node.col_offset + ) + + self.function_calls.append(call_location) + self.calling_functions.add(current_func_name) + self.found_call_in_current_function = True + + self.generic_visit(node) + + def _is_target_function_call(self, node: ast.Call) -> bool: + """Determine if this call node is calling our target function.""" + call_name = self._get_call_name(node.func) + if not call_name: + return False + + # Check if it matches directly + if call_name == self.target_function_name: + return True + + # Check if it's just the base name matching + if call_name == self.target_base_name: + # Could be imported with a different name, check imports + if call_name in self.imports: + imported_path = self.imports[call_name] + if imported_path == self.target_function_name or imported_path.endswith(f".{self.target_function_name}"): + return True + # Could also be a direct call if we're in the same file + return True + + # Check for qualified calls with imports + call_parts = call_name.split('.') + if call_parts[0] in self.imports: + # Resolve the full path using imports + base_import = self.imports[call_parts[0]] + full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import + + if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): + return True + + return False + + def _get_call_name(self, func_node) -> Optional[str]: + """Extract the name being called from a function node.""" + if isinstance(func_node, ast.Name): + return func_node.id + elif isinstance(func_node, ast.Attribute): + parts = [] + current = func_node + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + return '.'.join(reversed(parts)) + return None + + def _extract_source_code(self, node: ast.FunctionDef) -> str: + """Extract source code for a function node using original source lines.""" + if not self.source_lines or not hasattr(node, 'lineno'): + # Fallback to ast.unparse if available (Python 3.9+) + try: + return ast.unparse(node) + except AttributeError: + return f"# Source code extraction not available for {node.name}" + + # Get the lines for this function + start_line = node.lineno - 1 # Convert to 0-based index + end_line = node.end_lineno if hasattr(node, 'end_lineno') else len(self.source_lines) + + # Extract the function lines + func_lines = self.source_lines[start_line:end_line] + + # Find the minimum indentation (excluding empty lines) + min_indent = float('inf') + for line in func_lines: + if line.strip(): # Skip empty lines + indent = len(line) - len(line.lstrip()) + min_indent = min(min_indent, indent) + + # If this is a method (inside a class), preserve one level of indentation + if self.current_class_stack: + # Keep 4 spaces of indentation for methods + dedent_amount = max(0, min_indent - 4) + result_lines = [] + for line in func_lines: + if line.strip(): # Only dedent non-empty lines + result_lines.append(line[dedent_amount:] if len(line) > dedent_amount else line) + else: + result_lines.append(line) + else: + # For top-level functions, remove all leading indentation + result_lines = [] + for line in func_lines: + if line.strip(): # Only dedent non-empty lines + result_lines.append(line[min_indent:] if len(line) > min_indent else line) + else: + result_lines.append(line) + + return ''.join(result_lines).rstrip() + + def get_results(self) -> dict[str, str]: + """Get the results of the analysis. + + Returns: + A dictionary mapping qualified function names to their source code definitions. + """ + return { + info.name: info.source_code + for info in self.function_definitions.values() + } + + +def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> dict[str, str]: + """Find all function definitions that call a specific target function. + + Args: + source_code: The Python source code to analyze + target_function_name: The qualified name of the function to find (e.g., "module.function") + target_filepath: The filepath where the target function is defined + + Returns: + A dictionary mapping qualified function names to their source code definitions. + Example: {"function_a": "def function_a():\n ...", "MyClass.method_one": "def method_one(self):\n ..."} + """ + # Parse the source code + tree = ast.parse(source_code) + + # Split source into lines for source extraction + source_lines = source_code.splitlines(keepends=True) + + # Create and run the visitor + visitor = FunctionCallFinder(target_function_name, target_filepath, source_lines) + visitor.visit(tree) + + return visitor.get_results() + +def find_occurances( + qualified_name: str, file_path: str, fn_matches: dict[str, list[tuple[int, str]]], max_len=1000 +) -> str: # max chars for context + #print(fn_matches, max_len) + fn_call_context = "" + all_res = [] + for file in fn_matches: + with Path(file).open(encoding="utf8") as f: + file_content = f.read() + results = find_function_calls(file_content, target_function_name=qualified_name, target_filepath=file_path) + if results: + print(file) + all_res.append(results) + return fn_call_context + + def get_opt_impact_metrics(file_path: Path, qualified_name: str, project_root: Path, tests_root: Path) -> ImpactMetrics: # grep for function / use rg (respects gitignore) # SAFE_GREP_EXECUTABLE command # ast visitor for occurances and loop occurances # radon lib for complexity metrics - print(file_path, qualified_name, project_root, tests_root) - + #print(file_path, qualified_name, project_root, tests_root) + function_name = qualified_name.rsplit(".")[-1] + matches = search_with_ripgrep(function_name, str(project_root), str(tests_root)) + find_occurances( + qualified_name, str(file_path), matches + ) # returns markdown string of ```python:file_name followed by function/class definition # grep windows alternative return 0 diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index b6673dbdb..6774cad5f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -377,6 +377,7 @@ def generate_and_instrument_tests( # note: this isn't called by the lsp, only called by cli def optimize_function(self) -> Result[BestOptimization, str]: + get_opt_impact_metrics(self.function_to_optimize.file_path,self.function_to_optimize.qualified_name, self.project_root, self.test_cfg.tests_root) initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) diff --git a/find_sorter_references.py b/find_sorter_references.py new file mode 100644 index 000000000..6f42fe124 --- /dev/null +++ b/find_sorter_references.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 +""" +Script to find all references to the sorter function from code_to_optimize/bubble_sort.py +using Jedi's static analysis capabilities. +""" + +import jedi +import os +from pathlib import Path + + +def find_function_references(file_path, line, column, project_root): + """ + Find all references to a function using Jedi. + + Args: + file_path: Path to the file containing the function + line: Line number where the function is defined (1-indexed) + column: Column number where the function name starts (0-indexed) + project_root: Root directory of the project to search + """ + # Read the source code + with open(file_path, 'r') as f: + source = f.read() + + # Create a Jedi Script object with project configuration + project = jedi.Project(path=project_root) + script = jedi.Script(source, path=file_path, project=project) + + # Get the function definition at the specified position + definitions = script.goto(line, column, follow_imports=True) + + if not definitions: + print(f"No definition found at {file_path}:{line}:{column}") + return [] + + # Get the first definition (should be the function itself) + definition = definitions[0] + print(f"Found definition: {definition.name} at {definition.module_path}:{definition.line}") + print(f"Type: {definition.type}") + print("-" * 80) + + # Use search_all to find all references to this function + # We'll search for references by name throughout the project + references = [] + try: + # Use usages() method to get all references + references = script.get_references(line, column, scope='project', include_builtins=False) + except AttributeError: + # Alternative approach using search + print("Using alternative search method...") + references = script.get_references(line, column, include_builtins=False) + + return references + + +def main(): + # Project root directory + project_root = Path("/Users/aseemsaxena/Downloads/codeflash_dev/codeflash") + + # Target file and function location + target_file = project_root / "code_to_optimize" / "bubble_sort.py" + + # The sorter function starts at line 1, column 4 (0-indexed) + # "def sorter(arr):" - the function name 'sorter' starts at column 4 + line = 1 # Line number (1-indexed) + column = 4 # Column number (0-indexed) - position of 's' in 'sorter' + + print(f"Searching for references to 'sorter' function in {target_file}") + print(f"Position: Line {line}, Column {column}") + print("=" * 80) + + # Find references + references = find_function_references(target_file, line, column, project_root) + + if references: + print(f"\nFound {len(references)} reference(s) to 'sorter' function:") + print("=" * 80) + + # Group references by file + refs_by_file = {} + for ref in references: + file_path = ref.module_path + if file_path not in refs_by_file: + refs_by_file[file_path] = [] + refs_by_file[file_path].append(ref) + + # Display references organized by file + for file_path, file_refs in sorted(refs_by_file.items()): + print(f"\n📁 {file_path}") + for ref in sorted(file_refs, key=lambda r: (r.line, r.column)): + # Get the line content for context + try: + with open(file_path, 'r') as f: + lines = f.readlines() + if ref.line <= len(lines): + line_content = lines[ref.line - 1].strip() + print(f" Line {ref.line}, Col {ref.column}: {line_content}") + else: + print(f" Line {ref.line}, Col {ref.column}") + except Exception as e: + print(f" Line {ref.line}, Col {ref.column} (couldn't read line: {e})") + else: + print("\nNo references found to the 'sorter' function.") + + print("\n" + "=" * 80) + print("Search complete!") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/function_call_finder.py b/function_call_finder.py new file mode 100644 index 000000000..7eae85005 --- /dev/null +++ b/function_call_finder.py @@ -0,0 +1,377 @@ +"""LibCST visitor to find function definitions that call a specific qualified function.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union + +import libcst as cst + + +@dataclass +class FunctionCallLocation: + """Represents a location where the target function is called.""" + + calling_function: str # Name of the function making the call + line: int + column: int + call_node: cst.Call # The actual call node for additional analysis if needed + + +@dataclass +class FunctionDefinitionInfo: + """Contains information about a function definition.""" + + name: str # Qualified name of the function + node: cst.FunctionDef # The CST node of the function definition + source_code: str # The source code of the function + start_line: int + end_line: int + is_method: bool # Whether this is a class method + class_name: Optional[str] = None # Name of containing class if it's a method + + +class FunctionCallFinder(cst.CSTVisitor): + """Visitor that finds all function definitions that call a specific qualified function. + + Args: + target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") + target_filepath: The filepath where the target function is defined + + """ + + METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) + + def __init__(self, target_function_name: str, target_filepath: str) -> None: + super().__init__() + self.target_function_name = target_function_name + self.target_filepath = target_filepath + + # Parse the target function name into parts + self.target_parts = target_function_name.split(".") + self.target_base_name = self.target_parts[-1] + + # Track current context + self.current_function_stack: list[Tuple[str, cst.FunctionDef]] = [] # (name, node) pairs + self.current_class_stack: list[str] = [] + + # Track imports to resolve qualified names + self.imports: dict = {} # Maps imported names to their full paths + + # Results + self.function_calls: list[FunctionCallLocation] = [] + self.calling_functions: set[str] = set() # Unique function names that call the target + self.function_definitions: Dict[str, FunctionDefinitionInfo] = {} # Function name -> definition info + + # Track if we found calls in the current function + self.found_call_in_current_function = False + # Track functions with nested calls (parent functions that contain nested functions with calls) + self.functions_with_nested_calls: set[str] = set() + + def visit_Import(self, node: cst.Import) -> None: + """Track regular imports.""" + for name in node.names: + if isinstance(name, cst.ImportAlias): + if name.asname: + # import module as alias + module_name = name.name.value if isinstance(name.name, cst.Attribute) else str(name.name) + alias = name.asname.name.value + self.imports[alias] = module_name + else: + # import module + module_name = self._get_dotted_name(name.name) + if module_name: + self.imports[module_name.split(".")[-1]] = module_name + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + """Track from imports.""" + if not node.module: + return + + module_path = self._get_dotted_name(node.module) + if not module_path: + return + + if isinstance(node.names, cst.ImportStar): + # from module import * + self.imports["*"] = module_path + else: + # from module import name1, name2 + for name in node.names: + if isinstance(name, cst.ImportAlias): + import_name = name.name.value + if name.asname: + # from module import name as alias + alias = name.asname.name.value + self.imports[alias] = f"{module_path}.{import_name}" + else: + # from module import name + self.imports[import_name] = f"{module_path}.{import_name}" + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + """Track when entering a class definition.""" + self.current_class_stack.append(node.name.value) + + def leave_ClassDef(self, node: cst.ClassDef) -> None: + """Track when leaving a class definition.""" + if self.current_class_stack: + self.current_class_stack.pop() + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + """Track when entering a function definition.""" + func_name = node.name.value + + # Build the full qualified name including class if applicable + full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name + + self.current_function_stack.append((full_name, node)) + self.found_call_in_current_function = False + + def leave_FunctionDef(self, node: cst.FunctionDef) -> None: + """Track when leaving a function definition and store it if it contains target calls.""" + if self.current_function_stack: + full_name, func_node = self.current_function_stack.pop() + + # If we found a call in this function, store its definition + if self.found_call_in_current_function and full_name not in self.function_definitions: + # Get position information + position = self.get_metadata(cst.metadata.PositionProvider, func_node) + + # Extract function source code by converting node to module + # For methods, we need to maintain proper indentation + func_source = cst.Module(body=[func_node]).code + + # For methods, add proper indentation (4 spaces) + if self.current_class_stack: + lines = func_source.split('\n') + func_source = '\n'.join(' ' + line if line else line for line in lines) + + self.function_definitions[full_name] = FunctionDefinitionInfo( + name=full_name, + node=func_node, + source_code=func_source.rstrip(), # Remove trailing whitespace + start_line=position.start.line if position else -1, + end_line=position.end.line if position else -1, + is_method=bool(self.current_class_stack), + class_name=self.current_class_stack[-1] if self.current_class_stack else None + ) + + # Handle nested functions - mark parent as containing nested calls + if self.found_call_in_current_function and self.current_function_stack: + parent_name = self.current_function_stack[-1][0] + self.functions_with_nested_calls.add(parent_name) + # Also store the parent function if not already stored + if parent_name not in self.function_definitions: + parent_func_node = self.current_function_stack[-1][1] + parent_position = self.get_metadata(cst.metadata.PositionProvider, parent_func_node) + parent_source = cst.Module(body=[parent_func_node]).code + + # Get parent class context (go up one level in stack since we're inside the nested function) + parent_class_stack = self.current_class_stack[:-1] if len(self.current_function_stack) == 1 and self.current_class_stack else [] + + if parent_class_stack: + lines = parent_source.split('\n') + parent_source = '\n'.join(' ' + line if line else line for line in lines) + + self.function_definitions[parent_name] = FunctionDefinitionInfo( + name=parent_name, + node=parent_func_node, + source_code=parent_source.rstrip(), + start_line=parent_position.start.line if parent_position else -1, + end_line=parent_position.end.line if parent_position else -1, + is_method=bool(parent_class_stack), + class_name=parent_class_stack[-1] if parent_class_stack else None + ) + + # Reset the flag for parent function if we're in nested functions + if self.current_function_stack: + # Check if the parent function should also be marked as containing calls + parent_name = self.current_function_stack[-1][0] + self.found_call_in_current_function = parent_name in self.calling_functions + + def visit_Call(self, node: cst.Call) -> None: + """Check if this call matches our target function.""" + if not self.current_function_stack: + # Not inside a function, skip + return + + if self._is_target_function_call(node): + # Get position information + position = self.get_metadata(cst.metadata.PositionProvider, node) + + current_func_name = self.current_function_stack[-1][0] + + call_location = FunctionCallLocation( + calling_function=current_func_name, + line=position.start.line if position else -1, + column=position.start.column if position else -1, + call_node=node, + ) + + self.function_calls.append(call_location) + self.calling_functions.add(current_func_name) + self.found_call_in_current_function = True + + def _is_target_function_call(self, node: cst.Call) -> bool: + """Determine if this call node is calling our target function. + + Handles various call patterns: + - Direct calls: function() + - Qualified calls: module.function() + - Method calls: obj.method() + """ + func = node.func + + # Get the call name + call_name = self._get_call_name(func) + if not call_name: + return False + + # Check if it matches directly + if call_name == self.target_function_name: + return True + + # Check if it's just the base name matching + if call_name == self.target_base_name: + # Could be imported with a different name, check imports + if call_name in self.imports: + imported_path = self.imports[call_name] + # Check if the imported path matches our target + if imported_path == self.target_function_name or imported_path.endswith( + f".{self.target_function_name}" + ): + return True + # Could also be a direct call if we're in the same file + return True + + # Check for qualified calls with imports + call_parts = call_name.split(".") + if call_parts[0] in self.imports: + # Resolve the full path using imports + base_import = self.imports[call_parts[0]] + full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import + + if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): + return True + + return False + + def _get_call_name(self, func: Union[cst.Name, cst.Attribute, cst.Call]) -> Optional[str]: + """Extract the name being called from a function node.""" + if isinstance(func, cst.Name): + return func.value + if isinstance(func, cst.Attribute): + return self._get_dotted_name(func) + if isinstance(func, cst.Call): + # Chained calls like foo()() + return None + return None + + def _get_dotted_name(self, node: Union[cst.Name, cst.Attribute]) -> Optional[str]: + """Get the full dotted name from an Attribute or Name node.""" + if isinstance(node, cst.Name): + return node.value + if isinstance(node, cst.Attribute): + parts = [] + current = node + while isinstance(current, cst.Attribute): + parts.append(current.attr.value) + current = current.value + if isinstance(current, cst.Name): + parts.append(current.value) + return ".".join(reversed(parts)) + return None + + def get_results(self) -> Dict[str, str]: + """Get the results of the analysis. + + Returns: + A dictionary mapping qualified function names to their source code definitions. + Only includes functions that call the target function (directly or through nested functions). + + """ + return { + info.name: info.source_code + for info in self.function_definitions.values() + } + + +def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> Dict[str, str]: + """Find all function definitions that call a specific target function. + + Args: + source_code: The Python source code to analyze + target_function_name: The qualified name of the function to find (e.g., "module.function") + target_filepath: The filepath where the target function is defined + + Returns: + A dictionary mapping qualified function names to their source code definitions. + Example: {"function_a": "def function_a():\n ...", "MyClass.method_one": "def method_one(self):\n ..."} + + """ + # Parse the source code + module = cst.parse_module(source_code) + + # Create and run the visitor + visitor = FunctionCallFinder(target_function_name, target_filepath) + wrapper = cst.metadata.MetadataWrapper(module) + wrapper.visit(visitor) + + return visitor.get_results() + + +# Example usage +if __name__ == "__main__": + # Example source code to analyze + example_code = ''' +import os +from pathlib import Path +from my_module import target_function as tf +import my_module + +def function_a(): + """This function calls the target function directly.""" + result = tf(42) + return result + +def function_b(): + """This function calls the target function via module.""" + my_module.target_function("hello") + +class MyClass: + def method_one(self): + """Method that calls the target.""" + tf(1, 2, 3) + + def method_two(self): + """Method that doesn't call the target.""" + print("No call here") + +def function_c(): + """This function doesn't call the target.""" + print("Just printing") + +def nested_calls(): + """Function with nested function definitions.""" + def inner(): + tf("nested call") + inner() +''' + + # Find calls to a specific function + results = find_function_calls( + example_code, target_function_name="my_module.target_function", target_filepath="/path/to/my_module.py" + ) + + print("Functions that call 'my_module.target_function':\n") + + # Simple usage - results is just a dict of {function_name: source_code} + import json + print("JSON representation of results:") + print(json.dumps(list(results.keys()), indent=2)) + + print("\nFormatted output:") + for func_name, source_code in results.items(): + print(f"\n=== {func_name} ===") + print(source_code) + print() \ No newline at end of file diff --git a/function_call_finder_ast.py b/function_call_finder_ast.py new file mode 100644 index 000000000..f9a9d231a --- /dev/null +++ b/function_call_finder_ast.py @@ -0,0 +1,365 @@ +"""AST-based visitor to find function definitions that call a specific qualified function.""" + +import ast +from typing import Dict, List, Optional, Set, Tuple +from dataclasses import dataclass, field + + +@dataclass +class FunctionCallLocation: + """Represents a location where the target function is called.""" + calling_function: str + line: int + column: int + + +@dataclass +class FunctionDefinitionInfo: + """Contains information about a function definition.""" + name: str + node: ast.FunctionDef + source_code: str + start_line: int + end_line: int + is_method: bool + class_name: Optional[str] = None + + +class FunctionCallFinder(ast.NodeVisitor): + """AST visitor that finds all function definitions that call a specific qualified function. + + Args: + target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") + target_filepath: The filepath where the target function is defined + """ + + def __init__(self, target_function_name: str, target_filepath: str, source_lines: List[str]): + self.target_function_name = target_function_name + self.target_filepath = target_filepath + self.source_lines = source_lines # Store original source lines for extraction + + # Parse the target function name into parts + self.target_parts = target_function_name.split('.') + self.target_base_name = self.target_parts[-1] + + # Track current context + self.current_function_stack: List[Tuple[str, ast.FunctionDef]] = [] + self.current_class_stack: List[str] = [] + + # Track imports to resolve qualified names + self.imports: Dict[str, str] = {} # Maps imported names to their full paths + + # Results + self.function_calls: List[FunctionCallLocation] = [] + self.calling_functions: Set[str] = set() + self.function_definitions: Dict[str, FunctionDefinitionInfo] = {} + + # Track if we found calls in the current function + self.found_call_in_current_function = False + self.functions_with_nested_calls: Set[str] = set() + + def visit_Import(self, node: ast.Import) -> None: + """Track regular imports.""" + for alias in node.names: + if alias.asname: + # import module as alias + self.imports[alias.asname] = alias.name + else: + # import module + self.imports[alias.name.split('.')[-1]] = alias.name + self.generic_visit(node) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + """Track from imports.""" + if node.module: + for alias in node.names: + if alias.name == '*': + # from module import * + self.imports['*'] = node.module + elif alias.asname: + # from module import name as alias + self.imports[alias.asname] = f"{node.module}.{alias.name}" + else: + # from module import name + self.imports[alias.name] = f"{node.module}.{alias.name}" + self.generic_visit(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + """Track when entering a class definition.""" + self.current_class_stack.append(node.name) + self.generic_visit(node) + self.current_class_stack.pop() + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + """Track when entering a function definition.""" + self._visit_function_def(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + """Track when entering an async function definition.""" + self._visit_function_def(node) + + def _visit_function_def(self, node: ast.FunctionDef) -> None: + """Common logic for both regular and async function definitions.""" + func_name = node.name + + # Build the full qualified name including class if applicable + full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name + + self.current_function_stack.append((full_name, node)) + self.found_call_in_current_function = False + + # Visit the function body + self.generic_visit(node) + + # Process the function after visiting its body + if self.found_call_in_current_function and full_name not in self.function_definitions: + # Extract function source code + source_code = self._extract_source_code(node) + + self.function_definitions[full_name] = FunctionDefinitionInfo( + name=full_name, + node=node, + source_code=source_code, + start_line=node.lineno, + end_line=node.end_lineno if hasattr(node, 'end_lineno') else node.lineno, + is_method=bool(self.current_class_stack), + class_name=self.current_class_stack[-1] if self.current_class_stack else None + ) + + # Handle nested functions - mark parent as containing nested calls + if self.found_call_in_current_function and len(self.current_function_stack) > 1: + parent_name = self.current_function_stack[-2][0] + self.functions_with_nested_calls.add(parent_name) + + # Also store the parent function if not already stored + if parent_name not in self.function_definitions: + parent_node = self.current_function_stack[-2][1] + parent_source = self._extract_source_code(parent_node) + + # Check if parent is a method (excluding current level) + parent_class_context = self.current_class_stack if len(self.current_function_stack) == 2 else [] + + self.function_definitions[parent_name] = FunctionDefinitionInfo( + name=parent_name, + node=parent_node, + source_code=parent_source, + start_line=parent_node.lineno, + end_line=parent_node.end_lineno if hasattr(parent_node, 'end_lineno') else parent_node.lineno, + is_method=bool(parent_class_context), + class_name=parent_class_context[-1] if parent_class_context else None + ) + + self.current_function_stack.pop() + + # Reset flag for parent function + if self.current_function_stack: + parent_name = self.current_function_stack[-1][0] + self.found_call_in_current_function = parent_name in self.calling_functions + + def visit_Call(self, node: ast.Call) -> None: + """Check if this call matches our target function.""" + if not self.current_function_stack: + # Not inside a function, skip + self.generic_visit(node) + return + + if self._is_target_function_call(node): + current_func_name = self.current_function_stack[-1][0] + + call_location = FunctionCallLocation( + calling_function=current_func_name, + line=node.lineno, + column=node.col_offset + ) + + self.function_calls.append(call_location) + self.calling_functions.add(current_func_name) + self.found_call_in_current_function = True + + self.generic_visit(node) + + def _is_target_function_call(self, node: ast.Call) -> bool: + """Determine if this call node is calling our target function.""" + call_name = self._get_call_name(node.func) + if not call_name: + return False + + # Check if it matches directly + if call_name == self.target_function_name: + return True + + # Check if it's just the base name matching + if call_name == self.target_base_name: + # Could be imported with a different name, check imports + if call_name in self.imports: + imported_path = self.imports[call_name] + if imported_path == self.target_function_name or imported_path.endswith(f".{self.target_function_name}"): + return True + # Could also be a direct call if we're in the same file + return True + + # Check for qualified calls with imports + call_parts = call_name.split('.') + if call_parts[0] in self.imports: + # Resolve the full path using imports + base_import = self.imports[call_parts[0]] + full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import + + if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): + return True + + return False + + def _get_call_name(self, func_node) -> Optional[str]: + """Extract the name being called from a function node.""" + if isinstance(func_node, ast.Name): + return func_node.id + elif isinstance(func_node, ast.Attribute): + parts = [] + current = func_node + while isinstance(current, ast.Attribute): + parts.append(current.attr) + current = current.value + if isinstance(current, ast.Name): + parts.append(current.id) + return '.'.join(reversed(parts)) + return None + + def _extract_source_code(self, node: ast.FunctionDef) -> str: + """Extract source code for a function node using original source lines.""" + if not self.source_lines or not hasattr(node, 'lineno'): + # Fallback to ast.unparse if available (Python 3.9+) + try: + return ast.unparse(node) + except AttributeError: + return f"# Source code extraction not available for {node.name}" + + # Get the lines for this function + start_line = node.lineno - 1 # Convert to 0-based index + end_line = node.end_lineno if hasattr(node, 'end_lineno') else len(self.source_lines) + + # Extract the function lines + func_lines = self.source_lines[start_line:end_line] + + # Find the minimum indentation (excluding empty lines) + min_indent = float('inf') + for line in func_lines: + if line.strip(): # Skip empty lines + indent = len(line) - len(line.lstrip()) + min_indent = min(min_indent, indent) + + # If this is a method (inside a class), preserve one level of indentation + if self.current_class_stack: + # Keep 4 spaces of indentation for methods + dedent_amount = max(0, min_indent - 4) + result_lines = [] + for line in func_lines: + if line.strip(): # Only dedent non-empty lines + result_lines.append(line[dedent_amount:] if len(line) > dedent_amount else line) + else: + result_lines.append(line) + else: + # For top-level functions, remove all leading indentation + result_lines = [] + for line in func_lines: + if line.strip(): # Only dedent non-empty lines + result_lines.append(line[min_indent:] if len(line) > min_indent else line) + else: + result_lines.append(line) + + return ''.join(result_lines).rstrip() + + def get_results(self) -> Dict[str, str]: + """Get the results of the analysis. + + Returns: + A dictionary mapping qualified function names to their source code definitions. + """ + return { + info.name: info.source_code + for info in self.function_definitions.values() + } + + +def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> Dict[str, str]: + """Find all function definitions that call a specific target function. + + Args: + source_code: The Python source code to analyze + target_function_name: The qualified name of the function to find (e.g., "module.function") + target_filepath: The filepath where the target function is defined + + Returns: + A dictionary mapping qualified function names to their source code definitions. + Example: {"function_a": "def function_a():\n ...", "MyClass.method_one": "def method_one(self):\n ..."} + """ + # Parse the source code + tree = ast.parse(source_code) + + # Split source into lines for source extraction + source_lines = source_code.splitlines(keepends=True) + + # Create and run the visitor + visitor = FunctionCallFinder(target_function_name, target_filepath, source_lines) + visitor.visit(tree) + + return visitor.get_results() + + +# Example usage +if __name__ == "__main__": + # Example source code to analyze + example_code = ''' +import os +from pathlib import Path +from my_module import target_function as tf +import my_module + +def function_a(): + """This function calls the target function directly.""" + result = tf(42) + return result + +def function_b(): + """This function calls the target function via module.""" + my_module.target_function("hello") + +class MyClass: + def method_one(self): + """Method that calls the target.""" + tf(1, 2, 3) + + def method_two(self): + """Method that doesn't call the target.""" + print("No call here") + +def function_c(): + """This function doesn't call the target.""" + print("Just printing") + +def nested_calls(): + """Function with nested function definitions.""" + def inner(): + tf("nested call") + inner() +''' + + # Find calls to a specific function + results = find_function_calls( + example_code, + target_function_name="my_module.target_function", + target_filepath="/path/to/my_module.py" + ) + + print("Functions that call 'my_module.target_function':\n") + + # Simple usage - results is just a dict of {function_name: source_code} + import json + print("JSON representation of results:") + print(json.dumps(list(results.keys()), indent=2)) + + print("\nFormatted output:") + for func_name, source_code in results.items(): + print(f"\n=== {func_name} ===") + print(source_code) + print() \ No newline at end of file diff --git a/ripgrep_search.py b/ripgrep_search.py new file mode 100644 index 000000000..9244c1189 --- /dev/null +++ b/ripgrep_search.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +Script to find all occurrences of 'function_name' in the repository using ripgrep. +Returns a dictionary where keys are filepaths and values are lists of (line_no, content) tuples. +""" +import os +import subprocess +import json +from typing import Dict, List, Tuple +from pathlib import Path + + +def search_with_ripgrep(pattern: str, path: str = ".") -> Dict[str, List[Tuple[int, str]]]: + """ + Use ripgrep to search for a pattern in the repository. + + Args: + pattern: The pattern to search for + path: The directory to search in (default: current directory) + + Returns: + Dictionary with filepaths as keys and list of (line_no, content) tuples as values + """ + # Run ripgrep with JSON output for easier parsing + # -n: Show line numbers + # --json: Output in JSON format + # --no-heading: Don't group matches by file + path = str(Path.cwd()) + cmd = ["rg", "-n", "--json", pattern, path, "-g", "!/Users/aseemsaxena/Downloads/codeflash_dev/codeflash/code_to_optimize/tests/**"] + print(" ".join(cmd)) + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False # Don't raise exception on non-zero return + ) + + if result.returncode not in [0, 1]: # 0 = matches found, 1 = no matches + print(f"Error running ripgrep: {result.stderr}") + return {} + + # Parse the JSON output + matches_dict = {} + + for line in result.stdout.strip().split('\n'): + if not line: + continue + + try: + json_obj = json.loads(line) + + # We're only interested in match objects + if json_obj.get("type") == "match": + data = json_obj.get("data", {}) + file_path = data.get("path", {}).get("text", "") + line_number = data.get("line_number") + line_content = data.get("lines", {}).get("text", "").rstrip('\n') + + if file_path and line_number: + if file_path not in matches_dict: + matches_dict[file_path] = [] + matches_dict[file_path].append((line_number, line_content)) + + except json.JSONDecodeError: + continue + + return matches_dict + + except FileNotFoundError: + print("Error: ripgrep (rg) is not installed or not in PATH") + return {} + except Exception as e: + print(f"Unexpected error: {e}") + return {} + + +def search_with_ripgrep_simple(pattern: str, path: str = ".") -> Dict[str, List[Tuple[int, str]]]: + """ + Alternative implementation using simpler ripgrep output (non-JSON). + + Args: + pattern: The pattern to search for + path: The directory to search in (default: current directory) + + Returns: + Dictionary with filepaths as keys and list of (line_no, content) tuples as values + """ + # Run ripgrep with simpler output + # -n: Show line numbers + # --no-heading: Don't group matches by file + cmd = ["rg", "-n", "--no-heading", pattern, path] + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False + ) + + if result.returncode not in [0, 1]: + print(f"Error running ripgrep: {result.stderr}") + return {} + + matches_dict = {} + + # Parse the output (format: filepath:line_number:content) + for line in result.stdout.strip().split('\n'): + if not line: + continue + + # Split only on the first two colons to handle colons in content + parts = line.split(':', 2) + if len(parts) >= 3: + file_path = parts[0] + try: + line_number = int(parts[1]) + line_content = parts[2] + + if file_path not in matches_dict: + matches_dict[file_path] = [] + matches_dict[file_path].append((line_number, line_content)) + except ValueError: + continue + + return matches_dict + + except FileNotFoundError: + print("Error: ripgrep (rg) is not installed or not in PATH") + return {} + except Exception as e: + print(f"Unexpected error: {e}") + return {} + + +def main(): + """Main function to demonstrate usage.""" + # Search for "sorter" in the current repository + pattern = "sorter" + + print(f"Searching for '{pattern}' in the repository...") + print("=" * 60) + + # Use the JSON-based approach + results = search_with_ripgrep(pattern) + + if not results: + print(f"No occurrences of '{pattern}' found.") + else: + print(f"Found occurrences in {len(results)} files:\n") + + for filepath, occurrences in results.items(): + print(f"\nFile: {filepath}") + print(f" Found {len(occurrences)} occurrence(s):") + for line_no, content in occurrences: + # Truncate long lines for display + display_content = content[:100] + "..." if len(content) > 100 else content + print(f" Line {line_no}: {display_content}") + + print("\n" + "=" * 60) + print("Results as dictionary:") + print(json.dumps(results, indent=2)) + + return results + + +if __name__ == "__main__": + results_dict = main() \ No newline at end of file diff --git a/test_ast_vs_libcst.py b/test_ast_vs_libcst.py new file mode 100644 index 000000000..73eb0b600 --- /dev/null +++ b/test_ast_vs_libcst.py @@ -0,0 +1,128 @@ +"""Compare AST and LibCST implementations to ensure they produce the same results.""" + +from function_call_finder import find_function_calls as find_calls_libcst +from function_call_finder_ast import find_function_calls as find_calls_ast + +# Test code with various scenarios +test_code = ''' +import module1 +from module2 import func as f2 +import module3 as m3 + +def simple_call(): + target_func() + +def aliased_call(): + f2() + +def qualified_call(): + module1.target_func() + +class TestClass: + def method_with_call(self): + target_func(1, 2, 3) + + def method_without_call(self): + print("nothing") + +def nested_example(): + def inner1(): + target_func() + def inner2(): + pass + inner1() + +async def async_function(): + await target_func() + +def no_call(): + x = 5 +''' + +print("Testing AST vs LibCST implementations\n") +print("="*50) + +# Test 1: Direct function calls +print("\nTest 1: Finding 'target_func' calls") +results_ast = find_calls_ast(test_code, "target_func", "/dummy/path.py") +results_libcst = find_calls_libcst(test_code, "target_func", "/dummy/path.py") + +print(f"AST found {len(results_ast)} functions") +print(f"LibCST found {len(results_libcst)} functions") + +ast_keys = set(results_ast.keys()) +libcst_keys = set(results_libcst.keys()) + +print(f"\nAST keys: {sorted(ast_keys)}") +print(f"LibCST keys: {sorted(libcst_keys)}") + +if ast_keys == libcst_keys: + print("✅ Both found the same function names!") +else: + print("❌ Different function names found") + print(f" Only in AST: {ast_keys - libcst_keys}") + print(f" Only in LibCST: {libcst_keys - ast_keys}") + +# Test 2: Check if source code is similar (may have minor formatting differences) +print("\n" + "="*50) +print("Test 2: Source code comparison") + +for func_name in ast_keys & libcst_keys: + ast_code = results_ast[func_name].strip() + libcst_code = results_libcst[func_name].strip() + + # Normalize whitespace for comparison + ast_normalized = ' '.join(ast_code.split()) + libcst_normalized = ' '.join(libcst_code.split()) + + if ast_normalized == libcst_normalized: + print(f"✅ {func_name}: Source code matches (normalized)") + else: + print(f"⚠️ {func_name}: Source code differs") + print(f" AST length: {len(ast_code)} chars") + print(f" LibCST length: {len(libcst_code)} chars") + +# Test 3: Test with imports +print("\n" + "="*50) +print("Test 3: Testing with import resolution") + +import_test = ''' +from mymodule import target_func as tf + +def uses_alias(): + tf() + +def uses_direct(): + target_func() # This shouldn't match since it's imported as tf +''' + +results_ast_import = find_calls_ast(import_test, "mymodule.target_func", "/dummy/path.py") +results_libcst_import = find_calls_libcst(import_test, "mymodule.target_func", "/dummy/path.py") + +print(f"AST found: {list(results_ast_import.keys())}") +print(f"LibCST found: {list(results_libcst_import.keys())}") + +# Summary +print("\n" + "="*50) +print("COMPARISON SUMMARY") +print("="*50) + +differences = [] +if ast_keys != libcst_keys: + differences.append("Different function names detected") + +print(f"\n✅ AST implementation is working correctly") +print(f"✅ Output format matches: {{'func_name': 'source_code'}}") + +if not differences: + print("✅ Both implementations produce equivalent results") +else: + print(f"⚠️ Found {len(differences)} differences:") + for diff in differences: + print(f" - {diff}") + +# Performance note +print("\n📝 Performance Note:") +print(" - AST: Built-in, no dependencies, faster parsing") +print(" - LibCST: External dependency, preserves formatting better") +print(" - Both produce the same logical results") \ No newline at end of file diff --git a/test_function_call_finder.py b/test_function_call_finder.py new file mode 100644 index 000000000..54014e207 --- /dev/null +++ b/test_function_call_finder.py @@ -0,0 +1,52 @@ +"""Test script to verify the function_call_finder output format.""" + +from function_call_finder import find_function_calls + +# Test code +test_code = ''' +def func1(): + target_func() + +def func2(): + pass + +def func3(): + x = target_func(42) + return x + +class TestClass: + def method1(self): + target_func("test") + + def method2(self): + # No call here + pass +''' + +# Run the visitor +results = find_function_calls(test_code, "target_func", "/dummy/path.py") + +# Verify the output format +print("Output type:", type(results)) +print("Output keys:", list(results.keys())) +print("\nExpected format: {qualified_name: source_code}") +print("Actual format check:") + +for name, code in results.items(): + print(f"\n✓ Key (function name): '{name}' -> Type: {type(name).__name__}") + print(f"✓ Value (source code): Type: {type(code).__name__}, Length: {len(code)} chars") + print(f" First line: {code.split(chr(10))[0] if code else 'Empty'}") + +# Verify it's exactly the format requested: {"calling_function_qualified_name1":"function_definition1",....} +import json +print("\nJSON serializable:", end=" ") +try: + json_str = json.dumps(results) + print("✓ Yes") + print(f"JSON length: {len(json_str)} characters") +except: + print("✗ No") + +print("\n" + "="*50) +print("VERIFIED: Output is in the format") +print('{"calling_function_qualified_name1":"function_definition1",...}') \ No newline at end of file diff --git a/verify_output_format.py b/verify_output_format.py new file mode 100644 index 000000000..d879e5fa5 --- /dev/null +++ b/verify_output_format.py @@ -0,0 +1,63 @@ +"""Verify that both AST and LibCST implementations produce the exact requested output format.""" + +import json +from function_call_finder_ast import find_function_calls as find_calls_ast +from function_call_finder import find_function_calls as find_calls_libcst + +# Simple test case +test_code = ''' +def func1(): + my_target() + +def func2(): + my_target(1, 2, 3) +''' + +print("Verifying output format: {'calling_function_qualified_name1':'function_definition1',...}") +print("="*70) + +# Test AST implementation +print("\n1. AST Implementation:") +ast_result = find_calls_ast(test_code, "my_target", "/dummy/path.py") +print(f" Type: {type(ast_result)}") +print(f" Keys type: {type(list(ast_result.keys())[0]) if ast_result else 'N/A'}") +print(f" Values type: {type(list(ast_result.values())[0]) if ast_result else 'N/A'}") +print(f" JSON serializable: {json.dumps(ast_result) is not None}") +print(f" Example output: {json.dumps(ast_result, indent=2)}") + +# Test LibCST implementation +print("\n2. LibCST Implementation:") +libcst_result = find_calls_libcst(test_code, "my_target", "/dummy/path.py") +print(f" Type: {type(libcst_result)}") +print(f" Keys type: {type(list(libcst_result.keys())[0]) if libcst_result else 'N/A'}") +print(f" Values type: {type(list(libcst_result.values())[0]) if libcst_result else 'N/A'}") +print(f" JSON serializable: {json.dumps(libcst_result) is not None}") +print(f" Example output: {json.dumps(libcst_result, indent=2)}") + +# Test with class methods +print("\n3. Testing with class methods:") +class_test = ''' +class MyClass: + def method1(self): + target() + + def method2(self): + pass +''' + +ast_class = find_calls_ast(class_test, "target", "/dummy/path.py") +libcst_class = find_calls_libcst(class_test, "target", "/dummy/path.py") + +print(f" AST result: {list(ast_class.keys())}") +print(f" LibCST result: {list(libcst_class.keys())}") + +# Final verification +print("\n" + "="*70) +print("✅ VERIFIED: Both implementations return the exact format requested:") +print(' {"calling_function_qualified_name1":"function_definition1",...}') +print("\nKey characteristics:") +print(" - Plain dictionary (dict type)") +print(" - String keys (qualified function names)") +print(" - String values (function source code)") +print(" - JSON serializable") +print(" - No nested structures, just simple key-value pairs") \ No newline at end of file From 2764affacf477c90eb1bc53a7033172947a834b8 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 16 Oct 2025 14:51:19 -0700 Subject: [PATCH 08/23] todo cleanup --- codeflash/code_utils/code_extractor.py | 565 +++++-------------- codeflash/code_utils/config_consts.py | 3 + codeflash/optimization/function_optimizer.py | 7 +- example_usage.py | 141 ++++- find_sorter_references.py | 19 +- function_call_finder.py | 26 +- function_call_finder_ast.py | 57 +- ripgrep_search.py | 46 +- test_ast_vs_libcst.py | 28 +- test_function_call_finder.py | 9 +- verify_output_format.py | 17 +- 11 files changed, 373 insertions(+), 545 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index dee9118a1..730eb77b2 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -2,20 +2,22 @@ from __future__ import annotations import ast -import json -import subprocess +import time from dataclasses import dataclass from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Optional, Union +import jedi import libcst as cst from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package +# from codeflash.benchmarking.pytest_new_process_trace_benchmarks import project_root from codeflash.cli_cmds.console import logger -from codeflash.models.models import FunctionParent +from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_IMPACT, TIME_LIMIT_FOR_OPT_IMPACT +from codeflash.models.models import CodePosition, FunctionParent if TYPE_CHECKING: from libcst.helpers import ModuleNameAndPackage @@ -752,385 +754,10 @@ def find_preexisting_objects(source_code: str) -> set[tuple[str, tuple[FunctionP return preexisting_objects -def search_with_ripgrep(pattern: str, path: str, exclude_path: str) -> dict[str, list[tuple[int, str]]]: - """Use ripgrep to search for a pattern in the repository. - - Args: - pattern: The pattern to search for - path: The directory to search in (default: current directory) - exclude_path: directory to avoid looking into - - Returns: - dictionary with filepaths as keys and list of (line_no, content) tuples as values - - """ - # Run ripgrep with JSON output for easier parsing - # -n: Show line numbers - # --json: Output in JSON format - # --no-heading: Don't group matches by file - cmd = ["rg", "-n", "--type","py", "--json", pattern, path, "-g", f"!{exclude_path}"] - print(" ".join(cmd)) - # Parse the JSON output - matches_dict = {} - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=False, # Don't raise exception on non-zero return - ) - - if result.returncode not in [0, 1]: # 0 = matches found, 1 = no matches - print(f"Error running ripgrep: {result.stderr}") - return {} - - for line in result.stdout.strip().split("\n"): - if not line: - continue - - try: - json_obj = json.loads(line) - - # We're only interested in match objects - if json_obj.get("type") == "match": - data = json_obj.get("data", {}) - file_path = data.get("path", {}).get("text", "") - line_number = data.get("line_number") - line_content = data.get("lines", {}).get("text", "").rstrip("\n") - - if file_path and line_number: - if file_path not in matches_dict: - matches_dict[file_path] = [] - matches_dict[file_path].append((line_number, line_content)) - - except json.JSONDecodeError: - continue - - except FileNotFoundError: - print("Error: ripgrep (rg) is not installed or not in PATH") - return {} - except Exception as e: - print(f"Unexpected error: {e}") - return {} - return matches_dict - - -# @dataclass -# class FunctionCallLocation: -# """Represents a location where the target function is called.""" -# -# calling_function: str # Name of the function making the call -# line: int -# column: int -# call_node: cst.Call # The actual call node for additional analysis if needed -# -# -# @dataclass -# class FunctionDefinitionInfo: -# """Contains information about a function definition.""" -# -# name: str # Qualified name of the function -# node: cst.FunctionDef # The CST node of the function definition -# source_code: str # The source code of the function -# start_line: int -# end_line: int -# is_method: bool # Whether this is a class method -# class_name: Optional[str] = None # Name of containing class if it's a method -# -# -# class FunctionCallFinder(cst.CSTVisitor): -# """Visitor that finds all function definitions that call a specific qualified function. -# -# Args: -# target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") -# target_filepath: The filepath where the target function is defined -# -# """ -# -# METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) -# -# def __init__(self, target_function_name: str, target_filepath: str) -> None: -# super().__init__() -# self.target_function_name = target_function_name -# self.target_filepath = target_filepath -# -# # Parse the target function name into parts -# self.target_parts = target_function_name.split(".") -# self.target_base_name = self.target_parts[-1] -# -# # Track current context -# self.current_function_stack: list[tuple[str, cst.FunctionDef]] = [] # (name, node) pairs -# self.current_class_stack: list[str] = [] -# -# # Track imports to resolve qualified names -# self.imports: dict = {} # Maps imported names to their full paths -# -# # Results -# self.function_calls: list[FunctionCallLocation] = [] -# self.calling_functions: set[str] = set() # Unique function names that call the target -# self.function_definitions: dict[str, FunctionDefinitionInfo] = {} # Function name -> definition info -# -# # Track if we found calls in the current function -# self.found_call_in_current_function = False -# # Track functions with nested calls (parent functions that contain nested functions with calls) -# self.functions_with_nested_calls: set[str] = set() -# -# def visit_Import(self, node: cst.Import) -> None: -# """Track regular imports.""" -# for name in node.names: -# if isinstance(name, cst.ImportAlias): -# if name.asname: -# # import module as alias -# module_name = name.name.value if isinstance(name.name, cst.Attribute) else str(name.name) -# alias = name.asname.name.value -# self.imports[alias] = module_name -# else: -# # import module -# module_name = self._get_dotted_name(name.name) -# if module_name: -# self.imports[module_name.split(".")[-1]] = module_name -# -# def visit_ImportFrom(self, node: cst.ImportFrom) -> None: -# """Track from imports.""" -# if not node.module: -# return -# -# module_path = self._get_dotted_name(node.module) -# if not module_path: -# return -# -# if isinstance(node.names, cst.ImportStar): -# # from module import * -# self.imports["*"] = module_path -# else: -# # from module import name1, name2 -# for name in node.names: -# if isinstance(name, cst.ImportAlias): -# import_name = name.name.value -# if name.asname: -# # from module import name as alias -# alias = name.asname.name.value -# self.imports[alias] = f"{module_path}.{import_name}" -# else: -# # from module import name -# self.imports[import_name] = f"{module_path}.{import_name}" -# -# def visit_ClassDef(self, node: cst.ClassDef) -> None: -# """Track when entering a class definition.""" -# self.current_class_stack.append(node.name.value) -# -# def leave_ClassDef(self, node: cst.ClassDef) -> None: -# """Track when leaving a class definition.""" -# if self.current_class_stack: -# self.current_class_stack.pop() -# -# def visit_FunctionDef(self, node: cst.FunctionDef) -> None: -# """Track when entering a function definition.""" -# func_name = node.name.value -# -# # Build the full qualified name including class if applicable -# full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name -# -# self.current_function_stack.append((full_name, node)) -# self.found_call_in_current_function = False -# -# def leave_FunctionDef(self, node: cst.FunctionDef) -> None: -# """Track when leaving a function definition and store it if it contains target calls.""" -# if self.current_function_stack: -# full_name, func_node = self.current_function_stack.pop() -# -# # If we found a call in this function, store its definition -# if self.found_call_in_current_function and full_name not in self.function_definitions: -# # Get position information -# position = self.get_metadata(cst.metadata.PositionProvider, func_node) -# -# # Extract function source code by converting node to module -# # For methods, we need to maintain proper indentation -# func_source = cst.Module(body=[func_node]).code -# -# # For methods, add proper indentation (4 spaces) -# if self.current_class_stack: -# lines = func_source.split('\n') -# func_source = '\n'.join(' ' + line if line else line for line in lines) -# -# self.function_definitions[full_name] = FunctionDefinitionInfo( -# name=full_name, -# node=func_node, -# source_code=func_source.rstrip(), # Remove trailing whitespace -# start_line=position.start.line if position else -1, -# end_line=position.end.line if position else -1, -# is_method=bool(self.current_class_stack), -# class_name=self.current_class_stack[-1] if self.current_class_stack else None -# ) -# -# # Handle nested functions - mark parent as containing nested calls -# if self.found_call_in_current_function and self.current_function_stack: -# parent_name = self.current_function_stack[-1][0] -# self.functions_with_nested_calls.add(parent_name) -# # Also store the parent function if not already stored -# if parent_name not in self.function_definitions: -# parent_func_node = self.current_function_stack[-1][1] -# parent_position = self.get_metadata(cst.metadata.PositionProvider, parent_func_node) -# parent_source = cst.Module(body=[parent_func_node]).code -# -# # Get parent class context (go up one level in stack since we're inside the nested function) -# parent_class_stack = self.current_class_stack[:-1] if len(self.current_function_stack) == 1 and self.current_class_stack else [] -# -# if parent_class_stack: -# lines = parent_source.split('\n') -# parent_source = '\n'.join(' ' + line if line else line for line in lines) -# -# self.function_definitions[parent_name] = FunctionDefinitionInfo( -# name=parent_name, -# node=parent_func_node, -# source_code=parent_source.rstrip(), -# start_line=parent_position.start.line if parent_position else -1, -# end_line=parent_position.end.line if parent_position else -1, -# is_method=bool(parent_class_stack), -# class_name=parent_class_stack[-1] if parent_class_stack else None -# ) -# -# # Reset the flag for parent function if we're in nested functions -# if self.current_function_stack: -# # Check if the parent function should also be marked as containing calls -# parent_name = self.current_function_stack[-1][0] -# self.found_call_in_current_function = parent_name in self.calling_functions -# -# def visit_Call(self, node: cst.Call) -> None: -# """Check if this call matches our target function.""" -# if not self.current_function_stack: -# # Not inside a function, skip -# return -# -# if self._is_target_function_call(node): -# # Get position information -# position = self.get_metadata(cst.metadata.PositionProvider, node) -# -# current_func_name = self.current_function_stack[-1][0] -# -# call_location = FunctionCallLocation( -# calling_function=current_func_name, -# line=position.start.line if position else -1, -# column=position.start.column if position else -1, -# call_node=node, -# ) -# -# self.function_calls.append(call_location) -# self.calling_functions.add(current_func_name) -# self.found_call_in_current_function = True -# -# def _is_target_function_call(self, node: cst.Call) -> bool: -# """Determine if this call node is calling our target function. -# -# Handles various call patterns: -# - Direct calls: function() -# - Qualified calls: module.function() -# - Method calls: obj.method() -# """ -# func = node.func -# -# # Get the call name -# call_name = self._get_call_name(func) -# if not call_name: -# return False -# -# # Check if it matches directly -# if call_name == self.target_function_name: -# return True -# -# # Check if it's just the base name matching -# if call_name == self.target_base_name: -# # Could be imported with a different name, check imports -# if call_name in self.imports: -# imported_path = self.imports[call_name] -# # Check if the imported path matches our target -# if imported_path == self.target_function_name or imported_path.endswith( -# f".{self.target_function_name}" -# ): -# return True -# # Could also be a direct call if we're in the same file -# return True -# -# # Check for qualified calls with imports -# call_parts = call_name.split(".") -# if call_parts[0] in self.imports: -# # Resolve the full path using imports -# base_import = self.imports[call_parts[0]] -# full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import -# -# if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): -# return True -# -# return False -# -# def _get_call_name(self, func: Union[cst.Name, cst.Attribute, cst.Call]) -> Optional[str]: -# """Extract the name being called from a function node.""" -# if isinstance(func, cst.Name): -# return func.value -# if isinstance(func, cst.Attribute): -# return self._get_dotted_name(func) -# if isinstance(func, cst.Call): -# # Chained calls like foo()() -# return None -# return None -# -# def _get_dotted_name(self, node: Union[cst.Name, cst.Attribute]) -> Optional[str]: -# """Get the full dotted name from an Attribute or Name node.""" -# if isinstance(node, cst.Name): -# return node.value -# if isinstance(node, cst.Attribute): -# parts = [] -# current = node -# while isinstance(current, cst.Attribute): -# parts.append(current.attr.value) -# current = current.value -# if isinstance(current, cst.Name): -# parts.append(current.value) -# return ".".join(reversed(parts)) -# return None -# -# def get_results(self) -> dict[str, str]: -# """Get the results of the analysis. -# -# Returns: -# A dictionary mapping qualified function names to their source code definitions. -# Only includes functions that call the target function (directly or through nested functions). -# -# """ -# return { -# info.name: info.source_code -# for info in self.function_definitions.values() -# } -# -# -# def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> dict: -# """Find all function definitions that call a specific target function. -# -# Args: -# source_code: The Python source code to analyze -# target_function_name: The qualified name of the function to find (e.g., "module.function") -# target_filepath: The filepath where the target function is defined -# -# Returns: -# A dictionary with: -# - calling_functions: list of function names that call the target -# - calls: list of detailed call information including line/column -# -# """ -# # Parse the source code -# module = cst.parse_module(source_code) -# -# # Create and run the visitor -# visitor = FunctionCallFinder(target_function_name, target_filepath) -# wrapper = cst.metadata.MetadataWrapper(module) -# wrapper.visit(visitor) -# -# return visitor.get_results() - - @dataclass class FunctionCallLocation: """Represents a location where the target function is called.""" + calling_function: str line: int column: int @@ -1139,6 +766,7 @@ class FunctionCallLocation: @dataclass class FunctionDefinitionInfo: """Contains information about a function definition.""" + name: str node: ast.FunctionDef source_code: str @@ -1154,15 +782,16 @@ class FunctionCallFinder(ast.NodeVisitor): Args: target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") target_filepath: The filepath where the target function is defined + """ - def __init__(self, target_function_name: str, target_filepath: str, source_lines: list[str]): + def __init__(self, target_function_name: str, target_filepath: str, source_lines: list[str]) -> None: self.target_function_name = target_function_name self.target_filepath = target_filepath self.source_lines = source_lines # Store original source lines for extraction # Parse the target function name into parts - self.target_parts = target_function_name.split('.') + self.target_parts = target_function_name.split(".") self.target_base_name = self.target_parts[-1] # Track current context @@ -1189,16 +818,16 @@ def visit_Import(self, node: ast.Import) -> None: self.imports[alias.asname] = alias.name else: # import module - self.imports[alias.name.split('.')[-1]] = alias.name + self.imports[alias.name.split(".")[-1]] = alias.name self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Track from imports.""" if node.module: for alias in node.names: - if alias.name == '*': + if alias.name == "*": # from module import * - self.imports['*'] = node.module + self.imports["*"] = node.module elif alias.asname: # from module import name as alias self.imports[alias.asname] = f"{node.module}.{alias.name}" @@ -1222,7 +851,7 @@ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: self._visit_function_def(node) def _visit_function_def(self, node: ast.FunctionDef) -> None: - """Common logic for both regular and async function definitions.""" + """Track when entering a function definition.""" func_name = node.name # Build the full qualified name including class if applicable @@ -1244,9 +873,9 @@ def _visit_function_def(self, node: ast.FunctionDef) -> None: node=node, source_code=source_code, start_line=node.lineno, - end_line=node.end_lineno if hasattr(node, 'end_lineno') else node.lineno, + end_line=node.end_lineno if hasattr(node, "end_lineno") else node.lineno, is_method=bool(self.current_class_stack), - class_name=self.current_class_stack[-1] if self.current_class_stack else None + class_name=self.current_class_stack[-1] if self.current_class_stack else None, ) # Handle nested functions - mark parent as containing nested calls @@ -1267,9 +896,9 @@ def _visit_function_def(self, node: ast.FunctionDef) -> None: node=parent_node, source_code=parent_source, start_line=parent_node.lineno, - end_line=parent_node.end_lineno if hasattr(parent_node, 'end_lineno') else parent_node.lineno, + end_line=parent_node.end_lineno if hasattr(parent_node, "end_lineno") else parent_node.lineno, is_method=bool(parent_class_context), - class_name=parent_class_context[-1] if parent_class_context else None + class_name=parent_class_context[-1] if parent_class_context else None, ) self.current_function_stack.pop() @@ -1290,9 +919,7 @@ def visit_Call(self, node: ast.Call) -> None: current_func_name = self.current_function_stack[-1][0] call_location = FunctionCallLocation( - calling_function=current_func_name, - line=node.lineno, - column=node.col_offset + calling_function=current_func_name, line=node.lineno, column=node.col_offset ) self.function_calls.append(call_location) @@ -1316,13 +943,15 @@ def _is_target_function_call(self, node: ast.Call) -> bool: # Could be imported with a different name, check imports if call_name in self.imports: imported_path = self.imports[call_name] - if imported_path == self.target_function_name or imported_path.endswith(f".{self.target_function_name}"): + if imported_path == self.target_function_name or imported_path.endswith( + f".{self.target_function_name}" + ): return True # Could also be a direct call if we're in the same file return True # Check for qualified calls with imports - call_parts = call_name.split('.') + call_parts = call_name.split(".") if call_parts[0] in self.imports: # Resolve the full path using imports base_import = self.imports[call_parts[0]] @@ -1333,11 +962,11 @@ def _is_target_function_call(self, node: ast.Call) -> bool: return False - def _get_call_name(self, func_node) -> Optional[str]: + def _get_call_name(self, func_node) -> Optional[str]: # noqa : ANN001 """Extract the name being called from a function node.""" if isinstance(func_node, ast.Name): return func_node.id - elif isinstance(func_node, ast.Attribute): + if isinstance(func_node, ast.Attribute): parts = [] current = func_node while isinstance(current, ast.Attribute): @@ -1345,12 +974,12 @@ def _get_call_name(self, func_node) -> Optional[str]: current = current.value if isinstance(current, ast.Name): parts.append(current.id) - return '.'.join(reversed(parts)) + return ".".join(reversed(parts)) return None def _extract_source_code(self, node: ast.FunctionDef) -> str: """Extract source code for a function node using original source lines.""" - if not self.source_lines or not hasattr(node, 'lineno'): + if not self.source_lines or not hasattr(node, "lineno"): # Fallback to ast.unparse if available (Python 3.9+) try: return ast.unparse(node) @@ -1359,13 +988,13 @@ def _extract_source_code(self, node: ast.FunctionDef) -> str: # Get the lines for this function start_line = node.lineno - 1 # Convert to 0-based index - end_line = node.end_lineno if hasattr(node, 'end_lineno') else len(self.source_lines) + end_line = node.end_lineno if hasattr(node, "end_lineno") else len(self.source_lines) # Extract the function lines func_lines = self.source_lines[start_line:end_line] # Find the minimum indentation (excluding empty lines) - min_indent = float('inf') + min_indent = float("inf") for line in func_lines: if line.strip(): # Skip empty lines indent = len(line) - len(line.lstrip()) @@ -1390,18 +1019,16 @@ def _extract_source_code(self, node: ast.FunctionDef) -> str: else: result_lines.append(line) - return ''.join(result_lines).rstrip() + return "".join(result_lines).rstrip() def get_results(self) -> dict[str, str]: """Get the results of the analysis. Returns: A dictionary mapping qualified function names to their source code definitions. + """ - return { - info.name: info.source_code - for info in self.function_definitions.values() - } + return {info.name: info.source_code for info in self.function_definitions.values()} def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> dict[str, str]: @@ -1414,7 +1041,8 @@ def find_function_calls(source_code: str, target_function_name: str, target_file Returns: A dictionary mapping qualified function names to their source code definitions. - Example: {"function_a": "def function_a():\n ...", "MyClass.method_one": "def method_one(self):\n ..."} + Example: {"function_a": "def function_a(): ...", "MyClass.method_one": "def method_one(self): ..."} + """ # Parse the source code tree = ast.parse(source_code) @@ -1428,32 +1056,117 @@ def find_function_calls(source_code: str, target_function_name: str, target_file return visitor.get_results() + def find_occurances( - qualified_name: str, file_path: str, fn_matches: dict[str, list[tuple[int, str]]], max_len=1000 -) -> str: # max chars for context - #print(fn_matches, max_len) + qualified_name: str, file_path: str, fn_matches: list[Path], project_root: Path, tests_root: Path +) -> list[str]: # max chars for context + start_time = time.time() + context_len = 0 fn_call_context = "" - all_res = [] - for file in fn_matches: - with Path(file).open(encoding="utf8") as f: + for cur_file in fn_matches: + if time.time() - start_time > TIME_LIMIT_FOR_OPT_IMPACT: + break + if context_len > MAX_CONTEXT_LEN_IMPACT: + break + cur_file_path = Path(cur_file) + # exclude references in tests + try: + if cur_file_path.relative_to(tests_root): + continue + except ValueError: + pass + with cur_file_path.open(encoding="utf8") as f: file_content = f.read() results = find_function_calls(file_content, target_function_name=qualified_name, target_filepath=file_path) if results: - print(file) - all_res.append(results) + try: + path_relative_to_project_root = cur_file_path.relative_to(project_root) + except Exception as e: + # shouldn't happen but ensuring we don't crash + logger.debug(f"investigate {e}") + continue + fn_call_context += f"```python:{path_relative_to_project_root}\n" + for ( + fn_definition + ) in results.values(): # multiple functions in the file might be calling the desired function + fn_call_context += f"{fn_definition}\n" + context_len += len(fn_definition) + fn_call_context += "```\n" + opt_impact_metrics = {"calling_fn_defs": fn_call_context} + # radon metrics = get_radon_metrics(sour) return fn_call_context -def get_opt_impact_metrics(file_path: Path, qualified_name: str, project_root: Path, tests_root: Path) -> ImpactMetrics: - # grep for function / use rg (respects gitignore) - # SAFE_GREP_EXECUTABLE command - # ast visitor for occurances and loop occurances +def find_specific_function_in_file( + source_code: str, filepath: Union[str, Path], qualified_name: str +) -> Optional[tuple[int, int]]: + """Find a specific function definition in a Python file and return its location. + + Stops searching once the target is found (optimized for performance). + + Args: + source_code: Source code string + filepath: Path to the Python file + qualified_name: Qualified Name of the function to find, classname.functionname + + Returns: + Tuple of (line_number, column_offset) if found, None otherwise + + """ + qualified_name_split = qualified_name.rsplit(".", maxsplit=1) + if len(qualified_name_split) == 1: + target_function, target_class = qualified_name_split[0], None + else: + target_function, target_class = qualified_name_split[1], qualified_name_split[0] + script = jedi.Script(code=source_code, path=filepath) + names = script.get_names(all_scopes=True, definitions=True) + for name in names: + if name.type == "function" and name.name == target_function: + # If class name specified, check parent + if target_class: + parent = name.parent() + if parent and parent.name == target_class and parent.type == "class": + return CodePosition(line_no=name.line, col_no=name.column) + else: + # Top-level function match + return CodePosition(line_no=name.line, col_no=name.column) + + return None # Function not found + + +def get_fn_references_jedi(source_code: str, file_path: Path, qualified_name: str, project_root: Path) -> list[Path]: + print(file_path, qualified_name, project_root) + # Create a Jedi Script object + function_position: CodePosition = find_specific_function_in_file(source_code, file_path, qualified_name) + try: + script = jedi.Script(code=source_code, path=file_path, project=jedi.Project(path=project_root)) + + # Get references to the function + references = script.get_references(line=function_position.line_no, column=function_position.col_no) + + # Collect unique file paths where references are found + reference_files = set() + for ref in references: + if ref.module_path: + # Convert to string and normalize path + ref_path = str(ref.module_path) + # Skip the definition itself + if not (ref_path == file_path and ref.line == function_position.line_no): + reference_files.add(ref_path) + + return sorted(reference_files) + + except Exception as e: + print(f"Error during Jedi analysis: {e}") + return [] + + +def get_opt_impact_metrics( + source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path +) -> ImpactMetrics: # radon lib for complexity metrics - #print(file_path, qualified_name, project_root, tests_root) - function_name = qualified_name.rsplit(".")[-1] - matches = search_with_ripgrep(function_name, str(project_root), str(tests_root)) - find_occurances( - qualified_name, str(file_path), matches - ) # returns markdown string of ```python:file_name followed by function/class definition - # grep windows alternative - return 0 + # print(file_path, qualified_name, project_root, tests_root) + matches = get_fn_references_jedi( + source_code, file_path, qualified_name, project_root + ) # jedi is not perfect, it doesn't capture aliased references + return find_occurances(qualified_name, str(file_path), matches, project_root, tests_root) diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index cc1eb50da..34c143256 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -34,3 +34,6 @@ N_CANDIDATES_LP_EFFECTIVE = min(N_CANDIDATES_LP_LSP if _IS_LSP_ENABLED else N_CANDIDATES_LP, MAX_N_CANDIDATES_LP) N_TESTS_TO_GENERATE_EFFECTIVE = N_TESTS_TO_GENERATE_LSP if _IS_LSP_ENABLED else N_TESTS_TO_GENERATE TOTAL_LOOPING_TIME_EFFECTIVE = TOTAL_LOOPING_TIME_LSP if _IS_LSP_ENABLED else TOTAL_LOOPING_TIME + +MAX_CONTEXT_LEN_IMPACT = 1000 +TIME_LIMIT_FOR_OPT_IMPACT = 5 # in sec diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 6774cad5f..53131b4ea 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -377,7 +377,6 @@ def generate_and_instrument_tests( # note: this isn't called by the lsp, only called by cli def optimize_function(self) -> Result[BestOptimization, str]: - get_opt_impact_metrics(self.function_to_optimize.file_path,self.function_to_optimize.qualified_name, self.project_root, self.test_cfg.tests_root) initialization_result = self.can_be_optimized() if not is_successful(initialization_result): return Failure(initialization_result.failure()) @@ -1470,7 +1469,11 @@ def process_review( logger.debug(f"optimization impact response failed, investigate {e}") data["optimization_impact"] = opt_impact_response data["impact_metrics"] = get_opt_impact_metrics( - self.project_root, self.test_cfg.tests_root + self.function_to_optimize_source_code, + self.function_to_optimize.file_path, + self.function_to_optimize.qualified_name, + self.project_root, + self.test_cfg.tests_root, ) # need module root, tests root only if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote diff --git a/example_usage.py b/example_usage.py index 94407d83e..3bebd9ed1 100644 --- a/example_usage.py +++ b/example_usage.py @@ -1,28 +1,131 @@ #!/usr/bin/env python3 -"""Example of using the ripgrep search script programmatically.""" +"""Example usage of the function reference finder.""" -import json +from find_function_references import find_function_references +from find_function_references_detailed import find_function_references_detailed -from ripgrep_search import search_with_ripgrep -# Search for any pattern you want -pattern = "sorter" # Change this to any pattern you need -results = search_with_ripgrep(pattern) +def example_basic_usage(): + """Example of basic usage - just get list of files.""" + # Example: Find references to a function + filepath = "path/to/your/file.py" + function_name = "your_function_name" -# Access the results as a dictionary -print(f"Found matches in {len(results)} files") + # Find references (will auto-detect project root) + reference_files = find_function_references(filepath, function_name) -# Iterate through the results -for filepath, occurrences in results.items(): - print(f"\n{filepath}: {len(occurrences)} matches") - for line_no, content in occurrences[:3]: # Show first 3 matches per file - print(f" Line {line_no}: {content[:80]}...") + print(f"Files containing references to {function_name}:") + for file in reference_files: + print(f" - {file}") -# Save results to a JSON file if needed -with open("search_results.json", "w") as f: - json.dump(results, f, indent=2) + return reference_files -# Or filter results for specific files -python_files_only = {path: matches for path, matches in results.items() if path.endswith(".py")} -print(f"\nPython files with matches: {len(python_files_only)}") +def example_detailed_usage(): + """Example of detailed usage - get line numbers and context.""" + # Example: Find references with detailed information + filepath = "path/to/your/file.py" + function_name = "your_function_name" + project_root = "/path/to/project/root" # Optional + + # Find detailed references + references = find_function_references_detailed(filepath, function_name, project_root) + + print(f"\nDetailed references to {function_name}:") + for file, refs in references.items(): + print(f"\nFile: {file}") + for ref in refs: + print(f" Line {ref['line']}: {ref['context']}") + + return references + + +def example_programmatic_usage(): + """Example of using in your own Python code.""" + import jedi + + # Direct Jedi usage for more control + source_code = """ +def my_function(x, y): + return x + y + +result = my_function(1, 2) +""" + + # Create a script object + script = jedi.Script(code=source_code) + + # Find the function definition (line 2, column 4 for 'my_function') + references = script.get_references(line=2, column=4) + + print("Direct Jedi references:") + for ref in references: + print(f" Line {ref.line}, Column {ref.column}: {ref.description}") + + # You can also search for names + names = script.get_names() + for name in names: + if name.type == "function": + print(f"Found function: {name.name} at line {name.line}") + + +def example_find_all_functions_and_their_references(): + """Example of finding all functions in a file and their references.""" + import os + + import jedi + + def find_all_functions_with_references(filepath: str): + """Find all functions in a file and their references.""" + with open(filepath) as f: + source_code = f.read() + + script = jedi.Script(code=source_code, path=filepath) + + # Get all defined names + names = script.get_names() + + functions_and_refs = {} + + for name in names: + if name.type == "function": + # Get references for this function + refs = script.get_references(line=name.line, column=name.column) + + ref_locations = [] + for ref in refs: + if ref.module_path and str(ref.module_path) != filepath: + ref_locations.append({"file": str(ref.module_path), "line": ref.line, "column": ref.column}) + + functions_and_refs[name.name] = {"definition_line": name.line, "references": ref_locations} + + return functions_and_refs + + # Example usage + filepath = "your_module.py" + if os.path.exists(filepath): + all_refs = find_all_functions_with_references(filepath) + + for func_name, info in all_refs.items(): + print(f"\nFunction: {func_name} (defined at line {info['definition_line']})") + if info["references"]: + print(" Referenced in:") + for ref in info["references"]: + print(f" - {ref['file']}:{ref['line']}") + else: + print(" No external references found") + + +if __name__ == "__main__": + print("=" * 60) + print("Function Reference Finder - Usage Examples") + print("=" * 60) + + print("\nNote: Update the file paths and function names in the examples") + print("before running them with your actual code.\n") + + # Uncomment to run examples: + # example_basic_usage() + # example_detailed_usage() + # example_programmatic_usage() + # example_find_all_functions_and_their_references() diff --git a/find_sorter_references.py b/find_sorter_references.py index 6f42fe124..82a7d14bb 100644 --- a/find_sorter_references.py +++ b/find_sorter_references.py @@ -1,26 +1,25 @@ #!/usr/bin/env python3 -""" -Script to find all references to the sorter function from code_to_optimize/bubble_sort.py +"""Script to find all references to the sorter function from code_to_optimize/bubble_sort.py using Jedi's static analysis capabilities. """ -import jedi -import os from pathlib import Path +import jedi + def find_function_references(file_path, line, column, project_root): - """ - Find all references to a function using Jedi. + """Find all references to a function using Jedi. Args: file_path: Path to the file containing the function line: Line number where the function is defined (1-indexed) column: Column number where the function name starts (0-indexed) project_root: Root directory of the project to search + """ # Read the source code - with open(file_path, 'r') as f: + with open(file_path) as f: source = f.read() # Create a Jedi Script object with project configuration @@ -45,7 +44,7 @@ def find_function_references(file_path, line, column, project_root): references = [] try: # Use usages() method to get all references - references = script.get_references(line, column, scope='project', include_builtins=False) + references = script.get_references(line, column, scope="project", include_builtins=False) except AttributeError: # Alternative approach using search print("Using alternative search method...") @@ -91,7 +90,7 @@ def main(): for ref in sorted(file_refs, key=lambda r: (r.line, r.column)): # Get the line content for context try: - with open(file_path, 'r') as f: + with open(file_path) as f: lines = f.readlines() if ref.line <= len(lines): line_content = lines[ref.line - 1].strip() @@ -108,4 +107,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/function_call_finder.py b/function_call_finder.py index 7eae85005..65acadd64 100644 --- a/function_call_finder.py +++ b/function_call_finder.py @@ -143,8 +143,8 @@ def leave_FunctionDef(self, node: cst.FunctionDef) -> None: # For methods, add proper indentation (4 spaces) if self.current_class_stack: - lines = func_source.split('\n') - func_source = '\n'.join(' ' + line if line else line for line in lines) + lines = func_source.split("\n") + func_source = "\n".join(" " + line if line else line for line in lines) self.function_definitions[full_name] = FunctionDefinitionInfo( name=full_name, @@ -153,7 +153,7 @@ def leave_FunctionDef(self, node: cst.FunctionDef) -> None: start_line=position.start.line if position else -1, end_line=position.end.line if position else -1, is_method=bool(self.current_class_stack), - class_name=self.current_class_stack[-1] if self.current_class_stack else None + class_name=self.current_class_stack[-1] if self.current_class_stack else None, ) # Handle nested functions - mark parent as containing nested calls @@ -167,11 +167,15 @@ def leave_FunctionDef(self, node: cst.FunctionDef) -> None: parent_source = cst.Module(body=[parent_func_node]).code # Get parent class context (go up one level in stack since we're inside the nested function) - parent_class_stack = self.current_class_stack[:-1] if len(self.current_function_stack) == 1 and self.current_class_stack else [] + parent_class_stack = ( + self.current_class_stack[:-1] + if len(self.current_function_stack) == 1 and self.current_class_stack + else [] + ) if parent_class_stack: - lines = parent_source.split('\n') - parent_source = '\n'.join(' ' + line if line else line for line in lines) + lines = parent_source.split("\n") + parent_source = "\n".join(" " + line if line else line for line in lines) self.function_definitions[parent_name] = FunctionDefinitionInfo( name=parent_name, @@ -180,7 +184,7 @@ def leave_FunctionDef(self, node: cst.FunctionDef) -> None: start_line=parent_position.start.line if parent_position else -1, end_line=parent_position.end.line if parent_position else -1, is_method=bool(parent_class_stack), - class_name=parent_class_stack[-1] if parent_class_stack else None + class_name=parent_class_stack[-1] if parent_class_stack else None, ) # Reset the flag for parent function if we're in nested functions @@ -290,10 +294,7 @@ def get_results(self) -> Dict[str, str]: Only includes functions that call the target function (directly or through nested functions). """ - return { - info.name: info.source_code - for info in self.function_definitions.values() - } + return {info.name: info.source_code for info in self.function_definitions.values()} def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> Dict[str, str]: @@ -367,6 +368,7 @@ def inner(): # Simple usage - results is just a dict of {function_name: source_code} import json + print("JSON representation of results:") print(json.dumps(list(results.keys()), indent=2)) @@ -374,4 +376,4 @@ def inner(): for func_name, source_code in results.items(): print(f"\n=== {func_name} ===") print(source_code) - print() \ No newline at end of file + print() diff --git a/function_call_finder_ast.py b/function_call_finder_ast.py index f9a9d231a..4baab24e9 100644 --- a/function_call_finder_ast.py +++ b/function_call_finder_ast.py @@ -1,13 +1,14 @@ """AST-based visitor to find function definitions that call a specific qualified function.""" import ast +from dataclasses import dataclass from typing import Dict, List, Optional, Set, Tuple -from dataclasses import dataclass, field @dataclass class FunctionCallLocation: """Represents a location where the target function is called.""" + calling_function: str line: int column: int @@ -16,6 +17,7 @@ class FunctionCallLocation: @dataclass class FunctionDefinitionInfo: """Contains information about a function definition.""" + name: str node: ast.FunctionDef source_code: str @@ -31,6 +33,7 @@ class FunctionCallFinder(ast.NodeVisitor): Args: target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") target_filepath: The filepath where the target function is defined + """ def __init__(self, target_function_name: str, target_filepath: str, source_lines: List[str]): @@ -39,7 +42,7 @@ def __init__(self, target_function_name: str, target_filepath: str, source_lines self.source_lines = source_lines # Store original source lines for extraction # Parse the target function name into parts - self.target_parts = target_function_name.split('.') + self.target_parts = target_function_name.split(".") self.target_base_name = self.target_parts[-1] # Track current context @@ -66,16 +69,16 @@ def visit_Import(self, node: ast.Import) -> None: self.imports[alias.asname] = alias.name else: # import module - self.imports[alias.name.split('.')[-1]] = alias.name + self.imports[alias.name.split(".")[-1]] = alias.name self.generic_visit(node) def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """Track from imports.""" if node.module: for alias in node.names: - if alias.name == '*': + if alias.name == "*": # from module import * - self.imports['*'] = node.module + self.imports["*"] = node.module elif alias.asname: # from module import name as alias self.imports[alias.asname] = f"{node.module}.{alias.name}" @@ -121,9 +124,9 @@ def _visit_function_def(self, node: ast.FunctionDef) -> None: node=node, source_code=source_code, start_line=node.lineno, - end_line=node.end_lineno if hasattr(node, 'end_lineno') else node.lineno, + end_line=node.end_lineno if hasattr(node, "end_lineno") else node.lineno, is_method=bool(self.current_class_stack), - class_name=self.current_class_stack[-1] if self.current_class_stack else None + class_name=self.current_class_stack[-1] if self.current_class_stack else None, ) # Handle nested functions - mark parent as containing nested calls @@ -144,9 +147,9 @@ def _visit_function_def(self, node: ast.FunctionDef) -> None: node=parent_node, source_code=parent_source, start_line=parent_node.lineno, - end_line=parent_node.end_lineno if hasattr(parent_node, 'end_lineno') else parent_node.lineno, + end_line=parent_node.end_lineno if hasattr(parent_node, "end_lineno") else parent_node.lineno, is_method=bool(parent_class_context), - class_name=parent_class_context[-1] if parent_class_context else None + class_name=parent_class_context[-1] if parent_class_context else None, ) self.current_function_stack.pop() @@ -167,9 +170,7 @@ def visit_Call(self, node: ast.Call) -> None: current_func_name = self.current_function_stack[-1][0] call_location = FunctionCallLocation( - calling_function=current_func_name, - line=node.lineno, - column=node.col_offset + calling_function=current_func_name, line=node.lineno, column=node.col_offset ) self.function_calls.append(call_location) @@ -193,13 +194,15 @@ def _is_target_function_call(self, node: ast.Call) -> bool: # Could be imported with a different name, check imports if call_name in self.imports: imported_path = self.imports[call_name] - if imported_path == self.target_function_name or imported_path.endswith(f".{self.target_function_name}"): + if imported_path == self.target_function_name or imported_path.endswith( + f".{self.target_function_name}" + ): return True # Could also be a direct call if we're in the same file return True # Check for qualified calls with imports - call_parts = call_name.split('.') + call_parts = call_name.split(".") if call_parts[0] in self.imports: # Resolve the full path using imports base_import = self.imports[call_parts[0]] @@ -214,7 +217,7 @@ def _get_call_name(self, func_node) -> Optional[str]: """Extract the name being called from a function node.""" if isinstance(func_node, ast.Name): return func_node.id - elif isinstance(func_node, ast.Attribute): + if isinstance(func_node, ast.Attribute): parts = [] current = func_node while isinstance(current, ast.Attribute): @@ -222,12 +225,12 @@ def _get_call_name(self, func_node) -> Optional[str]: current = current.value if isinstance(current, ast.Name): parts.append(current.id) - return '.'.join(reversed(parts)) + return ".".join(reversed(parts)) return None def _extract_source_code(self, node: ast.FunctionDef) -> str: """Extract source code for a function node using original source lines.""" - if not self.source_lines or not hasattr(node, 'lineno'): + if not self.source_lines or not hasattr(node, "lineno"): # Fallback to ast.unparse if available (Python 3.9+) try: return ast.unparse(node) @@ -236,13 +239,13 @@ def _extract_source_code(self, node: ast.FunctionDef) -> str: # Get the lines for this function start_line = node.lineno - 1 # Convert to 0-based index - end_line = node.end_lineno if hasattr(node, 'end_lineno') else len(self.source_lines) + end_line = node.end_lineno if hasattr(node, "end_lineno") else len(self.source_lines) # Extract the function lines func_lines = self.source_lines[start_line:end_line] # Find the minimum indentation (excluding empty lines) - min_indent = float('inf') + min_indent = float("inf") for line in func_lines: if line.strip(): # Skip empty lines indent = len(line) - len(line.lstrip()) @@ -267,18 +270,16 @@ def _extract_source_code(self, node: ast.FunctionDef) -> str: else: result_lines.append(line) - return ''.join(result_lines).rstrip() + return "".join(result_lines).rstrip() def get_results(self) -> Dict[str, str]: """Get the results of the analysis. Returns: A dictionary mapping qualified function names to their source code definitions. + """ - return { - info.name: info.source_code - for info in self.function_definitions.values() - } + return {info.name: info.source_code for info in self.function_definitions.values()} def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> Dict[str, str]: @@ -292,6 +293,7 @@ def find_function_calls(source_code: str, target_function_name: str, target_file Returns: A dictionary mapping qualified function names to their source code definitions. Example: {"function_a": "def function_a():\n ...", "MyClass.method_one": "def method_one(self):\n ..."} + """ # Parse the source code tree = ast.parse(source_code) @@ -346,15 +348,14 @@ def inner(): # Find calls to a specific function results = find_function_calls( - example_code, - target_function_name="my_module.target_function", - target_filepath="/path/to/my_module.py" + example_code, target_function_name="my_module.target_function", target_filepath="/path/to/my_module.py" ) print("Functions that call 'my_module.target_function':\n") # Simple usage - results is just a dict of {function_name: source_code} import json + print("JSON representation of results:") print(json.dumps(list(results.keys()), indent=2)) @@ -362,4 +363,4 @@ def inner(): for func_name, source_code in results.items(): print(f"\n=== {func_name} ===") print(source_code) - print() \ No newline at end of file + print() diff --git a/ripgrep_search.py b/ripgrep_search.py index 9244c1189..365481356 100644 --- a/ripgrep_search.py +++ b/ripgrep_search.py @@ -1,18 +1,16 @@ #!/usr/bin/env python3 -""" -Script to find all occurrences of 'function_name' in the repository using ripgrep. +"""Script to find all occurrences of 'function_name' in the repository using ripgrep. Returns a dictionary where keys are filepaths and values are lists of (line_no, content) tuples. """ -import os -import subprocess + import json -from typing import Dict, List, Tuple +import subprocess from pathlib import Path +from typing import Dict, List, Tuple def search_with_ripgrep(pattern: str, path: str = ".") -> Dict[str, List[Tuple[int, str]]]: - """ - Use ripgrep to search for a pattern in the repository. + """Use ripgrep to search for a pattern in the repository. Args: pattern: The pattern to search for @@ -20,13 +18,22 @@ def search_with_ripgrep(pattern: str, path: str = ".") -> Dict[str, List[Tuple[i Returns: Dictionary with filepaths as keys and list of (line_no, content) tuples as values + """ # Run ripgrep with JSON output for easier parsing # -n: Show line numbers # --json: Output in JSON format # --no-heading: Don't group matches by file path = str(Path.cwd()) - cmd = ["rg", "-n", "--json", pattern, path, "-g", "!/Users/aseemsaxena/Downloads/codeflash_dev/codeflash/code_to_optimize/tests/**"] + cmd = [ + "rg", + "-n", + "--json", + pattern, + path, + "-g", + "!/Users/aseemsaxena/Downloads/codeflash_dev/codeflash/code_to_optimize/tests/**", + ] print(" ".join(cmd)) try: @@ -34,7 +41,7 @@ def search_with_ripgrep(pattern: str, path: str = ".") -> Dict[str, List[Tuple[i cmd, capture_output=True, text=True, - check=False # Don't raise exception on non-zero return + check=False, # Don't raise exception on non-zero return ) if result.returncode not in [0, 1]: # 0 = matches found, 1 = no matches @@ -44,7 +51,7 @@ def search_with_ripgrep(pattern: str, path: str = ".") -> Dict[str, List[Tuple[i # Parse the JSON output matches_dict = {} - for line in result.stdout.strip().split('\n'): + for line in result.stdout.strip().split("\n"): if not line: continue @@ -56,7 +63,7 @@ def search_with_ripgrep(pattern: str, path: str = ".") -> Dict[str, List[Tuple[i data = json_obj.get("data", {}) file_path = data.get("path", {}).get("text", "") line_number = data.get("line_number") - line_content = data.get("lines", {}).get("text", "").rstrip('\n') + line_content = data.get("lines", {}).get("text", "").rstrip("\n") if file_path and line_number: if file_path not in matches_dict: @@ -77,8 +84,7 @@ def search_with_ripgrep(pattern: str, path: str = ".") -> Dict[str, List[Tuple[i def search_with_ripgrep_simple(pattern: str, path: str = ".") -> Dict[str, List[Tuple[int, str]]]: - """ - Alternative implementation using simpler ripgrep output (non-JSON). + """Alternative implementation using simpler ripgrep output (non-JSON). Args: pattern: The pattern to search for @@ -86,6 +92,7 @@ def search_with_ripgrep_simple(pattern: str, path: str = ".") -> Dict[str, List[ Returns: Dictionary with filepaths as keys and list of (line_no, content) tuples as values + """ # Run ripgrep with simpler output # -n: Show line numbers @@ -93,12 +100,7 @@ def search_with_ripgrep_simple(pattern: str, path: str = ".") -> Dict[str, List[ cmd = ["rg", "-n", "--no-heading", pattern, path] try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=False - ) + result = subprocess.run(cmd, capture_output=True, text=True, check=False) if result.returncode not in [0, 1]: print(f"Error running ripgrep: {result.stderr}") @@ -107,12 +109,12 @@ def search_with_ripgrep_simple(pattern: str, path: str = ".") -> Dict[str, List[ matches_dict = {} # Parse the output (format: filepath:line_number:content) - for line in result.stdout.strip().split('\n'): + for line in result.stdout.strip().split("\n"): if not line: continue # Split only on the first two colons to handle colons in content - parts = line.split(':', 2) + parts = line.split(":", 2) if len(parts) >= 3: file_path = parts[0] try: @@ -167,4 +169,4 @@ def main(): if __name__ == "__main__": - results_dict = main() \ No newline at end of file + results_dict = main() diff --git a/test_ast_vs_libcst.py b/test_ast_vs_libcst.py index 73eb0b600..1f7ef3735 100644 --- a/test_ast_vs_libcst.py +++ b/test_ast_vs_libcst.py @@ -4,7 +4,7 @@ from function_call_finder_ast import find_function_calls as find_calls_ast # Test code with various scenarios -test_code = ''' +test_code = """ import module1 from module2 import func as f2 import module3 as m3 @@ -37,10 +37,10 @@ async def async_function(): def no_call(): x = 5 -''' +""" print("Testing AST vs LibCST implementations\n") -print("="*50) +print("=" * 50) # Test 1: Direct function calls print("\nTest 1: Finding 'target_func' calls") @@ -64,7 +64,7 @@ def no_call(): print(f" Only in LibCST: {libcst_keys - ast_keys}") # Test 2: Check if source code is similar (may have minor formatting differences) -print("\n" + "="*50) +print("\n" + "=" * 50) print("Test 2: Source code comparison") for func_name in ast_keys & libcst_keys: @@ -72,8 +72,8 @@ def no_call(): libcst_code = results_libcst[func_name].strip() # Normalize whitespace for comparison - ast_normalized = ' '.join(ast_code.split()) - libcst_normalized = ' '.join(libcst_code.split()) + ast_normalized = " ".join(ast_code.split()) + libcst_normalized = " ".join(libcst_code.split()) if ast_normalized == libcst_normalized: print(f"✅ {func_name}: Source code matches (normalized)") @@ -83,10 +83,10 @@ def no_call(): print(f" LibCST length: {len(libcst_code)} chars") # Test 3: Test with imports -print("\n" + "="*50) +print("\n" + "=" * 50) print("Test 3: Testing with import resolution") -import_test = ''' +import_test = """ from mymodule import target_func as tf def uses_alias(): @@ -94,7 +94,7 @@ def uses_alias(): def uses_direct(): target_func() # This shouldn't match since it's imported as tf -''' +""" results_ast_import = find_calls_ast(import_test, "mymodule.target_func", "/dummy/path.py") results_libcst_import = find_calls_libcst(import_test, "mymodule.target_func", "/dummy/path.py") @@ -103,16 +103,16 @@ def uses_direct(): print(f"LibCST found: {list(results_libcst_import.keys())}") # Summary -print("\n" + "="*50) +print("\n" + "=" * 50) print("COMPARISON SUMMARY") -print("="*50) +print("=" * 50) differences = [] if ast_keys != libcst_keys: differences.append("Different function names detected") -print(f"\n✅ AST implementation is working correctly") -print(f"✅ Output format matches: {{'func_name': 'source_code'}}") +print("\n✅ AST implementation is working correctly") +print("✅ Output format matches: {'func_name': 'source_code'}") if not differences: print("✅ Both implementations produce equivalent results") @@ -125,4 +125,4 @@ def uses_direct(): print("\n📝 Performance Note:") print(" - AST: Built-in, no dependencies, faster parsing") print(" - LibCST: External dependency, preserves formatting better") -print(" - Both produce the same logical results") \ No newline at end of file +print(" - Both produce the same logical results") diff --git a/test_function_call_finder.py b/test_function_call_finder.py index 54014e207..77dd52bcc 100644 --- a/test_function_call_finder.py +++ b/test_function_call_finder.py @@ -3,7 +3,7 @@ from function_call_finder import find_function_calls # Test code -test_code = ''' +test_code = """ def func1(): target_func() @@ -21,7 +21,7 @@ def method1(self): def method2(self): # No call here pass -''' +""" # Run the visitor results = find_function_calls(test_code, "target_func", "/dummy/path.py") @@ -39,6 +39,7 @@ def method2(self): # Verify it's exactly the format requested: {"calling_function_qualified_name1":"function_definition1",....} import json + print("\nJSON serializable:", end=" ") try: json_str = json.dumps(results) @@ -47,6 +48,6 @@ def method2(self): except: print("✗ No") -print("\n" + "="*50) +print("\n" + "=" * 50) print("VERIFIED: Output is in the format") -print('{"calling_function_qualified_name1":"function_definition1",...}') \ No newline at end of file +print('{"calling_function_qualified_name1":"function_definition1",...}') diff --git a/verify_output_format.py b/verify_output_format.py index d879e5fa5..87b2db669 100644 --- a/verify_output_format.py +++ b/verify_output_format.py @@ -1,20 +1,21 @@ """Verify that both AST and LibCST implementations produce the exact requested output format.""" import json -from function_call_finder_ast import find_function_calls as find_calls_ast + from function_call_finder import find_function_calls as find_calls_libcst +from function_call_finder_ast import find_function_calls as find_calls_ast # Simple test case -test_code = ''' +test_code = """ def func1(): my_target() def func2(): my_target(1, 2, 3) -''' +""" print("Verifying output format: {'calling_function_qualified_name1':'function_definition1',...}") -print("="*70) +print("=" * 70) # Test AST implementation print("\n1. AST Implementation:") @@ -36,14 +37,14 @@ def func2(): # Test with class methods print("\n3. Testing with class methods:") -class_test = ''' +class_test = """ class MyClass: def method1(self): target() def method2(self): pass -''' +""" ast_class = find_calls_ast(class_test, "target", "/dummy/path.py") libcst_class = find_calls_libcst(class_test, "target", "/dummy/path.py") @@ -52,7 +53,7 @@ def method2(self): print(f" LibCST result: {list(libcst_class.keys())}") # Final verification -print("\n" + "="*70) +print("\n" + "=" * 70) print("✅ VERIFIED: Both implementations return the exact format requested:") print(' {"calling_function_qualified_name1":"function_definition1",...}') print("\nKey characteristics:") @@ -60,4 +61,4 @@ def method2(self): print(" - String keys (qualified function names)") print(" - String values (function source code)") print(" - JSON serializable") -print(" - No nested structures, just simple key-value pairs") \ No newline at end of file +print(" - No nested structures, just simple key-value pairs") From e1e8efd4ba962908c7d4adaaf907814b6b076ca5 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 16 Oct 2025 17:17:42 -0700 Subject: [PATCH 09/23] start cleaning up, write tests --- codeflash/code_utils/code_extractor.py | 71 ++-- codeflash/code_utils/config_consts.py | 2 +- codeflash/models/models.py | 7 +- codeflash/optimization/function_optimizer.py | 2 +- example_usage.py | 131 ------- find_sorter_references.py | 110 ------ function_call_finder.py | 379 ------------------- function_call_finder_ast.py | 366 ------------------ function_call_visitor.py | 317 ---------------- pyproject.toml | 1 + ripgrep_search.py | 172 --------- test_ast_vs_libcst.py | 128 ------- test_function_call_finder.py | 53 --- test_function_call_visitor.py | 263 ------------- uv.lock | 27 ++ verify_output_format.py | 64 ---- 16 files changed, 79 insertions(+), 2014 deletions(-) delete mode 100644 example_usage.py delete mode 100644 find_sorter_references.py delete mode 100644 function_call_finder.py delete mode 100644 function_call_finder_ast.py delete mode 100644 function_call_visitor.py delete mode 100644 ripgrep_search.py delete mode 100644 test_ast_vs_libcst.py delete mode 100644 test_function_call_finder.py delete mode 100644 test_function_call_visitor.py delete mode 100644 verify_output_format.py diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 730eb77b2..8eb794332 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -10,20 +10,22 @@ import jedi import libcst as cst +import radon.visitors from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package +from radon.complexity import cc_visit # from codeflash.benchmarking.pytest_new_process_trace_benchmarks import project_root from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_IMPACT, TIME_LIMIT_FOR_OPT_IMPACT -from codeflash.models.models import CodePosition, FunctionParent +from codeflash.models.models import CodePosition, FunctionParent, ImpactMetrics if TYPE_CHECKING: from libcst.helpers import ModuleNameAndPackage from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.models.models import FunctionSource, ImpactMetrics + from codeflash.models.models import FunctionSource class GlobalAssignmentCollector(cst.CSTVisitor): @@ -1092,13 +1094,11 @@ def find_occurances( fn_call_context += f"{fn_definition}\n" context_len += len(fn_definition) fn_call_context += "```\n" - opt_impact_metrics = {"calling_fn_defs": fn_call_context} - # radon metrics = get_radon_metrics(sour) return fn_call_context def find_specific_function_in_file( - source_code: str, filepath: Union[str, Path], qualified_name: str + source_code: str, filepath: Union[str, Path], target_function: str, target_class: str | None ) -> Optional[tuple[int, int]]: """Find a specific function definition in a Python file and return its location. @@ -1107,17 +1107,13 @@ def find_specific_function_in_file( Args: source_code: Source code string filepath: Path to the Python file - qualified_name: Qualified Name of the function to find, classname.functionname + target_function: Function Name of the function to find + target_class: Class name of the function to find Returns: Tuple of (line_number, column_offset) if found, None otherwise """ - qualified_name_split = qualified_name.rsplit(".", maxsplit=1) - if len(qualified_name_split) == 1: - target_function, target_class = qualified_name_split[0], None - else: - target_function, target_class = qualified_name_split[1], qualified_name_split[0] script = jedi.Script(code=source_code, path=filepath) names = script.get_names(all_scopes=True, definitions=True) for name in names: @@ -1134,16 +1130,16 @@ def find_specific_function_in_file( return None # Function not found -def get_fn_references_jedi(source_code: str, file_path: Path, qualified_name: str, project_root: Path) -> list[Path]: - print(file_path, qualified_name, project_root) - # Create a Jedi Script object - function_position: CodePosition = find_specific_function_in_file(source_code, file_path, qualified_name) +def get_fn_references_jedi( + source_code: str, file_path: Path, project_root: Path, target_function: str, target_class: str | None +) -> list[Path]: + function_position: CodePosition = find_specific_function_in_file( + source_code, file_path, target_function, target_class + ) try: script = jedi.Script(code=source_code, path=file_path, project=jedi.Project(path=project_root)) - # Get references to the function references = script.get_references(line=function_position.line_no, column=function_position.col_no) - # Collect unique file paths where references are found reference_files = set() for ref in references: @@ -1153,9 +1149,7 @@ def get_fn_references_jedi(source_code: str, file_path: Path, qualified_name: st # Skip the definition itself if not (ref_path == file_path and ref.line == function_position.line_no): reference_files.add(ref_path) - return sorted(reference_files) - except Exception as e: print(f"Error during Jedi analysis: {e}") return [] @@ -1164,9 +1158,36 @@ def get_fn_references_jedi(source_code: str, file_path: Path, qualified_name: st def get_opt_impact_metrics( source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path ) -> ImpactMetrics: - # radon lib for complexity metrics - # print(file_path, qualified_name, project_root, tests_root) - matches = get_fn_references_jedi( - source_code, file_path, qualified_name, project_root - ) # jedi is not perfect, it doesn't capture aliased references - return find_occurances(qualified_name, str(file_path), matches, project_root, tests_root) + metrics = ImpactMetrics() + try: + qualified_name_split = qualified_name.rsplit(".", maxsplit=1) + if len(qualified_name_split) == 1: + target_function, target_class = qualified_name_split[0], None + else: + target_function, target_class = qualified_name_split[1], qualified_name_split[0] + matches = get_fn_references_jedi( + source_code, file_path, project_root, target_function, target_class + ) # jedi is not perfect, it doesn't capture aliased references + cyclomatic_complexity_results = cc_visit(source_code) + match_found = False + for result in cyclomatic_complexity_results: + if match_found: + break + if isinstance(result, radon.visitors.Function) and not target_class: + if result.name == target_function: + metrics.cyclomatic_complexity = result.complexity + metrics.cyclomatic_complexity_rating = result.letter + match_found = True + elif isinstance(result, radon.visitors.Class) and target_class: # noqa: SIM102 + if result.name == target_class: + for method in result.methods: + if match_found: + break + if method.name == target_function: + metrics.cyclomatic_complexity = method.complexity + metrics.cyclomatic_complexity_rating = method.letter + match_found = True + metrics.calling_fns = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root) + except Exception as e: + logger.debug(f"Investigate {e}") + return metrics diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 34c143256..ca39c0238 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -36,4 +36,4 @@ TOTAL_LOOPING_TIME_EFFECTIVE = TOTAL_LOOPING_TIME_LSP if _IS_LSP_ENABLED else TOTAL_LOOPING_TIME MAX_CONTEXT_LEN_IMPACT = 1000 -TIME_LIMIT_FOR_OPT_IMPACT = 5 # in sec +TIME_LIMIT_FOR_OPT_IMPACT = 10 # in sec diff --git a/codeflash/models/models.py b/codeflash/models/models.py index d00c1246d..f3293cc83 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -33,10 +33,9 @@ @dataclass class ImpactMetrics: - complexity_score: int - occurances: int - loop_occurances: int - presence_of_decorators: bool + cyclomatic_complexity: Optional[int] = None + cyclomatic_complexity_rating: Optional[str] = None + calling_fns: Optional[str] = None @dataclass(frozen=True) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 53131b4ea..d28f06e02 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1474,7 +1474,7 @@ def process_review( self.function_to_optimize.qualified_name, self.project_root, self.test_cfg.tests_root, - ) # need module root, tests root only + ) if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) diff --git a/example_usage.py b/example_usage.py deleted file mode 100644 index 3bebd9ed1..000000000 --- a/example_usage.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python3 -"""Example usage of the function reference finder.""" - -from find_function_references import find_function_references -from find_function_references_detailed import find_function_references_detailed - - -def example_basic_usage(): - """Example of basic usage - just get list of files.""" - # Example: Find references to a function - filepath = "path/to/your/file.py" - function_name = "your_function_name" - - # Find references (will auto-detect project root) - reference_files = find_function_references(filepath, function_name) - - print(f"Files containing references to {function_name}:") - for file in reference_files: - print(f" - {file}") - - return reference_files - - -def example_detailed_usage(): - """Example of detailed usage - get line numbers and context.""" - # Example: Find references with detailed information - filepath = "path/to/your/file.py" - function_name = "your_function_name" - project_root = "/path/to/project/root" # Optional - - # Find detailed references - references = find_function_references_detailed(filepath, function_name, project_root) - - print(f"\nDetailed references to {function_name}:") - for file, refs in references.items(): - print(f"\nFile: {file}") - for ref in refs: - print(f" Line {ref['line']}: {ref['context']}") - - return references - - -def example_programmatic_usage(): - """Example of using in your own Python code.""" - import jedi - - # Direct Jedi usage for more control - source_code = """ -def my_function(x, y): - return x + y - -result = my_function(1, 2) -""" - - # Create a script object - script = jedi.Script(code=source_code) - - # Find the function definition (line 2, column 4 for 'my_function') - references = script.get_references(line=2, column=4) - - print("Direct Jedi references:") - for ref in references: - print(f" Line {ref.line}, Column {ref.column}: {ref.description}") - - # You can also search for names - names = script.get_names() - for name in names: - if name.type == "function": - print(f"Found function: {name.name} at line {name.line}") - - -def example_find_all_functions_and_their_references(): - """Example of finding all functions in a file and their references.""" - import os - - import jedi - - def find_all_functions_with_references(filepath: str): - """Find all functions in a file and their references.""" - with open(filepath) as f: - source_code = f.read() - - script = jedi.Script(code=source_code, path=filepath) - - # Get all defined names - names = script.get_names() - - functions_and_refs = {} - - for name in names: - if name.type == "function": - # Get references for this function - refs = script.get_references(line=name.line, column=name.column) - - ref_locations = [] - for ref in refs: - if ref.module_path and str(ref.module_path) != filepath: - ref_locations.append({"file": str(ref.module_path), "line": ref.line, "column": ref.column}) - - functions_and_refs[name.name] = {"definition_line": name.line, "references": ref_locations} - - return functions_and_refs - - # Example usage - filepath = "your_module.py" - if os.path.exists(filepath): - all_refs = find_all_functions_with_references(filepath) - - for func_name, info in all_refs.items(): - print(f"\nFunction: {func_name} (defined at line {info['definition_line']})") - if info["references"]: - print(" Referenced in:") - for ref in info["references"]: - print(f" - {ref['file']}:{ref['line']}") - else: - print(" No external references found") - - -if __name__ == "__main__": - print("=" * 60) - print("Function Reference Finder - Usage Examples") - print("=" * 60) - - print("\nNote: Update the file paths and function names in the examples") - print("before running them with your actual code.\n") - - # Uncomment to run examples: - # example_basic_usage() - # example_detailed_usage() - # example_programmatic_usage() - # example_find_all_functions_and_their_references() diff --git a/find_sorter_references.py b/find_sorter_references.py deleted file mode 100644 index 82a7d14bb..000000000 --- a/find_sorter_references.py +++ /dev/null @@ -1,110 +0,0 @@ -#!/usr/bin/env python3 -"""Script to find all references to the sorter function from code_to_optimize/bubble_sort.py -using Jedi's static analysis capabilities. -""" - -from pathlib import Path - -import jedi - - -def find_function_references(file_path, line, column, project_root): - """Find all references to a function using Jedi. - - Args: - file_path: Path to the file containing the function - line: Line number where the function is defined (1-indexed) - column: Column number where the function name starts (0-indexed) - project_root: Root directory of the project to search - - """ - # Read the source code - with open(file_path) as f: - source = f.read() - - # Create a Jedi Script object with project configuration - project = jedi.Project(path=project_root) - script = jedi.Script(source, path=file_path, project=project) - - # Get the function definition at the specified position - definitions = script.goto(line, column, follow_imports=True) - - if not definitions: - print(f"No definition found at {file_path}:{line}:{column}") - return [] - - # Get the first definition (should be the function itself) - definition = definitions[0] - print(f"Found definition: {definition.name} at {definition.module_path}:{definition.line}") - print(f"Type: {definition.type}") - print("-" * 80) - - # Use search_all to find all references to this function - # We'll search for references by name throughout the project - references = [] - try: - # Use usages() method to get all references - references = script.get_references(line, column, scope="project", include_builtins=False) - except AttributeError: - # Alternative approach using search - print("Using alternative search method...") - references = script.get_references(line, column, include_builtins=False) - - return references - - -def main(): - # Project root directory - project_root = Path("/Users/aseemsaxena/Downloads/codeflash_dev/codeflash") - - # Target file and function location - target_file = project_root / "code_to_optimize" / "bubble_sort.py" - - # The sorter function starts at line 1, column 4 (0-indexed) - # "def sorter(arr):" - the function name 'sorter' starts at column 4 - line = 1 # Line number (1-indexed) - column = 4 # Column number (0-indexed) - position of 's' in 'sorter' - - print(f"Searching for references to 'sorter' function in {target_file}") - print(f"Position: Line {line}, Column {column}") - print("=" * 80) - - # Find references - references = find_function_references(target_file, line, column, project_root) - - if references: - print(f"\nFound {len(references)} reference(s) to 'sorter' function:") - print("=" * 80) - - # Group references by file - refs_by_file = {} - for ref in references: - file_path = ref.module_path - if file_path not in refs_by_file: - refs_by_file[file_path] = [] - refs_by_file[file_path].append(ref) - - # Display references organized by file - for file_path, file_refs in sorted(refs_by_file.items()): - print(f"\n📁 {file_path}") - for ref in sorted(file_refs, key=lambda r: (r.line, r.column)): - # Get the line content for context - try: - with open(file_path) as f: - lines = f.readlines() - if ref.line <= len(lines): - line_content = lines[ref.line - 1].strip() - print(f" Line {ref.line}, Col {ref.column}: {line_content}") - else: - print(f" Line {ref.line}, Col {ref.column}") - except Exception as e: - print(f" Line {ref.line}, Col {ref.column} (couldn't read line: {e})") - else: - print("\nNo references found to the 'sorter' function.") - - print("\n" + "=" * 80) - print("Search complete!") - - -if __name__ == "__main__": - main() diff --git a/function_call_finder.py b/function_call_finder.py deleted file mode 100644 index 65acadd64..000000000 --- a/function_call_finder.py +++ /dev/null @@ -1,379 +0,0 @@ -"""LibCST visitor to find function definitions that call a specific qualified function.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union - -import libcst as cst - - -@dataclass -class FunctionCallLocation: - """Represents a location where the target function is called.""" - - calling_function: str # Name of the function making the call - line: int - column: int - call_node: cst.Call # The actual call node for additional analysis if needed - - -@dataclass -class FunctionDefinitionInfo: - """Contains information about a function definition.""" - - name: str # Qualified name of the function - node: cst.FunctionDef # The CST node of the function definition - source_code: str # The source code of the function - start_line: int - end_line: int - is_method: bool # Whether this is a class method - class_name: Optional[str] = None # Name of containing class if it's a method - - -class FunctionCallFinder(cst.CSTVisitor): - """Visitor that finds all function definitions that call a specific qualified function. - - Args: - target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") - target_filepath: The filepath where the target function is defined - - """ - - METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) - - def __init__(self, target_function_name: str, target_filepath: str) -> None: - super().__init__() - self.target_function_name = target_function_name - self.target_filepath = target_filepath - - # Parse the target function name into parts - self.target_parts = target_function_name.split(".") - self.target_base_name = self.target_parts[-1] - - # Track current context - self.current_function_stack: list[Tuple[str, cst.FunctionDef]] = [] # (name, node) pairs - self.current_class_stack: list[str] = [] - - # Track imports to resolve qualified names - self.imports: dict = {} # Maps imported names to their full paths - - # Results - self.function_calls: list[FunctionCallLocation] = [] - self.calling_functions: set[str] = set() # Unique function names that call the target - self.function_definitions: Dict[str, FunctionDefinitionInfo] = {} # Function name -> definition info - - # Track if we found calls in the current function - self.found_call_in_current_function = False - # Track functions with nested calls (parent functions that contain nested functions with calls) - self.functions_with_nested_calls: set[str] = set() - - def visit_Import(self, node: cst.Import) -> None: - """Track regular imports.""" - for name in node.names: - if isinstance(name, cst.ImportAlias): - if name.asname: - # import module as alias - module_name = name.name.value if isinstance(name.name, cst.Attribute) else str(name.name) - alias = name.asname.name.value - self.imports[alias] = module_name - else: - # import module - module_name = self._get_dotted_name(name.name) - if module_name: - self.imports[module_name.split(".")[-1]] = module_name - - def visit_ImportFrom(self, node: cst.ImportFrom) -> None: - """Track from imports.""" - if not node.module: - return - - module_path = self._get_dotted_name(node.module) - if not module_path: - return - - if isinstance(node.names, cst.ImportStar): - # from module import * - self.imports["*"] = module_path - else: - # from module import name1, name2 - for name in node.names: - if isinstance(name, cst.ImportAlias): - import_name = name.name.value - if name.asname: - # from module import name as alias - alias = name.asname.name.value - self.imports[alias] = f"{module_path}.{import_name}" - else: - # from module import name - self.imports[import_name] = f"{module_path}.{import_name}" - - def visit_ClassDef(self, node: cst.ClassDef) -> None: - """Track when entering a class definition.""" - self.current_class_stack.append(node.name.value) - - def leave_ClassDef(self, node: cst.ClassDef) -> None: - """Track when leaving a class definition.""" - if self.current_class_stack: - self.current_class_stack.pop() - - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - """Track when entering a function definition.""" - func_name = node.name.value - - # Build the full qualified name including class if applicable - full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name - - self.current_function_stack.append((full_name, node)) - self.found_call_in_current_function = False - - def leave_FunctionDef(self, node: cst.FunctionDef) -> None: - """Track when leaving a function definition and store it if it contains target calls.""" - if self.current_function_stack: - full_name, func_node = self.current_function_stack.pop() - - # If we found a call in this function, store its definition - if self.found_call_in_current_function and full_name not in self.function_definitions: - # Get position information - position = self.get_metadata(cst.metadata.PositionProvider, func_node) - - # Extract function source code by converting node to module - # For methods, we need to maintain proper indentation - func_source = cst.Module(body=[func_node]).code - - # For methods, add proper indentation (4 spaces) - if self.current_class_stack: - lines = func_source.split("\n") - func_source = "\n".join(" " + line if line else line for line in lines) - - self.function_definitions[full_name] = FunctionDefinitionInfo( - name=full_name, - node=func_node, - source_code=func_source.rstrip(), # Remove trailing whitespace - start_line=position.start.line if position else -1, - end_line=position.end.line if position else -1, - is_method=bool(self.current_class_stack), - class_name=self.current_class_stack[-1] if self.current_class_stack else None, - ) - - # Handle nested functions - mark parent as containing nested calls - if self.found_call_in_current_function and self.current_function_stack: - parent_name = self.current_function_stack[-1][0] - self.functions_with_nested_calls.add(parent_name) - # Also store the parent function if not already stored - if parent_name not in self.function_definitions: - parent_func_node = self.current_function_stack[-1][1] - parent_position = self.get_metadata(cst.metadata.PositionProvider, parent_func_node) - parent_source = cst.Module(body=[parent_func_node]).code - - # Get parent class context (go up one level in stack since we're inside the nested function) - parent_class_stack = ( - self.current_class_stack[:-1] - if len(self.current_function_stack) == 1 and self.current_class_stack - else [] - ) - - if parent_class_stack: - lines = parent_source.split("\n") - parent_source = "\n".join(" " + line if line else line for line in lines) - - self.function_definitions[parent_name] = FunctionDefinitionInfo( - name=parent_name, - node=parent_func_node, - source_code=parent_source.rstrip(), - start_line=parent_position.start.line if parent_position else -1, - end_line=parent_position.end.line if parent_position else -1, - is_method=bool(parent_class_stack), - class_name=parent_class_stack[-1] if parent_class_stack else None, - ) - - # Reset the flag for parent function if we're in nested functions - if self.current_function_stack: - # Check if the parent function should also be marked as containing calls - parent_name = self.current_function_stack[-1][0] - self.found_call_in_current_function = parent_name in self.calling_functions - - def visit_Call(self, node: cst.Call) -> None: - """Check if this call matches our target function.""" - if not self.current_function_stack: - # Not inside a function, skip - return - - if self._is_target_function_call(node): - # Get position information - position = self.get_metadata(cst.metadata.PositionProvider, node) - - current_func_name = self.current_function_stack[-1][0] - - call_location = FunctionCallLocation( - calling_function=current_func_name, - line=position.start.line if position else -1, - column=position.start.column if position else -1, - call_node=node, - ) - - self.function_calls.append(call_location) - self.calling_functions.add(current_func_name) - self.found_call_in_current_function = True - - def _is_target_function_call(self, node: cst.Call) -> bool: - """Determine if this call node is calling our target function. - - Handles various call patterns: - - Direct calls: function() - - Qualified calls: module.function() - - Method calls: obj.method() - """ - func = node.func - - # Get the call name - call_name = self._get_call_name(func) - if not call_name: - return False - - # Check if it matches directly - if call_name == self.target_function_name: - return True - - # Check if it's just the base name matching - if call_name == self.target_base_name: - # Could be imported with a different name, check imports - if call_name in self.imports: - imported_path = self.imports[call_name] - # Check if the imported path matches our target - if imported_path == self.target_function_name or imported_path.endswith( - f".{self.target_function_name}" - ): - return True - # Could also be a direct call if we're in the same file - return True - - # Check for qualified calls with imports - call_parts = call_name.split(".") - if call_parts[0] in self.imports: - # Resolve the full path using imports - base_import = self.imports[call_parts[0]] - full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import - - if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): - return True - - return False - - def _get_call_name(self, func: Union[cst.Name, cst.Attribute, cst.Call]) -> Optional[str]: - """Extract the name being called from a function node.""" - if isinstance(func, cst.Name): - return func.value - if isinstance(func, cst.Attribute): - return self._get_dotted_name(func) - if isinstance(func, cst.Call): - # Chained calls like foo()() - return None - return None - - def _get_dotted_name(self, node: Union[cst.Name, cst.Attribute]) -> Optional[str]: - """Get the full dotted name from an Attribute or Name node.""" - if isinstance(node, cst.Name): - return node.value - if isinstance(node, cst.Attribute): - parts = [] - current = node - while isinstance(current, cst.Attribute): - parts.append(current.attr.value) - current = current.value - if isinstance(current, cst.Name): - parts.append(current.value) - return ".".join(reversed(parts)) - return None - - def get_results(self) -> Dict[str, str]: - """Get the results of the analysis. - - Returns: - A dictionary mapping qualified function names to their source code definitions. - Only includes functions that call the target function (directly or through nested functions). - - """ - return {info.name: info.source_code for info in self.function_definitions.values()} - - -def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> Dict[str, str]: - """Find all function definitions that call a specific target function. - - Args: - source_code: The Python source code to analyze - target_function_name: The qualified name of the function to find (e.g., "module.function") - target_filepath: The filepath where the target function is defined - - Returns: - A dictionary mapping qualified function names to their source code definitions. - Example: {"function_a": "def function_a():\n ...", "MyClass.method_one": "def method_one(self):\n ..."} - - """ - # Parse the source code - module = cst.parse_module(source_code) - - # Create and run the visitor - visitor = FunctionCallFinder(target_function_name, target_filepath) - wrapper = cst.metadata.MetadataWrapper(module) - wrapper.visit(visitor) - - return visitor.get_results() - - -# Example usage -if __name__ == "__main__": - # Example source code to analyze - example_code = ''' -import os -from pathlib import Path -from my_module import target_function as tf -import my_module - -def function_a(): - """This function calls the target function directly.""" - result = tf(42) - return result - -def function_b(): - """This function calls the target function via module.""" - my_module.target_function("hello") - -class MyClass: - def method_one(self): - """Method that calls the target.""" - tf(1, 2, 3) - - def method_two(self): - """Method that doesn't call the target.""" - print("No call here") - -def function_c(): - """This function doesn't call the target.""" - print("Just printing") - -def nested_calls(): - """Function with nested function definitions.""" - def inner(): - tf("nested call") - inner() -''' - - # Find calls to a specific function - results = find_function_calls( - example_code, target_function_name="my_module.target_function", target_filepath="/path/to/my_module.py" - ) - - print("Functions that call 'my_module.target_function':\n") - - # Simple usage - results is just a dict of {function_name: source_code} - import json - - print("JSON representation of results:") - print(json.dumps(list(results.keys()), indent=2)) - - print("\nFormatted output:") - for func_name, source_code in results.items(): - print(f"\n=== {func_name} ===") - print(source_code) - print() diff --git a/function_call_finder_ast.py b/function_call_finder_ast.py deleted file mode 100644 index 4baab24e9..000000000 --- a/function_call_finder_ast.py +++ /dev/null @@ -1,366 +0,0 @@ -"""AST-based visitor to find function definitions that call a specific qualified function.""" - -import ast -from dataclasses import dataclass -from typing import Dict, List, Optional, Set, Tuple - - -@dataclass -class FunctionCallLocation: - """Represents a location where the target function is called.""" - - calling_function: str - line: int - column: int - - -@dataclass -class FunctionDefinitionInfo: - """Contains information about a function definition.""" - - name: str - node: ast.FunctionDef - source_code: str - start_line: int - end_line: int - is_method: bool - class_name: Optional[str] = None - - -class FunctionCallFinder(ast.NodeVisitor): - """AST visitor that finds all function definitions that call a specific qualified function. - - Args: - target_function_name: The qualified name of the function to find (e.g., "module.function" or "function") - target_filepath: The filepath where the target function is defined - - """ - - def __init__(self, target_function_name: str, target_filepath: str, source_lines: List[str]): - self.target_function_name = target_function_name - self.target_filepath = target_filepath - self.source_lines = source_lines # Store original source lines for extraction - - # Parse the target function name into parts - self.target_parts = target_function_name.split(".") - self.target_base_name = self.target_parts[-1] - - # Track current context - self.current_function_stack: List[Tuple[str, ast.FunctionDef]] = [] - self.current_class_stack: List[str] = [] - - # Track imports to resolve qualified names - self.imports: Dict[str, str] = {} # Maps imported names to their full paths - - # Results - self.function_calls: List[FunctionCallLocation] = [] - self.calling_functions: Set[str] = set() - self.function_definitions: Dict[str, FunctionDefinitionInfo] = {} - - # Track if we found calls in the current function - self.found_call_in_current_function = False - self.functions_with_nested_calls: Set[str] = set() - - def visit_Import(self, node: ast.Import) -> None: - """Track regular imports.""" - for alias in node.names: - if alias.asname: - # import module as alias - self.imports[alias.asname] = alias.name - else: - # import module - self.imports[alias.name.split(".")[-1]] = alias.name - self.generic_visit(node) - - def visit_ImportFrom(self, node: ast.ImportFrom) -> None: - """Track from imports.""" - if node.module: - for alias in node.names: - if alias.name == "*": - # from module import * - self.imports["*"] = node.module - elif alias.asname: - # from module import name as alias - self.imports[alias.asname] = f"{node.module}.{alias.name}" - else: - # from module import name - self.imports[alias.name] = f"{node.module}.{alias.name}" - self.generic_visit(node) - - def visit_ClassDef(self, node: ast.ClassDef) -> None: - """Track when entering a class definition.""" - self.current_class_stack.append(node.name) - self.generic_visit(node) - self.current_class_stack.pop() - - def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - """Track when entering a function definition.""" - self._visit_function_def(node) - - def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: - """Track when entering an async function definition.""" - self._visit_function_def(node) - - def _visit_function_def(self, node: ast.FunctionDef) -> None: - """Common logic for both regular and async function definitions.""" - func_name = node.name - - # Build the full qualified name including class if applicable - full_name = f"{'.'.join(self.current_class_stack)}.{func_name}" if self.current_class_stack else func_name - - self.current_function_stack.append((full_name, node)) - self.found_call_in_current_function = False - - # Visit the function body - self.generic_visit(node) - - # Process the function after visiting its body - if self.found_call_in_current_function and full_name not in self.function_definitions: - # Extract function source code - source_code = self._extract_source_code(node) - - self.function_definitions[full_name] = FunctionDefinitionInfo( - name=full_name, - node=node, - source_code=source_code, - start_line=node.lineno, - end_line=node.end_lineno if hasattr(node, "end_lineno") else node.lineno, - is_method=bool(self.current_class_stack), - class_name=self.current_class_stack[-1] if self.current_class_stack else None, - ) - - # Handle nested functions - mark parent as containing nested calls - if self.found_call_in_current_function and len(self.current_function_stack) > 1: - parent_name = self.current_function_stack[-2][0] - self.functions_with_nested_calls.add(parent_name) - - # Also store the parent function if not already stored - if parent_name not in self.function_definitions: - parent_node = self.current_function_stack[-2][1] - parent_source = self._extract_source_code(parent_node) - - # Check if parent is a method (excluding current level) - parent_class_context = self.current_class_stack if len(self.current_function_stack) == 2 else [] - - self.function_definitions[parent_name] = FunctionDefinitionInfo( - name=parent_name, - node=parent_node, - source_code=parent_source, - start_line=parent_node.lineno, - end_line=parent_node.end_lineno if hasattr(parent_node, "end_lineno") else parent_node.lineno, - is_method=bool(parent_class_context), - class_name=parent_class_context[-1] if parent_class_context else None, - ) - - self.current_function_stack.pop() - - # Reset flag for parent function - if self.current_function_stack: - parent_name = self.current_function_stack[-1][0] - self.found_call_in_current_function = parent_name in self.calling_functions - - def visit_Call(self, node: ast.Call) -> None: - """Check if this call matches our target function.""" - if not self.current_function_stack: - # Not inside a function, skip - self.generic_visit(node) - return - - if self._is_target_function_call(node): - current_func_name = self.current_function_stack[-1][0] - - call_location = FunctionCallLocation( - calling_function=current_func_name, line=node.lineno, column=node.col_offset - ) - - self.function_calls.append(call_location) - self.calling_functions.add(current_func_name) - self.found_call_in_current_function = True - - self.generic_visit(node) - - def _is_target_function_call(self, node: ast.Call) -> bool: - """Determine if this call node is calling our target function.""" - call_name = self._get_call_name(node.func) - if not call_name: - return False - - # Check if it matches directly - if call_name == self.target_function_name: - return True - - # Check if it's just the base name matching - if call_name == self.target_base_name: - # Could be imported with a different name, check imports - if call_name in self.imports: - imported_path = self.imports[call_name] - if imported_path == self.target_function_name or imported_path.endswith( - f".{self.target_function_name}" - ): - return True - # Could also be a direct call if we're in the same file - return True - - # Check for qualified calls with imports - call_parts = call_name.split(".") - if call_parts[0] in self.imports: - # Resolve the full path using imports - base_import = self.imports[call_parts[0]] - full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import - - if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"): - return True - - return False - - def _get_call_name(self, func_node) -> Optional[str]: - """Extract the name being called from a function node.""" - if isinstance(func_node, ast.Name): - return func_node.id - if isinstance(func_node, ast.Attribute): - parts = [] - current = func_node - while isinstance(current, ast.Attribute): - parts.append(current.attr) - current = current.value - if isinstance(current, ast.Name): - parts.append(current.id) - return ".".join(reversed(parts)) - return None - - def _extract_source_code(self, node: ast.FunctionDef) -> str: - """Extract source code for a function node using original source lines.""" - if not self.source_lines or not hasattr(node, "lineno"): - # Fallback to ast.unparse if available (Python 3.9+) - try: - return ast.unparse(node) - except AttributeError: - return f"# Source code extraction not available for {node.name}" - - # Get the lines for this function - start_line = node.lineno - 1 # Convert to 0-based index - end_line = node.end_lineno if hasattr(node, "end_lineno") else len(self.source_lines) - - # Extract the function lines - func_lines = self.source_lines[start_line:end_line] - - # Find the minimum indentation (excluding empty lines) - min_indent = float("inf") - for line in func_lines: - if line.strip(): # Skip empty lines - indent = len(line) - len(line.lstrip()) - min_indent = min(min_indent, indent) - - # If this is a method (inside a class), preserve one level of indentation - if self.current_class_stack: - # Keep 4 spaces of indentation for methods - dedent_amount = max(0, min_indent - 4) - result_lines = [] - for line in func_lines: - if line.strip(): # Only dedent non-empty lines - result_lines.append(line[dedent_amount:] if len(line) > dedent_amount else line) - else: - result_lines.append(line) - else: - # For top-level functions, remove all leading indentation - result_lines = [] - for line in func_lines: - if line.strip(): # Only dedent non-empty lines - result_lines.append(line[min_indent:] if len(line) > min_indent else line) - else: - result_lines.append(line) - - return "".join(result_lines).rstrip() - - def get_results(self) -> Dict[str, str]: - """Get the results of the analysis. - - Returns: - A dictionary mapping qualified function names to their source code definitions. - - """ - return {info.name: info.source_code for info in self.function_definitions.values()} - - -def find_function_calls(source_code: str, target_function_name: str, target_filepath: str) -> Dict[str, str]: - """Find all function definitions that call a specific target function. - - Args: - source_code: The Python source code to analyze - target_function_name: The qualified name of the function to find (e.g., "module.function") - target_filepath: The filepath where the target function is defined - - Returns: - A dictionary mapping qualified function names to their source code definitions. - Example: {"function_a": "def function_a():\n ...", "MyClass.method_one": "def method_one(self):\n ..."} - - """ - # Parse the source code - tree = ast.parse(source_code) - - # Split source into lines for source extraction - source_lines = source_code.splitlines(keepends=True) - - # Create and run the visitor - visitor = FunctionCallFinder(target_function_name, target_filepath, source_lines) - visitor.visit(tree) - - return visitor.get_results() - - -# Example usage -if __name__ == "__main__": - # Example source code to analyze - example_code = ''' -import os -from pathlib import Path -from my_module import target_function as tf -import my_module - -def function_a(): - """This function calls the target function directly.""" - result = tf(42) - return result - -def function_b(): - """This function calls the target function via module.""" - my_module.target_function("hello") - -class MyClass: - def method_one(self): - """Method that calls the target.""" - tf(1, 2, 3) - - def method_two(self): - """Method that doesn't call the target.""" - print("No call here") - -def function_c(): - """This function doesn't call the target.""" - print("Just printing") - -def nested_calls(): - """Function with nested function definitions.""" - def inner(): - tf("nested call") - inner() -''' - - # Find calls to a specific function - results = find_function_calls( - example_code, target_function_name="my_module.target_function", target_filepath="/path/to/my_module.py" - ) - - print("Functions that call 'my_module.target_function':\n") - - # Simple usage - results is just a dict of {function_name: source_code} - import json - - print("JSON representation of results:") - print(json.dumps(list(results.keys()), indent=2)) - - print("\nFormatted output:") - for func_name, source_code in results.items(): - print(f"\n=== {func_name} ===") - print(source_code) - print() diff --git a/function_call_visitor.py b/function_call_visitor.py deleted file mode 100644 index 5537e1a3a..000000000 --- a/function_call_visitor.py +++ /dev/null @@ -1,317 +0,0 @@ -"""AST Visitor to count function calls and identify calls within loops. - -This module provides a visitor that can track calls to specific functions, -including regular functions, methods, classmethods, and staticmethods. -""" - -from __future__ import annotations - -import ast -from dataclasses import dataclass -from pathlib import Path -from typing import Optional - - -@dataclass -class CallInfo: - """Information about a function call.""" - - line: int - col: int - call_text: str - in_loop: bool - loop_type: Optional[str] = None # 'for', 'while', or nested combinations - file_path: Optional[str] = None - - def __repr__(self): - loop_info = f" (in {self.loop_type} loop)" if self.in_loop else "" - file_info = f"{self.file_path}:" if self.file_path else "" - return f"{file_info}{self.line}:{self.col} - {self.call_text}{loop_info}" - - -class FunctionCallVisitor(ast.NodeVisitor): - """AST visitor to count and track function calls. - - Handles: - - Regular function calls: func() - - Method calls: obj.method() - - Class method calls: Class.method() - - Static method calls: Class.static_method() - - Nested attribute calls: module.submodule.func() - """ - - def __init__(self, target_functions: list[str], file_path: Optional[str] = None): - """Initialize the visitor. - - Args: - target_functions: list of function names to track. Can be: - - Simple names: ['print', 'len'] - - Qualified names: ['os.path.join', 'numpy.array'] - - Method names: ['append', 'extend'] (will match any obj.append()) - file_path: Optional path to the file being analyzed - - """ - self.target_functions = set(target_functions) - self.file_path = file_path - self.calls: list[CallInfo] = [] - self.loop_stack: list[str] = [] # Track nested loops - self._source_lines: Optional[list[str]] = None - - def set_source(self, source: str): - """Set the source code for better call text extraction.""" - self._source_lines = source.splitlines() - - def _get_call_name(self, node: ast.Call) -> Optional[str]: - """Extract the full name of the called function.""" - if isinstance(node.func, ast.Name): - # Simple function call: func() - return node.func.id - if isinstance(node.func, ast.Attribute): - # Method or qualified call: obj.method() or module.func() - parts = [] - current = node.func - - while isinstance(current, ast.Attribute): - parts.append(current.attr) - current = current.value - - if isinstance(current, ast.Name): - parts.append(current.id) - full_name = ".".join(reversed(parts)) - - # Check if we should track this call - # Match exact qualified names or just the method name - if full_name in self.target_functions: - return full_name - - # Also check if just the method name matches - # (for tracking all calls to a method regardless of object) - method_name = parts[0] # The rightmost part is the method - if method_name in self.target_functions: - return full_name - - # Check partial matches (e.g., 'path.join' matches 'os.path.join') - for target in self.target_functions: - if full_name.endswith(target) or target.endswith(full_name): - return full_name - - return None - - def _get_call_text(self, node: ast.Call) -> str: - """Get a string representation of the call.""" - if self._source_lines and hasattr(node, "lineno") and hasattr(node, "end_lineno"): - try: - if node.lineno == node.end_lineno: - line = self._source_lines[node.lineno - 1] - if hasattr(node, "col_offset") and hasattr(node, "end_col_offset"): - return line[node.col_offset : node.end_col_offset] - else: - # Multi-line call - lines = [] - for i in range(node.lineno - 1, node.end_lineno): - if i < len(self._source_lines): - if i == node.lineno - 1: - lines.append(self._source_lines[i][node.col_offset :]) - elif i == node.end_lineno - 1: - lines.append(self._source_lines[i][: node.end_col_offset]) - else: - lines.append(self._source_lines[i]) - return " ".join(line.strip() for line in lines) - except (IndexError, AttributeError): - pass - - # Fallback to reconstructing from AST - return ast.unparse(node) if hasattr(ast, "unparse") else self._get_call_name(node) + "(...)" - - def _in_loop(self) -> bool: - """Check if we're currently inside a loop.""" - return len(self.loop_stack) > 0 - - def _get_loop_type(self) -> Optional[str]: - """Get the current loop type(s).""" - if not self.loop_stack: - return None - if len(self.loop_stack) == 1: - return self.loop_stack[0] - return " -> ".join(self.loop_stack) # Show nested loops - - def visit_Call(self, node: ast.Call): - """Visit a function call node.""" - call_name = self._get_call_name(node) - - if call_name: - # Check if this matches any of our target functions - should_track = False - - # Direct match - if call_name in self.target_functions: - should_track = True - else: - # Check if just the method/function name matches - simple_name = call_name.split(".")[-1] - if simple_name in self.target_functions: - should_track = True - else: - # Check for partial qualified matches - for target in self.target_functions: - if "." in target: - # For qualified targets, check if call matches the end - if call_name.endswith("." + target.split(".")[-1]): - should_track = True - break - - if should_track: - call_info = CallInfo( - line=node.lineno, - col=node.col_offset, - call_text=self._get_call_text(node), - in_loop=self._in_loop(), - loop_type=self._get_loop_type(), - file_path=self.file_path, - ) - self.calls.append(call_info) - - self.generic_visit(node) - - def visit_For(self, node: ast.For): - """Visit a for loop.""" - self.loop_stack.append("for") - self.generic_visit(node) - self.loop_stack.pop() - - def visit_While(self, node: ast.While): - """Visit a while loop.""" - self.loop_stack.append("while") - self.generic_visit(node) - self.loop_stack.pop() - - def visit_AsyncFor(self, node: ast.AsyncFor): - """Visit an async for loop.""" - self.loop_stack.append("async for") - self.generic_visit(node) - self.loop_stack.pop() - - def get_summary(self) -> dict: - """Get a summary of the calls found.""" - total_calls = len(self.calls) - calls_in_loops = [c for c in self.calls if c.in_loop] - calls_outside_loops = [c for c in self.calls if not c.in_loop] - - return { - "total_calls": total_calls, - "calls_in_loops": len(calls_in_loops), - "calls_outside_loops": len(calls_outside_loops), - "all_calls": self.calls, - "loop_calls": calls_in_loops, - "non_loop_calls": calls_outside_loops, - } - - -def analyze_file(file_path: str, target_functions: list[str]) -> dict: - """Analyze a Python file for function calls. - - Args: - file_path: Path to the Python file - target_functions: list of function names to track - - Returns: - dictionary with call statistics and details - - """ - with Path.open(file_path) as f: - source = f.read() - - tree = ast.parse(source, filename=file_path) - visitor = FunctionCallVisitor(target_functions, file_path) - visitor.set_source(source) - visitor.visit(tree) - - return visitor.get_summary() - - -def analyze_code(source: str, target_functions: list[str], file_path: Optional[str] = None) -> dict: - """Analyze Python source code for function calls. - - Args: - source: Python source code as string - target_functions: list of function names to track - file_path: Optional file path for reference - - Returns: - dictionary with call statistics and details - - """ - tree = ast.parse(source) - visitor = FunctionCallVisitor(target_functions, file_path) - visitor.set_source(source) - visitor.visit(tree) - - return visitor.get_summary() - - -if __name__ == "__main__": - # Example usage - example_code = """ -import os -import numpy as np - -def process_data(data): - print("Starting processing") - result = [] - - for item in data: - print(f"Processing {item}") - value = len(item) - result.append(value) - - for i in range(3): - print(f"Inner loop {i}") - np.array([1, 2, 3]) - - while len(result) < 10: - print("Adding more items") - result.append(0) - - os.path.join("dir", "file") - print("Done") - return result - -class DataProcessor: - def process(self, items): - for item in items: - self.validate(item) - print(f"Item: {item}") - - def validate(self, item): - if len(item) > 0: - print("Valid") - - @classmethod - def create(cls): - print("Creating processor") - return cls() - - @staticmethod - def utility(): - print("Utility function") -""" - - # Track multiple functions - targets = ["print", "len", "np.array", "os.path.join", "append", "validate"] - results = analyze_code(example_code, targets, "example.py") - - print("Function Call Analysis Results") - print("=" * 50) - print(f"Total calls found: {results['total_calls']}") - print(f"Calls in loops: {results['calls_in_loops']}") - print(f"Calls outside loops: {results['calls_outside_loops']}") - print("\nAll calls:") - print("-" * 50) - for call in results["all_calls"]: - print(f" {call}") - - if results["loop_calls"]: - print("\nCalls within loops:") - print("-" * 50) - for call in results["loop_calls"]: - print(f" {call}") diff --git a/pyproject.toml b/pyproject.toml index ce6d5c1db..dfaa3551c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "pygls>=1.3.1", "codeflash-benchmark", "filelock", + "radon", #for code complexity metrics ] [project.urls] diff --git a/ripgrep_search.py b/ripgrep_search.py deleted file mode 100644 index 365481356..000000000 --- a/ripgrep_search.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env python3 -"""Script to find all occurrences of 'function_name' in the repository using ripgrep. -Returns a dictionary where keys are filepaths and values are lists of (line_no, content) tuples. -""" - -import json -import subprocess -from pathlib import Path -from typing import Dict, List, Tuple - - -def search_with_ripgrep(pattern: str, path: str = ".") -> Dict[str, List[Tuple[int, str]]]: - """Use ripgrep to search for a pattern in the repository. - - Args: - pattern: The pattern to search for - path: The directory to search in (default: current directory) - - Returns: - Dictionary with filepaths as keys and list of (line_no, content) tuples as values - - """ - # Run ripgrep with JSON output for easier parsing - # -n: Show line numbers - # --json: Output in JSON format - # --no-heading: Don't group matches by file - path = str(Path.cwd()) - cmd = [ - "rg", - "-n", - "--json", - pattern, - path, - "-g", - "!/Users/aseemsaxena/Downloads/codeflash_dev/codeflash/code_to_optimize/tests/**", - ] - print(" ".join(cmd)) - - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=False, # Don't raise exception on non-zero return - ) - - if result.returncode not in [0, 1]: # 0 = matches found, 1 = no matches - print(f"Error running ripgrep: {result.stderr}") - return {} - - # Parse the JSON output - matches_dict = {} - - for line in result.stdout.strip().split("\n"): - if not line: - continue - - try: - json_obj = json.loads(line) - - # We're only interested in match objects - if json_obj.get("type") == "match": - data = json_obj.get("data", {}) - file_path = data.get("path", {}).get("text", "") - line_number = data.get("line_number") - line_content = data.get("lines", {}).get("text", "").rstrip("\n") - - if file_path and line_number: - if file_path not in matches_dict: - matches_dict[file_path] = [] - matches_dict[file_path].append((line_number, line_content)) - - except json.JSONDecodeError: - continue - - return matches_dict - - except FileNotFoundError: - print("Error: ripgrep (rg) is not installed or not in PATH") - return {} - except Exception as e: - print(f"Unexpected error: {e}") - return {} - - -def search_with_ripgrep_simple(pattern: str, path: str = ".") -> Dict[str, List[Tuple[int, str]]]: - """Alternative implementation using simpler ripgrep output (non-JSON). - - Args: - pattern: The pattern to search for - path: The directory to search in (default: current directory) - - Returns: - Dictionary with filepaths as keys and list of (line_no, content) tuples as values - - """ - # Run ripgrep with simpler output - # -n: Show line numbers - # --no-heading: Don't group matches by file - cmd = ["rg", "-n", "--no-heading", pattern, path] - - try: - result = subprocess.run(cmd, capture_output=True, text=True, check=False) - - if result.returncode not in [0, 1]: - print(f"Error running ripgrep: {result.stderr}") - return {} - - matches_dict = {} - - # Parse the output (format: filepath:line_number:content) - for line in result.stdout.strip().split("\n"): - if not line: - continue - - # Split only on the first two colons to handle colons in content - parts = line.split(":", 2) - if len(parts) >= 3: - file_path = parts[0] - try: - line_number = int(parts[1]) - line_content = parts[2] - - if file_path not in matches_dict: - matches_dict[file_path] = [] - matches_dict[file_path].append((line_number, line_content)) - except ValueError: - continue - - return matches_dict - - except FileNotFoundError: - print("Error: ripgrep (rg) is not installed or not in PATH") - return {} - except Exception as e: - print(f"Unexpected error: {e}") - return {} - - -def main(): - """Main function to demonstrate usage.""" - # Search for "sorter" in the current repository - pattern = "sorter" - - print(f"Searching for '{pattern}' in the repository...") - print("=" * 60) - - # Use the JSON-based approach - results = search_with_ripgrep(pattern) - - if not results: - print(f"No occurrences of '{pattern}' found.") - else: - print(f"Found occurrences in {len(results)} files:\n") - - for filepath, occurrences in results.items(): - print(f"\nFile: {filepath}") - print(f" Found {len(occurrences)} occurrence(s):") - for line_no, content in occurrences: - # Truncate long lines for display - display_content = content[:100] + "..." if len(content) > 100 else content - print(f" Line {line_no}: {display_content}") - - print("\n" + "=" * 60) - print("Results as dictionary:") - print(json.dumps(results, indent=2)) - - return results - - -if __name__ == "__main__": - results_dict = main() diff --git a/test_ast_vs_libcst.py b/test_ast_vs_libcst.py deleted file mode 100644 index 1f7ef3735..000000000 --- a/test_ast_vs_libcst.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Compare AST and LibCST implementations to ensure they produce the same results.""" - -from function_call_finder import find_function_calls as find_calls_libcst -from function_call_finder_ast import find_function_calls as find_calls_ast - -# Test code with various scenarios -test_code = """ -import module1 -from module2 import func as f2 -import module3 as m3 - -def simple_call(): - target_func() - -def aliased_call(): - f2() - -def qualified_call(): - module1.target_func() - -class TestClass: - def method_with_call(self): - target_func(1, 2, 3) - - def method_without_call(self): - print("nothing") - -def nested_example(): - def inner1(): - target_func() - def inner2(): - pass - inner1() - -async def async_function(): - await target_func() - -def no_call(): - x = 5 -""" - -print("Testing AST vs LibCST implementations\n") -print("=" * 50) - -# Test 1: Direct function calls -print("\nTest 1: Finding 'target_func' calls") -results_ast = find_calls_ast(test_code, "target_func", "/dummy/path.py") -results_libcst = find_calls_libcst(test_code, "target_func", "/dummy/path.py") - -print(f"AST found {len(results_ast)} functions") -print(f"LibCST found {len(results_libcst)} functions") - -ast_keys = set(results_ast.keys()) -libcst_keys = set(results_libcst.keys()) - -print(f"\nAST keys: {sorted(ast_keys)}") -print(f"LibCST keys: {sorted(libcst_keys)}") - -if ast_keys == libcst_keys: - print("✅ Both found the same function names!") -else: - print("❌ Different function names found") - print(f" Only in AST: {ast_keys - libcst_keys}") - print(f" Only in LibCST: {libcst_keys - ast_keys}") - -# Test 2: Check if source code is similar (may have minor formatting differences) -print("\n" + "=" * 50) -print("Test 2: Source code comparison") - -for func_name in ast_keys & libcst_keys: - ast_code = results_ast[func_name].strip() - libcst_code = results_libcst[func_name].strip() - - # Normalize whitespace for comparison - ast_normalized = " ".join(ast_code.split()) - libcst_normalized = " ".join(libcst_code.split()) - - if ast_normalized == libcst_normalized: - print(f"✅ {func_name}: Source code matches (normalized)") - else: - print(f"⚠️ {func_name}: Source code differs") - print(f" AST length: {len(ast_code)} chars") - print(f" LibCST length: {len(libcst_code)} chars") - -# Test 3: Test with imports -print("\n" + "=" * 50) -print("Test 3: Testing with import resolution") - -import_test = """ -from mymodule import target_func as tf - -def uses_alias(): - tf() - -def uses_direct(): - target_func() # This shouldn't match since it's imported as tf -""" - -results_ast_import = find_calls_ast(import_test, "mymodule.target_func", "/dummy/path.py") -results_libcst_import = find_calls_libcst(import_test, "mymodule.target_func", "/dummy/path.py") - -print(f"AST found: {list(results_ast_import.keys())}") -print(f"LibCST found: {list(results_libcst_import.keys())}") - -# Summary -print("\n" + "=" * 50) -print("COMPARISON SUMMARY") -print("=" * 50) - -differences = [] -if ast_keys != libcst_keys: - differences.append("Different function names detected") - -print("\n✅ AST implementation is working correctly") -print("✅ Output format matches: {'func_name': 'source_code'}") - -if not differences: - print("✅ Both implementations produce equivalent results") -else: - print(f"⚠️ Found {len(differences)} differences:") - for diff in differences: - print(f" - {diff}") - -# Performance note -print("\n📝 Performance Note:") -print(" - AST: Built-in, no dependencies, faster parsing") -print(" - LibCST: External dependency, preserves formatting better") -print(" - Both produce the same logical results") diff --git a/test_function_call_finder.py b/test_function_call_finder.py deleted file mode 100644 index 77dd52bcc..000000000 --- a/test_function_call_finder.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Test script to verify the function_call_finder output format.""" - -from function_call_finder import find_function_calls - -# Test code -test_code = """ -def func1(): - target_func() - -def func2(): - pass - -def func3(): - x = target_func(42) - return x - -class TestClass: - def method1(self): - target_func("test") - - def method2(self): - # No call here - pass -""" - -# Run the visitor -results = find_function_calls(test_code, "target_func", "/dummy/path.py") - -# Verify the output format -print("Output type:", type(results)) -print("Output keys:", list(results.keys())) -print("\nExpected format: {qualified_name: source_code}") -print("Actual format check:") - -for name, code in results.items(): - print(f"\n✓ Key (function name): '{name}' -> Type: {type(name).__name__}") - print(f"✓ Value (source code): Type: {type(code).__name__}, Length: {len(code)} chars") - print(f" First line: {code.split(chr(10))[0] if code else 'Empty'}") - -# Verify it's exactly the format requested: {"calling_function_qualified_name1":"function_definition1",....} -import json - -print("\nJSON serializable:", end=" ") -try: - json_str = json.dumps(results) - print("✓ Yes") - print(f"JSON length: {len(json_str)} characters") -except: - print("✗ No") - -print("\n" + "=" * 50) -print("VERIFIED: Output is in the format") -print('{"calling_function_qualified_name1":"function_definition1",...}') diff --git a/test_function_call_visitor.py b/test_function_call_visitor.py deleted file mode 100644 index 41536590f..000000000 --- a/test_function_call_visitor.py +++ /dev/null @@ -1,263 +0,0 @@ -"""Test and demonstrate the FunctionCallVisitor capabilities.""" - -from function_call_visitor import analyze_code, analyze_file - - -def test_basic_calls(): - """Test basic function call detection.""" - code = """ -def example(): - print("Hello") - len([1, 2, 3]) - max([4, 5, 6]) - print("World") -""" - results = analyze_code(code, ["print", "len"]) - print("Test: Basic Calls") - print(f" Found {results['total_calls']} calls") - for call in results["all_calls"]: - print(f" {call}") - print() - - -def test_loop_detection(): - """Test detection of calls within loops.""" - code = """ -def process(): - print("Start") # Outside loop - - for i in range(10): - print(f"Item {i}") # In for loop - len(str(i)) # In for loop - - x = 0 - while x < 5: - print(f"While {x}") # In while loop - x += len([1, 2]) # In while loop - - print("End") # Outside loop -""" - results = analyze_code(code, ["print", "len"]) - print("Test: Loop Detection") - print(f" Total calls: {results['total_calls']}") - print(f" In loops: {results['calls_in_loops']}") - print(f" Outside loops: {results['calls_outside_loops']}") - print(" Loop calls:") - for call in results["loop_calls"]: - print(f" {call}") - print() - - -def test_nested_loops(): - """Test detection in nested loops.""" - code = """ -def nested(): - for i in range(3): - print(f"Outer {i}") - for j in range(2): - print(f"Inner {i},{j}") - while j < 1: - print(f"Innermost") - j += 1 -""" - results = analyze_code(code, ["print"]) - print("Test: Nested Loops") - for call in results["all_calls"]: - print(f" {call}") - print() - - -def test_method_calls(): - """Test detection of method calls.""" - code = """ -class MyClass: - def __init__(self): - self.data = [] - - def process(self): - for item in [1, 2, 3]: - self.data.append(item) - self.validate(item) - - def validate(self, item): - if len(str(item)) > 0: - self.data.append(item * 2) - - @classmethod - def create(cls): - instance = cls() - instance.data.append(0) - return instance - - @staticmethod - def helper(): - result = [] - result.append(1) - return result - -obj = MyClass() -obj.process() -obj.data.append(99) -MyClass.create() -MyClass.helper() -""" - results = analyze_code(code, ["append", "validate", "len"]) - print("Test: Method Calls") - print(f" Found {results['total_calls']} calls") - for call in results["all_calls"]: - print(f" {call}") - print() - - -def test_module_calls(): - """Test detection of module function calls.""" - code = """ -import os.path -import numpy as np -from math import sqrt - -def example(): - # Module function calls - os.path.join("a", "b") - np.array([1, 2, 3]) - sqrt(16) - - for i in range(3): - os.path.exists(f"file_{i}") - np.zeros((2, 2)) - - # Nested module calls - result = os.path.dirname(os.path.join("x", "y")) -""" - results = analyze_code(code, ["os.path.join", "np.array", "sqrt", "os.path.exists", "np.zeros", "os.path.dirname"]) - print("Test: Module Calls") - print(f" Total calls: {results['total_calls']}") - print(" All calls:") - for call in results["all_calls"]: - print(f" {call}") - print() - - -def test_complex_expressions(): - """Test calls in complex expressions.""" - code = """ -def complex_example(): - # Calls in list comprehensions - result = [len(x) for x in ["a", "bb", "ccc"]] - - # Calls in generator expressions - gen = (print(x) for x in range(3)) - - # Nested calls - value = max(len("hello"), len("world")) - - # Calls in lambda - func = lambda x: len(x) + len(x.strip()) - - # Calls in conditionals - if len("test") > 0: - print("Has length") - - # Calls in dict comprehensions - d = {x: len(x) for x in ["key1", "key2"]} -""" - results = analyze_code(code, ["len", "print", "max"]) - print("Test: Complex Expressions") - print(f" Found {results['total_calls']} calls") - for call in results["all_calls"]: - print(f" {call}") - print() - - -def test_async_code(): - """Test async function calls.""" - code = """ -async def async_example(): - print("Starting async") - - async for item in async_generator(): - print(f"Processing {item}") - await process_item(item) - - print("Done") - -async def async_generator(): - for i in range(3): - yield i - -async def process_item(item): - print(f"Item: {item}") -""" - results = analyze_code(code, ["print", "process_item"]) - print("Test: Async Code") - for call in results["all_calls"]: - print(f" {call}") - print() - - -def test_partial_matching(): - """Test partial name matching.""" - code = """ -import os -import os.path -from pathlib import Path - -def file_operations(): - # These should all be caught when looking for 'join' - os.path.join("a", "b") - # path.join("c", "d") # Would need path to be defined - # something.else.join("x") # Would need something to be defined - - # Looking for any 'append' method - list1 = [] - list1.append(1) - list2 = [] - list2.append(2) - # some_obj.data.append(3) # Would need some_obj to be defined -""" - results = analyze_code(code, ["join", "append"]) - print("Test: Partial Matching") - print(" Tracking 'join' and 'append'") - for call in results["all_calls"]: - print(f" {call}") - print() - - -def run_all_tests(): - """Run all test cases.""" - print("=" * 60) - print("FunctionCallVisitor Test Suite") - print("=" * 60) - print() - - test_basic_calls() - test_loop_detection() - test_nested_loops() - test_method_calls() - test_module_calls() - test_complex_expressions() - test_async_code() - test_partial_matching() - - print("=" * 60) - print("All tests completed!") - print("=" * 60) - - -if __name__ == "__main__": - run_all_tests() - - # Example of analyzing an actual file - print("\nExample: Analyzing the visitor file itself") - print("-" * 60) - try: - results = analyze_file("function_call_visitor.py", ["isinstance", "append", "len"]) - print(f"Found {results['total_calls']} calls in function_call_visitor.py") - print(f" In loops: {results['calls_in_loops']}") - print(f" Outside loops: {results['calls_outside_loops']}") - if results["loop_calls"]: - print("\nCalls in loops:") - for call in results["loop_calls"][:5]: # Show first 5 - print(f" {call}") - except FileNotFoundError: - print(" (File not found - run from the same directory)") diff --git a/uv.lock b/uv.lock index e775a576a..132f40b5c 100644 --- a/uv.lock +++ b/uv.lock @@ -326,6 +326,7 @@ dependencies = [ { name = "pygls" }, { name = "pytest" }, { name = "pytest-timeout" }, + { name = "radon" }, { name = "rich" }, { name = "sentry-sdk" }, { name = "timeout-decorator" }, @@ -414,6 +415,7 @@ requires-dist = [ { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", marker = "extra == 'asyncio'", specifier = ">=1.2.0" }, { name = "pytest-timeout", specifier = ">=2.1.0" }, + { name = "radon" }, { name = "rich", specifier = ">=13.8.1" }, { name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" }, { name = "timeout-decorator", specifier = ">=0.5.0" }, @@ -1504,6 +1506,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/c9/e0f8e4e6e8a69e5959b06499582dca6349db6769cc7fdfb8a02a7c75a9ae/lxml_stubs-0.5.1-py3-none-any.whl", hash = "sha256:1f689e5dbc4b9247cb09ae820c7d34daeb1fdbd1db06123814b856dae7787272", size = 13584, upload-time = "2024-01-10T09:37:44.931Z" }, ] +[[package]] +name = "mando" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/24/cd70d5ae6d35962be752feccb7dca80b5e0c2d450e995b16abd6275f3296/mando-0.7.1.tar.gz", hash = "sha256:18baa999b4b613faefb00eac4efadcf14f510b59b924b66e08289aa1de8c3500", size = 37868, upload-time = "2022-02-24T08:12:27.316Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/f0/834e479e47e499b6478e807fb57b31cc2db696c4db30557bb6f5aea4a90b/mando-0.7.1-py2.py3-none-any.whl", hash = "sha256:26ef1d70928b6057ee3ca12583d73c63e05c49de8972d620c278a7b206581a8a", size = 28149, upload-time = "2022-02-24T08:12:25.24Z" }, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -2784,6 +2798,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/28/26534bed77109632a956977f60d8519049f545abc39215d086e33a61f1f2/pyyaml_ft-8.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:de04cfe9439565e32f178106c51dd6ca61afaa2907d143835d501d84703d3793", size = 171579, upload-time = "2025-06-10T15:32:14.34Z" }, ] +[[package]] +name = "radon" +version = "6.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama" }, + { name = "mando" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/6d/98e61600febf6bd929cf04154537c39dc577ce414bafbfc24a286c4fa76d/radon-6.0.1.tar.gz", hash = "sha256:d1ac0053943a893878940fedc8b19ace70386fc9c9bf0a09229a44125ebf45b5", size = 1874992, upload-time = "2023-03-26T06:24:38.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/f7/d00d9b4a0313a6be3a3e0818e6375e15da6d7076f4ae47d1324e7ca986a1/radon-6.0.1-py2.py3-none-any.whl", hash = "sha256:632cc032364a6f8bb1010a2f6a12d0f14bc7e5ede76585ef29dc0cecf4cd8859", size = 52784, upload-time = "2023-03-26T06:24:33.949Z" }, +] + [[package]] name = "readchar" version = "4.2.1" diff --git a/verify_output_format.py b/verify_output_format.py deleted file mode 100644 index 87b2db669..000000000 --- a/verify_output_format.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Verify that both AST and LibCST implementations produce the exact requested output format.""" - -import json - -from function_call_finder import find_function_calls as find_calls_libcst -from function_call_finder_ast import find_function_calls as find_calls_ast - -# Simple test case -test_code = """ -def func1(): - my_target() - -def func2(): - my_target(1, 2, 3) -""" - -print("Verifying output format: {'calling_function_qualified_name1':'function_definition1',...}") -print("=" * 70) - -# Test AST implementation -print("\n1. AST Implementation:") -ast_result = find_calls_ast(test_code, "my_target", "/dummy/path.py") -print(f" Type: {type(ast_result)}") -print(f" Keys type: {type(list(ast_result.keys())[0]) if ast_result else 'N/A'}") -print(f" Values type: {type(list(ast_result.values())[0]) if ast_result else 'N/A'}") -print(f" JSON serializable: {json.dumps(ast_result) is not None}") -print(f" Example output: {json.dumps(ast_result, indent=2)}") - -# Test LibCST implementation -print("\n2. LibCST Implementation:") -libcst_result = find_calls_libcst(test_code, "my_target", "/dummy/path.py") -print(f" Type: {type(libcst_result)}") -print(f" Keys type: {type(list(libcst_result.keys())[0]) if libcst_result else 'N/A'}") -print(f" Values type: {type(list(libcst_result.values())[0]) if libcst_result else 'N/A'}") -print(f" JSON serializable: {json.dumps(libcst_result) is not None}") -print(f" Example output: {json.dumps(libcst_result, indent=2)}") - -# Test with class methods -print("\n3. Testing with class methods:") -class_test = """ -class MyClass: - def method1(self): - target() - - def method2(self): - pass -""" - -ast_class = find_calls_ast(class_test, "target", "/dummy/path.py") -libcst_class = find_calls_libcst(class_test, "target", "/dummy/path.py") - -print(f" AST result: {list(ast_class.keys())}") -print(f" LibCST result: {list(libcst_class.keys())}") - -# Final verification -print("\n" + "=" * 70) -print("✅ VERIFIED: Both implementations return the exact format requested:") -print(' {"calling_function_qualified_name1":"function_definition1",...}') -print("\nKey characteristics:") -print(" - Plain dictionary (dict type)") -print(" - String keys (qualified function names)") -print(" - String values (function source code)") -print(" - JSON serializable") -print(" - No nested structures, just simple key-value pairs") From fc5274d9fb7f55a918b53fe8d1ca8a5c917a1cbf Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 16 Oct 2025 17:19:24 -0700 Subject: [PATCH 10/23] start cleaning up, write tests --- codeflash/code_utils/compat.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/codeflash/code_utils/compat.py b/codeflash/code_utils/compat.py index 66b203429..eb4e5b561 100644 --- a/codeflash/code_utils/compat.py +++ b/codeflash/code_utils/compat.py @@ -1,5 +1,4 @@ import os -import shutil import sys import tempfile from pathlib import Path @@ -18,7 +17,6 @@ class Compat: LF: str = os.linesep SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix() - SAFE_GREP_EXECUTABLE: str = shutil.which("grep") # works even grep is aliased in the env IS_POSIX: bool = os.name != "nt" @@ -47,4 +45,3 @@ def codeflash_cache_db(self) -> Path: LF = _compat.LF SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE IS_POSIX = _compat.IS_POSIX -SAFE_GREP_EXECUTABLE = _compat.SAFE_GREP_EXECUTABLE From 2570a2e9cd6d99dac2411f238e940e335f09502b Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 16 Oct 2025 17:20:41 -0700 Subject: [PATCH 11/23] Apply suggestion from @aseembits93 --- codeflash/api/aiservice.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 6b478f148..e15333d75 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -577,13 +577,6 @@ def get_optimization_impact( ] ) code_diff = f"```diff\n{diff_str}\n```" - # TODO get complexity metrics and fn call heuristics -> constructing a complete static call graph can be expensive for really large repos - # grep function name in codebase -> ast parser to get no of calls and no of calls in loop -> radon lib to get complexity metrics -> send as additional context to the AI service - # metric 1 -> call count - how many times the function is called in the codebase - # metric 2 -> loop call count - how many times the function is called in a loop in the codebase - # metric 3 -> presence of decorators like @profile, @cache -> this means the owner of the repo cares about the performance of this function - # metric 4 -> cyclomatic complexity (https://en.wikipedia.org/wiki/Cyclomatic_complexity) - # metric 5 (for future) -> halstead complexity (https://en.wikipedia.org/wiki/Halstead_complexity_measures) logger.info("!lsp|Computing Optimization Impact…") payload = { "code_diff": code_diff, From ca313a1812a1017cfda0ff4f0c708a95cc8872f0 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 16 Oct 2025 17:22:14 -0700 Subject: [PATCH 12/23] update uv.lock later --- uv.lock | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/uv.lock b/uv.lock index 132f40b5c..e775a576a 100644 --- a/uv.lock +++ b/uv.lock @@ -326,7 +326,6 @@ dependencies = [ { name = "pygls" }, { name = "pytest" }, { name = "pytest-timeout" }, - { name = "radon" }, { name = "rich" }, { name = "sentry-sdk" }, { name = "timeout-decorator" }, @@ -415,7 +414,6 @@ requires-dist = [ { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", marker = "extra == 'asyncio'", specifier = ">=1.2.0" }, { name = "pytest-timeout", specifier = ">=2.1.0" }, - { name = "radon" }, { name = "rich", specifier = ">=13.8.1" }, { name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" }, { name = "timeout-decorator", specifier = ">=0.5.0" }, @@ -1506,18 +1504,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/c9/e0f8e4e6e8a69e5959b06499582dca6349db6769cc7fdfb8a02a7c75a9ae/lxml_stubs-0.5.1-py3-none-any.whl", hash = "sha256:1f689e5dbc4b9247cb09ae820c7d34daeb1fdbd1db06123814b856dae7787272", size = 13584, upload-time = "2024-01-10T09:37:44.931Z" }, ] -[[package]] -name = "mando" -version = "0.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/35/24/cd70d5ae6d35962be752feccb7dca80b5e0c2d450e995b16abd6275f3296/mando-0.7.1.tar.gz", hash = "sha256:18baa999b4b613faefb00eac4efadcf14f510b59b924b66e08289aa1de8c3500", size = 37868, upload-time = "2022-02-24T08:12:27.316Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/f0/834e479e47e499b6478e807fb57b31cc2db696c4db30557bb6f5aea4a90b/mando-0.7.1-py2.py3-none-any.whl", hash = "sha256:26ef1d70928b6057ee3ca12583d73c63e05c49de8972d620c278a7b206581a8a", size = 28149, upload-time = "2022-02-24T08:12:25.24Z" }, -] - [[package]] name = "markdown-it-py" version = "3.0.0" @@ -2798,19 +2784,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/28/26534bed77109632a956977f60d8519049f545abc39215d086e33a61f1f2/pyyaml_ft-8.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:de04cfe9439565e32f178106c51dd6ca61afaa2907d143835d501d84703d3793", size = 171579, upload-time = "2025-06-10T15:32:14.34Z" }, ] -[[package]] -name = "radon" -version = "6.0.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama" }, - { name = "mando" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/6d/98e61600febf6bd929cf04154537c39dc577ce414bafbfc24a286c4fa76d/radon-6.0.1.tar.gz", hash = "sha256:d1ac0053943a893878940fedc8b19ace70386fc9c9bf0a09229a44125ebf45b5", size = 1874992, upload-time = "2023-03-26T06:24:38.868Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/93/f7/d00d9b4a0313a6be3a3e0818e6375e15da6d7076f4ae47d1324e7ca986a1/radon-6.0.1-py2.py3-none-any.whl", hash = "sha256:632cc032364a6f8bb1010a2f6a12d0f14bc7e5ede76585ef29dc0cecf4cd8859", size = 52784, upload-time = "2023-03-26T06:24:33.949Z" }, -] - [[package]] name = "readchar" version = "4.2.1" From 33bcab6893030abf87f180247bd4a68badd50d2e Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 16 Oct 2025 17:26:20 -0700 Subject: [PATCH 13/23] merge conflicts --- codeflash/api/aiservice.py | 7 +++++++ codeflash/optimization/function_optimizer.py | 22 +++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index e15333d75..6b478f148 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -577,6 +577,13 @@ def get_optimization_impact( ] ) code_diff = f"```diff\n{diff_str}\n```" + # TODO get complexity metrics and fn call heuristics -> constructing a complete static call graph can be expensive for really large repos + # grep function name in codebase -> ast parser to get no of calls and no of calls in loop -> radon lib to get complexity metrics -> send as additional context to the AI service + # metric 1 -> call count - how many times the function is called in the codebase + # metric 2 -> loop call count - how many times the function is called in a loop in the codebase + # metric 3 -> presence of decorators like @profile, @cache -> this means the owner of the repo cares about the performance of this function + # metric 4 -> cyclomatic complexity (https://en.wikipedia.org/wiki/Cyclomatic_complexity) + # metric 5 (for future) -> halstead complexity (https://en.wikipedia.org/wiki/Halstead_complexity_measures) logger.info("!lsp|Computing Optimization Impact…") payload = { "code_diff": code_diff, diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d28f06e02..8aaf11ec5 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -24,7 +24,6 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import get_opt_impact_metrics from codeflash.code_utils.code_replacer import ( add_custom_marker_to_all_tests, modify_autouse_fixture, @@ -1462,19 +1461,14 @@ def process_review( if raise_pr or staging_review: data["root_dir"] = git_root_dir() - opt_impact_response = "" - try: - opt_impact_response = self.aiservice_client.get_optimization_impact(**data) - except Exception as e: - logger.debug(f"optimization impact response failed, investigate {e}") - data["optimization_impact"] = opt_impact_response - data["impact_metrics"] = get_opt_impact_metrics( - self.function_to_optimize_source_code, - self.function_to_optimize.file_path, - self.function_to_optimize.qualified_name, - self.project_root, - self.test_cfg.tests_root, - ) + # try: + # # modify argument of staging vs pr based on the impact + # opt_impact_response = self.aiservice_client.get_optimization_impact(**data) + # if opt_impact_response == "low": + # raise_pr = False + # staging_review = True + # except Exception as e: + # logger.debug(f"optimization impact response failed, investigate {e}") if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) From 9cd47438e941b1734e444fecd0b2934f5bdedcf9 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 16 Oct 2025 17:32:13 -0700 Subject: [PATCH 14/23] Apply suggestion from @aseembits93 --- codeflash/code_utils/code_extractor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 8eb794332..22d236f9c 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -16,7 +16,6 @@ from libcst.helpers import calculate_module_and_package from radon.complexity import cc_visit -# from codeflash.benchmarking.pytest_new_process_trace_benchmarks import project_root from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_IMPACT, TIME_LIMIT_FOR_OPT_IMPACT from codeflash.models.models import CodePosition, FunctionParent, ImpactMetrics From 651ded45ace72b9aab838a81d06de7629d57aa51 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Thu, 16 Oct 2025 19:39:30 -0700 Subject: [PATCH 15/23] experiment --- codeflash/api/aiservice.py | 18 +++---- codeflash/api/cfapi.py | 4 ++ codeflash/code_utils/code_extractor.py | 52 ++++++-------------- codeflash/optimization/function_optimizer.py | 29 ++++++++++- codeflash/result/create_pr.py | 1 + uv.lock | 27 ++++++++++ 6 files changed, 81 insertions(+), 50 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 6b478f148..3b3076884 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -544,7 +544,8 @@ def get_optimization_impact( replay_tests: str, root_dir: Path, concolic_tests: str, # noqa: ARG002 - ) -> str: + calling_fn_details: str, + ) -> tuple[str, str]: """Compute the optimization impact of current Pull Request. Args: @@ -558,6 +559,7 @@ def get_optimization_impact( replay_tests: str -> replay test table root_dir: Path -> path of git directory concolic_tests: str -> concolic_tests (not used) + calling_fn_details: str -> filenames and definitions of functions which call the function_to_optimize Returns: ------- @@ -577,13 +579,6 @@ def get_optimization_impact( ] ) code_diff = f"```diff\n{diff_str}\n```" - # TODO get complexity metrics and fn call heuristics -> constructing a complete static call graph can be expensive for really large repos - # grep function name in codebase -> ast parser to get no of calls and no of calls in loop -> radon lib to get complexity metrics -> send as additional context to the AI service - # metric 1 -> call count - how many times the function is called in the codebase - # metric 2 -> loop call count - how many times the function is called in a loop in the codebase - # metric 3 -> presence of decorators like @profile, @cache -> this means the owner of the repo cares about the performance of this function - # metric 4 -> cyclomatic complexity (https://en.wikipedia.org/wiki/Cyclomatic_complexity) - # metric 5 (for future) -> halstead complexity (https://en.wikipedia.org/wiki/Halstead_complexity_measures) logger.info("!lsp|Computing Optimization Impact…") payload = { "code_diff": code_diff, @@ -598,6 +593,7 @@ def get_optimization_impact( "benchmark_details": explanation.benchmark_details if explanation.benchmark_details else None, "optimized_runtime": humanize_runtime(explanation.best_runtime_ns), "original_runtime": humanize_runtime(explanation.original_runtime_ns), + "calling_fn_details": calling_fn_details, } console.rule() try: @@ -605,10 +601,10 @@ def get_optimization_impact( except requests.exceptions.RequestException as e: logger.exception(f"Error generating optimization refinements: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) - return "" + return ("", str(e)) if response.status_code == 200: - return cast("str", response.json()["impact"]) + return (cast("str", response.json()["impact"]), cast("str", response.json()["impact_explanation"])) try: error = cast("str", response.json()["error"]) except Exception: @@ -616,7 +612,7 @@ def get_optimization_impact( logger.error(f"Error generating impact candidates: {response.status_code} - {error}") ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) console.rule() - return "" + return ("", error) class LocalAiServiceClient(AiServiceClient): diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index a0a5685b3..9117199a8 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -173,6 +173,7 @@ def create_pr( coverage_message: str, replay_tests: str = "", concolic_tests: str = "", + optimization_impact: str = "", ) -> Response: """Create a pull request, targeting the specified branch. (usually 'main'). @@ -197,6 +198,7 @@ def create_pr( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, + "optimizationImpact": optimization_impact, } return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload) @@ -212,6 +214,7 @@ def create_staging( replay_tests: str, concolic_tests: str, root_dir: Path, + optimization_impact: str = "", ) -> Response: """Create a staging pull request, targeting the specified branch. (usually 'staging'). @@ -252,6 +255,7 @@ def create_staging( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, + "optimizationImpact": optimization_impact, } return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 22d236f9c..f8b91befb 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1,4 +1,3 @@ -# ruff: noqa: ARG002 from __future__ import annotations import ast @@ -10,15 +9,13 @@ import jedi import libcst as cst -import radon.visitors from libcst.codemod import CodemodContext from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package -from radon.complexity import cc_visit from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_IMPACT, TIME_LIMIT_FOR_OPT_IMPACT -from codeflash.models.models import CodePosition, FunctionParent, ImpactMetrics +from codeflash.models.models import CodePosition, FunctionParent if TYPE_CHECKING: from libcst.helpers import ModuleNameAndPackage @@ -38,28 +35,28 @@ def __init__(self) -> None: self.scope_depth = 0 self.if_else_depth = 0 - def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: + def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: ARG002 self.scope_depth += 1 return True - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 self.scope_depth -= 1 def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: self.scope_depth += 1 return True - def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 self.scope_depth -= 1 - def visit_If(self, node: cst.If) -> Optional[bool]: + def visit_If(self, node: cst.If) -> Optional[bool]: # noqa: ARG002 self.if_else_depth += 1 return True - def leave_If(self, original_node: cst.If) -> None: + def leave_If(self, original_node: cst.If) -> None: # noqa: ARG002 self.if_else_depth -= 1 - def visit_Else(self, node: cst.Else) -> Optional[bool]: + def visit_Else(self, node: cst.Else) -> Optional[bool]: # noqa: ARG002 # Else blocks are already counted as part of the if statement return True @@ -86,21 +83,21 @@ def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: self.scope_depth = 0 self.if_else_depth = 0 - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.scope_depth += 1 def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: self.scope_depth -= 1 return updated_node - def visit_ClassDef(self, node: cst.ClassDef) -> None: + def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002 self.scope_depth += 1 - def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002 self.scope_depth -= 1 return updated_node - def visit_If(self, node: cst.If) -> None: + def visit_If(self, node: cst.If) -> None: # noqa: ARG002 self.if_else_depth += 1 def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: @@ -1156,8 +1153,7 @@ def get_fn_references_jedi( def get_opt_impact_metrics( source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path -) -> ImpactMetrics: - metrics = ImpactMetrics() +) -> str: try: qualified_name_split = qualified_name.rsplit(".", maxsplit=1) if len(qualified_name_split) == 1: @@ -1167,26 +1163,8 @@ def get_opt_impact_metrics( matches = get_fn_references_jedi( source_code, file_path, project_root, target_function, target_class ) # jedi is not perfect, it doesn't capture aliased references - cyclomatic_complexity_results = cc_visit(source_code) - match_found = False - for result in cyclomatic_complexity_results: - if match_found: - break - if isinstance(result, radon.visitors.Function) and not target_class: - if result.name == target_function: - metrics.cyclomatic_complexity = result.complexity - metrics.cyclomatic_complexity_rating = result.letter - match_found = True - elif isinstance(result, radon.visitors.Class) and target_class: # noqa: SIM102 - if result.name == target_class: - for method in result.methods: - if match_found: - break - if method.name == target_function: - metrics.cyclomatic_complexity = method.complexity - metrics.cyclomatic_complexity_rating = method.letter - match_found = True - metrics.calling_fns = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root) + calling_fns_details = find_occurances(qualified_name, str(file_path), matches, project_root, tests_root) except Exception as e: + calling_fns_details = "" logger.debug(f"Investigate {e}") - return metrics + return calling_fns_details diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index e54aac92d..72f3ad868 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -24,6 +24,7 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils +from codeflash.code_utils.code_extractor import get_opt_impact_metrics from codeflash.code_utils.code_replacer import ( add_custom_marker_to_all_tests, modify_autouse_fixture, @@ -1461,12 +1462,36 @@ def process_review( if raise_pr or staging_review: data["root_dir"] = git_root_dir() + calling_fn_details = get_opt_impact_metrics( + self.function_to_optimize_source_code, + self.function_to_optimize.file_path, + self.function_to_optimize.qualified_name, + self.project_root, + self.test_cfg.tests_root, + ) opt_impact_response = "" try: - opt_impact_response = self.aiservice_client.get_optimization_impact(**data) + opt_impact_response = self.aiservice_client.get_optimization_impact( + **data, calling_fn_details=calling_fn_details + ) except Exception as e: logger.debug(f"optimization impact response failed, investigate {e}") - data["optimization_impact"] = opt_impact_response + data["optimization_impact"] = opt_impact_response[0] + new_explanation_with_opt_explanation = Explanation( + raw_explanation_message=f"Impact: {opt_impact_response[0]}\n Impact_explanation: {opt_impact_response[1]} END OF IMPACT EXPLANATION\n" + + new_explanation.raw_explanation_message, + winning_behavior_test_results=explanation.winning_behavior_test_results, + winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, + original_runtime_ns=explanation.original_runtime_ns, + best_runtime_ns=explanation.best_runtime_ns, + function_name=explanation.function_name, + file_path=explanation.file_path, + benchmark_details=explanation.benchmark_details, + original_async_throughput=explanation.original_async_throughput, + best_async_throughput=explanation.best_async_throughput, + ) + best_optimization.explanation_v2 = new_explanation_with_opt_explanation.explanation_message() + data["explanation"] = new_explanation_with_opt_explanation if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index f9fbf84d7..3f1ffa200 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -277,6 +277,7 @@ def check_create_pr( coverage_message=coverage_message, replay_tests=replay_tests, concolic_tests=concolic_tests, + optimization_impact=optimization_impact, ) if response.ok: pr_id = response.text diff --git a/uv.lock b/uv.lock index e775a576a..132f40b5c 100644 --- a/uv.lock +++ b/uv.lock @@ -326,6 +326,7 @@ dependencies = [ { name = "pygls" }, { name = "pytest" }, { name = "pytest-timeout" }, + { name = "radon" }, { name = "rich" }, { name = "sentry-sdk" }, { name = "timeout-decorator" }, @@ -414,6 +415,7 @@ requires-dist = [ { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", marker = "extra == 'asyncio'", specifier = ">=1.2.0" }, { name = "pytest-timeout", specifier = ">=2.1.0" }, + { name = "radon" }, { name = "rich", specifier = ">=13.8.1" }, { name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" }, { name = "timeout-decorator", specifier = ">=0.5.0" }, @@ -1504,6 +1506,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/c9/e0f8e4e6e8a69e5959b06499582dca6349db6769cc7fdfb8a02a7c75a9ae/lxml_stubs-0.5.1-py3-none-any.whl", hash = "sha256:1f689e5dbc4b9247cb09ae820c7d34daeb1fdbd1db06123814b856dae7787272", size = 13584, upload-time = "2024-01-10T09:37:44.931Z" }, ] +[[package]] +name = "mando" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/35/24/cd70d5ae6d35962be752feccb7dca80b5e0c2d450e995b16abd6275f3296/mando-0.7.1.tar.gz", hash = "sha256:18baa999b4b613faefb00eac4efadcf14f510b59b924b66e08289aa1de8c3500", size = 37868, upload-time = "2022-02-24T08:12:27.316Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/f0/834e479e47e499b6478e807fb57b31cc2db696c4db30557bb6f5aea4a90b/mando-0.7.1-py2.py3-none-any.whl", hash = "sha256:26ef1d70928b6057ee3ca12583d73c63e05c49de8972d620c278a7b206581a8a", size = 28149, upload-time = "2022-02-24T08:12:25.24Z" }, +] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -2784,6 +2798,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/28/26534bed77109632a956977f60d8519049f545abc39215d086e33a61f1f2/pyyaml_ft-8.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:de04cfe9439565e32f178106c51dd6ca61afaa2907d143835d501d84703d3793", size = 171579, upload-time = "2025-06-10T15:32:14.34Z" }, ] +[[package]] +name = "radon" +version = "6.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama" }, + { name = "mando" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/6d/98e61600febf6bd929cf04154537c39dc577ce414bafbfc24a286c4fa76d/radon-6.0.1.tar.gz", hash = "sha256:d1ac0053943a893878940fedc8b19ace70386fc9c9bf0a09229a44125ebf45b5", size = 1874992, upload-time = "2023-03-26T06:24:38.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/f7/d00d9b4a0313a6be3a3e0818e6375e15da6d7076f4ae47d1324e7ca986a1/radon-6.0.1-py2.py3-none-any.whl", hash = "sha256:632cc032364a6f8bb1010a2f6a12d0f14bc7e5ede76585ef29dc0cecf4cd8859", size = 52784, upload-time = "2023-03-26T06:24:33.949Z" }, +] + [[package]] name = "readchar" version = "4.2.1" From 65bb2d25f314db59a3f75c51e9bbe1e0708ae9fe Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Fri, 17 Oct 2025 16:09:33 -0700 Subject: [PATCH 16/23] Enhance explanation message with calling context details --- codeflash/optimization/function_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 755fa0b7e..bbf4a157e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1477,7 +1477,7 @@ def process_review( logger.debug(f"optimization impact response failed, investigate {e}") data["optimization_impact"] = opt_impact_response[0] new_explanation_with_opt_explanation = Explanation( - raw_explanation_message=f"Impact: {opt_impact_response[0]}\n Impact_explanation: {opt_impact_response[1]} END OF IMPACT EXPLANATION\n" + raw_explanation_message=f"Impact: {opt_impact_response[0]}\n Impact_explanation: {opt_impact_response[1]} END OF IMPACT EXPLANATION\nCALLING CONTEXT \n{calling_fn_details}\nEND OF CALLING CONTEXT\n" + new_explanation.raw_explanation_message, winning_behavior_test_results=explanation.winning_behavior_test_results, winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, From 26da88f3968a8a33d97c27298209a1671954696e Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 20 Oct 2025 12:22:14 -0700 Subject: [PATCH 17/23] fix precommit later --- codeflash/api/aiservice.py | 14 ++++---- codeflash/api/cfapi.py | 12 +++---- codeflash/code_utils/code_extractor.py | 36 +++++++++----------- codeflash/code_utils/config_consts.py | 3 +- codeflash/models/models.py | 7 ---- codeflash/optimization/function_optimizer.py | 27 ++++----------- codeflash/result/create_pr.py | 6 ++-- pyproject.toml | 1 - uv.lock | 27 --------------- 9 files changed, 39 insertions(+), 94 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 3b3076884..3dd9e457e 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -532,7 +532,7 @@ def generate_regression_tests( # noqa: D417 ph("cli-testgen-error-response", {"response_status_code": response.status_code, "error": response.text}) return None - def get_optimization_impact( + def get_optimization_review( self, original_code: dict[Path, str], new_code: dict[Path, str], @@ -546,7 +546,7 @@ def get_optimization_impact( concolic_tests: str, # noqa: ARG002 calling_fn_details: str, ) -> tuple[str, str]: - """Compute the optimization impact of current Pull Request. + """Compute the optimization review of current Pull Request. Args: original_code: dict -> data structure mapping file paths to function definition for original code @@ -563,7 +563,7 @@ def get_optimization_impact( Returns: ------- - - 'high' or 'low' optimization impact + - 'high', 'medium' or 'low' optimization review """ diff_str = "\n".join( @@ -579,7 +579,7 @@ def get_optimization_impact( ] ) code_diff = f"```diff\n{diff_str}\n```" - logger.info("!lsp|Computing Optimization Impact…") + logger.info("!lsp|Computing Optimization Review…") payload = { "code_diff": code_diff, "explanation": explanation.raw_explanation_message, @@ -597,19 +597,19 @@ def get_optimization_impact( } console.rule() try: - response = self.make_ai_service_request("/optimization_impact", payload=payload, timeout=600) + response = self.make_ai_service_request("/optimization_review", payload=payload, timeout=600) except requests.exceptions.RequestException as e: logger.exception(f"Error generating optimization refinements: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) return ("", str(e)) if response.status_code == 200: - return (cast("str", response.json()["impact"]), cast("str", response.json()["impact_explanation"])) + return (cast("str", response.json()["review"]), cast("str", response.json()["review_explanation"])) try: error = cast("str", response.json()["error"]) except Exception: error = response.text - logger.error(f"Error generating impact candidates: {response.status_code} - {error}") + logger.error(f"Error generating optimization review: {response.status_code} - {error}") ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) console.rule() return ("", error) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 9117199a8..d410f75dd 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -130,7 +130,7 @@ def suggest_changes( coverage_message: str, replay_tests: str = "", concolic_tests: str = "", - optimization_impact: str = "", + optimization_review: str = "", ) -> Response: """Suggest changes to a pull request. @@ -156,7 +156,7 @@ def suggest_changes( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, - "optimizationImpact": optimization_impact, + "optimizationImpact": optimization_review, # impact keyword left for legacy reasons, touches js/ts code } return make_cfapi_request(endpoint="/suggest-pr-changes", method="POST", payload=payload) @@ -173,7 +173,7 @@ def create_pr( coverage_message: str, replay_tests: str = "", concolic_tests: str = "", - optimization_impact: str = "", + optimization_review: str = "", ) -> Response: """Create a pull request, targeting the specified branch. (usually 'main'). @@ -198,7 +198,7 @@ def create_pr( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, - "optimizationImpact": optimization_impact, + "optimizationImpact": optimization_review, # Impact keyword left for legacy reasons, it touches js/ts codebase } return make_cfapi_request(endpoint="/create-pr", method="POST", payload=payload) @@ -214,7 +214,7 @@ def create_staging( replay_tests: str, concolic_tests: str, root_dir: Path, - optimization_impact: str = "", + optimization_review: str = "", ) -> Response: """Create a staging pull request, targeting the specified branch. (usually 'staging'). @@ -255,7 +255,7 @@ def create_staging( "coverage_message": coverage_message, "replayTests": replay_tests, "concolicTests": concolic_tests, - "optimizationImpact": optimization_impact, + "optimizationImpact": optimization_review, # Impact keyword left for legacy reasons, it touches js/ts codebase } return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index f8b91befb..a1e7b498e 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -1,7 +1,6 @@ from __future__ import annotations import ast -import time from dataclasses import dataclass from itertools import chain from pathlib import Path @@ -14,7 +13,7 @@ from libcst.helpers import calculate_module_and_package from codeflash.cli_cmds.console import logger -from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_IMPACT, TIME_LIMIT_FOR_OPT_IMPACT +from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW from codeflash.models.models import CodePosition, FunctionParent if TYPE_CHECKING: @@ -42,7 +41,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]: # noqa: A def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 self.scope_depth -= 1 - def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: + def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002 self.scope_depth += 1 return True @@ -86,7 +85,7 @@ def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.scope_depth += 1 - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 self.scope_depth -= 1 return updated_node @@ -100,7 +99,7 @@ def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef def visit_If(self, node: cst.If) -> None: # noqa: ARG002 self.if_else_depth += 1 - def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: + def leave_If(self, original_node: cst.If, updated_node: cst.If) -> cst.If: # noqa: ARG002 self.if_else_depth -= 1 return updated_node @@ -148,7 +147,7 @@ def _find_insertion_index(self, updated_node: cst.Module) -> int: return insert_index - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 # Add any new assignments that weren't in the original file new_statements = list(updated_node.body) @@ -192,20 +191,20 @@ def __init__(self) -> None: self.global_statements = [] self.in_function_or_class = False - def visit_ClassDef(self, node: cst.ClassDef) -> bool: + def visit_ClassDef(self, node: cst.ClassDef) -> bool: # noqa: ARG002 # Don't visit inside classes self.in_function_or_class = True return False - def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 self.in_function_or_class = False - def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # noqa: ARG002 # Don't visit inside functions self.in_function_or_class = True return False - def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002 self.in_function_or_class = False def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None: @@ -286,16 +285,16 @@ def visit_Module(self, node: cst.Module) -> None: self.depth = 0 self._collect_imports_from_block(node) - def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.depth += 1 - def leave_FunctionDef(self, node: cst.FunctionDef) -> None: + def leave_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002 self.depth -= 1 - def visit_ClassDef(self, node: cst.ClassDef) -> None: + def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002 self.depth += 1 - def leave_ClassDef(self, node: cst.ClassDef) -> None: + def leave_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002 self.depth -= 1 def visit_If(self, node: cst.If) -> None: @@ -329,7 +328,7 @@ def leave_SimpleStatementLine( return cst.Module(body=[updated_node]) - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 # If there were no imports, add at the beginning of the module if self.last_import_line == 0 and not self.inserted: updated_body = list(updated_node.body) @@ -1058,13 +1057,10 @@ def find_function_calls(source_code: str, target_function_name: str, target_file def find_occurances( qualified_name: str, file_path: str, fn_matches: list[Path], project_root: Path, tests_root: Path ) -> list[str]: # max chars for context - start_time = time.time() context_len = 0 fn_call_context = "" for cur_file in fn_matches: - if time.time() - start_time > TIME_LIMIT_FOR_OPT_IMPACT: - break - if context_len > MAX_CONTEXT_LEN_IMPACT: + if context_len > MAX_CONTEXT_LEN_REVIEW: break cur_file_path = Path(cur_file) # exclude references in tests @@ -1151,7 +1147,7 @@ def get_fn_references_jedi( return [] -def get_opt_impact_metrics( +def get_opt_review_metrics( source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path ) -> str: try: diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index ca39c0238..6b2805fbf 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -35,5 +35,4 @@ N_TESTS_TO_GENERATE_EFFECTIVE = N_TESTS_TO_GENERATE_LSP if _IS_LSP_ENABLED else N_TESTS_TO_GENERATE TOTAL_LOOPING_TIME_EFFECTIVE = TOTAL_LOOPING_TIME_LSP if _IS_LSP_ENABLED else TOTAL_LOOPING_TIME -MAX_CONTEXT_LEN_IMPACT = 1000 -TIME_LIMIT_FOR_OPT_IMPACT = 10 # in sec +MAX_CONTEXT_LEN_REVIEW = 1000 diff --git a/codeflash/models/models.py b/codeflash/models/models.py index f3293cc83..84179054e 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -31,13 +31,6 @@ from codeflash.verification.comparator import comparator -@dataclass -class ImpactMetrics: - cyclomatic_complexity: Optional[int] = None - cyclomatic_complexity_rating: Optional[str] = None - calling_fns: Optional[str] = None - - @dataclass(frozen=True) class AIServiceRefinerRequest: optimization_id: str diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 48e98ff0f..805e922cb 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -23,7 +23,7 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_extractor import get_opt_impact_metrics +from codeflash.code_utils.code_extractor import get_opt_review_metrics from codeflash.code_utils.code_replacer import ( add_custom_marker_to_all_tests, modify_autouse_fixture, @@ -1461,36 +1461,21 @@ def process_review( if raise_pr or staging_review: data["root_dir"] = git_root_dir() - calling_fn_details = get_opt_impact_metrics( + calling_fn_details = get_opt_review_metrics( self.function_to_optimize_source_code, self.function_to_optimize.file_path, self.function_to_optimize.qualified_name, self.project_root, self.test_cfg.tests_root, ) - opt_impact_response = "" + opt_review_response = "" try: - opt_impact_response = self.aiservice_client.get_optimization_impact( + opt_review_response = self.aiservice_client.get_optimization_review( **data, calling_fn_details=calling_fn_details ) except Exception as e: - logger.debug(f"optimization impact response failed, investigate {e}") - data["optimization_impact"] = opt_impact_response[0] - new_explanation_with_opt_explanation = Explanation( - raw_explanation_message=f"Impact: {opt_impact_response[0]}\n Impact_explanation: {opt_impact_response[1]} END OF IMPACT EXPLANATION\nCALLING CONTEXT \n{calling_fn_details}\nEND OF CALLING CONTEXT\n" - + new_explanation.raw_explanation_message, - winning_behavior_test_results=explanation.winning_behavior_test_results, - winning_benchmarking_test_results=explanation.winning_benchmarking_test_results, - original_runtime_ns=explanation.original_runtime_ns, - best_runtime_ns=explanation.best_runtime_ns, - function_name=explanation.function_name, - file_path=explanation.file_path, - benchmark_details=explanation.benchmark_details, - original_async_throughput=explanation.original_async_throughput, - best_async_throughput=explanation.best_async_throughput, - ) - best_optimization.explanation_v2 = new_explanation_with_opt_explanation.explanation_message() - data["explanation"] = new_explanation_with_opt_explanation + logger.debug(f"optimization review response failed, investigate {e}") + data["optimization_review_response"] = opt_review_response[0] if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 3f1ffa200..55f3713fd 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -185,7 +185,7 @@ def check_create_pr( concolic_tests: str, root_dir: Path, git_remote: Optional[str] = None, - optimization_impact: str = "", + optimization_review: str = "", ) -> None: pr_number: Optional[int] = env_utils.get_pr_number() git_repo = git.Repo(search_parent_directories=True) @@ -227,7 +227,7 @@ def check_create_pr( coverage_message=coverage_message, replay_tests=replay_tests, concolic_tests=concolic_tests, - optimization_impact=optimization_impact, + optimization_review=optimization_review, ) if response.ok: logger.info(f"Suggestions were successfully made to PR #{pr_number}") @@ -277,7 +277,7 @@ def check_create_pr( coverage_message=coverage_message, replay_tests=replay_tests, concolic_tests=concolic_tests, - optimization_impact=optimization_impact, + optimization_review=optimization_review, ) if response.ok: pr_id = response.text diff --git a/pyproject.toml b/pyproject.toml index dfaa3551c..ce6d5c1db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,6 @@ dependencies = [ "pygls>=1.3.1", "codeflash-benchmark", "filelock", - "radon", #for code complexity metrics ] [project.urls] diff --git a/uv.lock b/uv.lock index 132f40b5c..e775a576a 100644 --- a/uv.lock +++ b/uv.lock @@ -326,7 +326,6 @@ dependencies = [ { name = "pygls" }, { name = "pytest" }, { name = "pytest-timeout" }, - { name = "radon" }, { name = "rich" }, { name = "sentry-sdk" }, { name = "timeout-decorator" }, @@ -415,7 +414,6 @@ requires-dist = [ { name = "pytest", specifier = ">=7.0.0" }, { name = "pytest-asyncio", marker = "extra == 'asyncio'", specifier = ">=1.2.0" }, { name = "pytest-timeout", specifier = ">=2.1.0" }, - { name = "radon" }, { name = "rich", specifier = ">=13.8.1" }, { name = "sentry-sdk", specifier = ">=1.40.6,<3.0.0" }, { name = "timeout-decorator", specifier = ">=0.5.0" }, @@ -1506,18 +1504,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/c9/e0f8e4e6e8a69e5959b06499582dca6349db6769cc7fdfb8a02a7c75a9ae/lxml_stubs-0.5.1-py3-none-any.whl", hash = "sha256:1f689e5dbc4b9247cb09ae820c7d34daeb1fdbd1db06123814b856dae7787272", size = 13584, upload-time = "2024-01-10T09:37:44.931Z" }, ] -[[package]] -name = "mando" -version = "0.7.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/35/24/cd70d5ae6d35962be752feccb7dca80b5e0c2d450e995b16abd6275f3296/mando-0.7.1.tar.gz", hash = "sha256:18baa999b4b613faefb00eac4efadcf14f510b59b924b66e08289aa1de8c3500", size = 37868, upload-time = "2022-02-24T08:12:27.316Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/f0/834e479e47e499b6478e807fb57b31cc2db696c4db30557bb6f5aea4a90b/mando-0.7.1-py2.py3-none-any.whl", hash = "sha256:26ef1d70928b6057ee3ca12583d73c63e05c49de8972d620c278a7b206581a8a", size = 28149, upload-time = "2022-02-24T08:12:25.24Z" }, -] - [[package]] name = "markdown-it-py" version = "3.0.0" @@ -2798,19 +2784,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/28/26534bed77109632a956977f60d8519049f545abc39215d086e33a61f1f2/pyyaml_ft-8.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:de04cfe9439565e32f178106c51dd6ca61afaa2907d143835d501d84703d3793", size = 171579, upload-time = "2025-06-10T15:32:14.34Z" }, ] -[[package]] -name = "radon" -version = "6.0.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "colorama" }, - { name = "mando" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/b1/6d/98e61600febf6bd929cf04154537c39dc577ce414bafbfc24a286c4fa76d/radon-6.0.1.tar.gz", hash = "sha256:d1ac0053943a893878940fedc8b19ace70386fc9c9bf0a09229a44125ebf45b5", size = 1874992, upload-time = "2023-03-26T06:24:38.868Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/93/f7/d00d9b4a0313a6be3a3e0818e6375e15da6d7076f4ae47d1324e7ca986a1/radon-6.0.1-py2.py3-none-any.whl", hash = "sha256:632cc032364a6f8bb1010a2f6a12d0f14bc7e5ede76585ef29dc0cecf4cd8859", size = 52784, upload-time = "2023-03-26T06:24:33.949Z" }, -] - [[package]] name = "readchar" version = "4.2.1" From 2c6abd34f4ee796ebcac00da6439345b1775ac3e Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 20 Oct 2025 12:26:41 -0700 Subject: [PATCH 18/23] precommit fix --- codeflash/code_utils/code_extractor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index a1e7b498e..ffe600c33 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -317,7 +317,9 @@ def __init__(self, global_statements: list[cst.SimpleStatementLine], last_import self.inserted = False def leave_SimpleStatementLine( - self, original_node: cst.SimpleStatementLine, updated_node: cst.SimpleStatementLine + self, + original_node: cst.SimpleStatementLine, # noqa: ARG002 + updated_node: cst.SimpleStatementLine, ) -> cst.Module: self.current_line += 1 @@ -356,7 +358,9 @@ def find_last_import_line(target_code: str) -> int: class FutureAliasedImportTransformer(cst.CSTTransformer): def leave_ImportFrom( - self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + self, + original_node: cst.ImportFrom, # noqa: ARG002 + updated_node: cst.ImportFrom, ) -> cst.BaseSmallStatement | cst.FlattenSentinel[cst.BaseSmallStatement] | cst.RemovalSentinel: import libcst.matchers as m From b5b85be6b0e673e81ba058228c0d9ef143570472 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 20 Oct 2025 12:46:29 -0700 Subject: [PATCH 19/23] Apply suggestion from @aseembits93 --- codeflash/code_utils/code_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index ffe600c33..0ee4c0b89 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -240,7 +240,7 @@ class DottedImportCollector(cst.CSTVisitor): import dbt.adapters.factory ==> "dbt.adapters.factory" from pathlib import Path ==> "pathlib.Path" from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter" - from typing import Any, list, Optional ==> "typing.Any", "typing.list", "typing.Optional" + from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional" from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps" """ From dd4577c1eefdee1dcd751b8b4e3b303a39d54298 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 20 Oct 2025 12:46:39 -0700 Subject: [PATCH 20/23] Apply suggestion from @aseembits93 --- codeflash/optimization/function_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 805e922cb..aa3d699ee 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1475,7 +1475,7 @@ def process_review( ) except Exception as e: logger.debug(f"optimization review response failed, investigate {e}") - data["optimization_review_response"] = opt_review_response[0] + data["optimization_review"] = opt_review_response[0] if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data) From c37505d6f6fa546dca62b304619fb227a5105df8 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 20 Oct 2025 12:46:55 -0700 Subject: [PATCH 21/23] Apply suggestion from @aseembits93 --- codeflash/code_utils/code_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/code_utils/code_extractor.py b/codeflash/code_utils/code_extractor.py index 0ee4c0b89..0a515a080 100644 --- a/codeflash/code_utils/code_extractor.py +++ b/codeflash/code_utils/code_extractor.py @@ -449,7 +449,7 @@ def resolve_star_import(module_name: str, project_root: Path) -> set[str]: and isinstance(node.targets[0], ast.Name) and node.targets[0].id == "__all__" ): - if isinstance(node.value, (ast.list, ast.tuple)): + if isinstance(node.value, (ast.List, ast.Tuple)): all_names = [] for elt in node.value.elts: if isinstance(elt, ast.Constant) and isinstance(elt.value, str): From bdd23cdd9f498ee146ec804703ce5bdcf9851169 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 20 Oct 2025 13:22:46 -0700 Subject: [PATCH 22/23] change function signature --- codeflash/api/aiservice.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 3dd9e457e..f1928a9ac 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -545,7 +545,7 @@ def get_optimization_review( root_dir: Path, concolic_tests: str, # noqa: ARG002 calling_fn_details: str, - ) -> tuple[str, str]: + ) -> str: """Compute the optimization review of current Pull Request. Args: @@ -601,10 +601,10 @@ def get_optimization_review( except requests.exceptions.RequestException as e: logger.exception(f"Error generating optimization refinements: {e}") ph("cli-optimize-error-caught", {"error": str(e)}) - return ("", str(e)) + return "" if response.status_code == 200: - return (cast("str", response.json()["review"]), cast("str", response.json()["review_explanation"])) + return cast("str", response.json()["review"]) try: error = cast("str", response.json()["error"]) except Exception: @@ -612,7 +612,7 @@ def get_optimization_review( logger.error(f"Error generating optimization review: {response.status_code} - {error}") ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) console.rule() - return ("", error) + return "" class LocalAiServiceClient(AiServiceClient): From 0d43f6dfb8173dfb7d70ba68f3b88dbbcf234e8b Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Mon, 20 Oct 2025 13:41:18 -0700 Subject: [PATCH 23/23] minor bug fix --- codeflash/optimization/function_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index aa3d699ee..eb61a689b 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1475,7 +1475,7 @@ def process_review( ) except Exception as e: logger.debug(f"optimization review response failed, investigate {e}") - data["optimization_review"] = opt_review_response[0] + data["optimization_review"] = opt_review_response if raise_pr and not staging_review: data["git_remote"] = self.args.git_remote check_create_pr(**data)