Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 45 additions & 49 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
from _ast import AsyncFunctionDef, ClassDef, FunctionDef
from collections import defaultdict
from functools import cache
from itertools import islice
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional

import git
import libcst as cst
from pydantic.dataclasses import dataclass

from codeflash.api.cfapi import get_blocklisted_functions, make_cfapi_request, is_function_being_optimized_again
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
from codeflash.code_utils.code_utils import (
is_class_defined_in_file,
Expand Down Expand Up @@ -153,38 +154,37 @@ def get_code_context_hash(self) -> str:
to uniquely identify the function for optimization tracking.
"""
try:
with open(self.file_path, 'r', encoding='utf-8') as f:
file_content = f.read()

# Extract the function's code content
lines = file_content.splitlines()
# Read only the necessary lines if possible, otherwise fallback to full file.
if self.starting_line is not None and self.ending_line is not None:
# Use line numbers if available (1-indexed to 0-indexed)
function_content = '\n'.join(lines[self.starting_line - 1:self.ending_line])
# Efficiently read only relevant function lines
start = self.starting_line - 1 # convert to 0-indexed
end = self.ending_line # exclusive
with open(self.file_path, encoding="utf-8") as f:
function_lines = list(islice(f, start, end))
function_content = "".join(function_lines).strip()
else:
# Fallback: use the entire file content if line numbers aren't available
function_content = file_content
with open(self.file_path, encoding="utf-8") as f:
function_content = f.read().strip()

# Create a context string that includes:
# - File path (relative to make it portable)
# - Qualified function name
# - Function code content
# Create a context string that includes filename (for portability),
# qualified function name, and function code content.
context_parts = [
str(self.file_path.name), # Just filename for portability
self.qualified_name,
function_content.strip()
function_content,
]

context_string = '\n---\n'.join(context_parts)
context_string = "\n---\n".join(context_parts)

# Generate SHA-256 hash
return hashlib.sha256(context_string.encode('utf-8')).hexdigest()
return hashlib.sha256(context_string.encode("utf-8")).hexdigest()

except (OSError, IOError) as e:
except OSError as e:
logger.warning(f"Could not read file {self.file_path} for hashing: {e}")
# Fallback hash using available metadata
fallback_string = f"{self.file_path.name}:{self.qualified_name}"
return hashlib.sha256(fallback_string.encode('utf-8')).hexdigest()
return hashlib.sha256(fallback_string.encode("utf-8")).hexdigest()


def get_functions_to_optimize(
optimize_all: str | None,
Expand Down Expand Up @@ -228,7 +228,7 @@ def get_functions_to_optimize(
found_function = None
for fn in functions.get(file, []):
if only_function_name == fn.function_name and (
class_name is None or class_name == fn.top_level_parent_name
class_name is None or class_name == fn.top_level_parent_name
):
found_function = fn
if found_function is None:
Expand Down Expand Up @@ -307,7 +307,7 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt


def get_all_replay_test_functions(
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
replay_test: Path, test_cfg: TestConfig, project_root_path: Path
) -> dict[Path, list[FunctionToOptimize]]:
function_tests = discover_unit_tests(test_cfg, discover_only_these_tests=[replay_test])
# Get the absolute file paths for each function, excluding class name if present
Expand All @@ -322,7 +322,7 @@ def get_all_replay_test_functions(
class_name = (
module_path_parts[-1]
if module_path_parts
and is_class_defined_in_file(
and is_class_defined_in_file(
module_path_parts[-1], Path(project_root_path, *module_path_parts[:-1]).with_suffix(".py")
)
else None
Expand Down Expand Up @@ -374,8 +374,7 @@ def ignored_submodule_paths(module_root: str) -> list[str]:

class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
def __init__(
self, file_name: Path, function_or_method_name: str, class_name: str | None = None,
line_no: int | None = None
self, file_name: Path, function_or_method_name: str, class_name: str | None = None, line_no: int | None = None
) -> None:
self.file_name = file_name
self.class_name = class_name
Expand Down Expand Up @@ -406,13 +405,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
self.is_top_level = True
if any(
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
for decorator in body_node.decorator_list
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
for decorator in body_node.decorator_list
):
self.is_classmethod = True
elif any(
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
):
self.is_staticmethod = True
return
Expand All @@ -421,13 +420,13 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
)
isinstance(body_node, ast.FunctionDef)
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
)
):
self.is_staticmethod = True
self.is_top_level = True
Expand Down Expand Up @@ -460,10 +459,7 @@ def inspect_top_level_functions_or_methods(


def check_optimization_status(
functions_by_file: dict[Path, list[FunctionToOptimize]],
owner: str,
repo: str,
pr_number: int
functions_by_file: dict[Path, list[FunctionToOptimize]], owner: str, repo: str, pr_number: int
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
"""Check which functions have already been optimized and filter them out.

Expand All @@ -480,6 +476,7 @@ def check_optimization_status(

Returns:
Tuple of (filtered_functions_dict, remaining_count)

"""
# Build the code_contexts dictionary for the API call
code_contexts = {}
Expand All @@ -500,7 +497,6 @@ def check_optimization_status(
result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts)
already_optimized_paths = set(result.get("already_optimized_paths", []))


# Filter out already optimized functions
filtered_functions = defaultdict(list)
remaining_count = 0
Expand Down Expand Up @@ -556,12 +552,12 @@ def filter_functions(
test_functions_removed_count += len(_functions)
continue
if file_path in ignore_paths or any(
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
):
ignore_paths_removed_count += 1
continue
if file_path in submodule_paths or any(
file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths
file_path.startswith(str(submodule_path) + os.sep) for submodule_path in submodule_paths
):
submodule_ignored_paths_count += 1
continue
Expand All @@ -579,12 +575,14 @@ def filter_functions(
if blocklist_funcs:
functions_tmp = []
for function in _functions:
if not (
if (
function.file_path.name in blocklist_funcs
and function.qualified_name in blocklist_funcs[function.file_path.name]
):
# This function is in blocklist, we can skip it
blocklist_funcs_removed_count += 1
continue
# This function is NOT in blocklist. we can keep it
functions_tmp.append(function)
_functions = functions_tmp

Expand All @@ -609,9 +607,7 @@ def filter_functions(
owner, repo = get_repo_owner_and_name(repository)
pr_number = get_pr_number()
if owner and repo and pr_number is not None:
path_based_functions, functions_count = check_optimization_status(
path_based_functions, owner, repo, pr_number
)
path_based_functions, functions_count = check_optimization_status(path_based_functions, owner, repo, pr_number)
initial_count = sum(len(funcs) for funcs in filtered_modified_functions.values())
already_optimized_count = initial_count - functions_count

Expand Down Expand Up @@ -652,8 +648,8 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
if submodule_paths is None:
submodule_paths = ignored_submodule_paths(module_root)
return not (
file_path in submodule_paths
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
file_path in submodule_paths
or any(file_path.is_relative_to(submodule_path) for submodule_path in submodule_paths)
)


Expand All @@ -662,4 +658,4 @@ def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef)


def function_is_a_property(function_node: FunctionDef | AsyncFunctionDef) -> bool:
return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list)
return any(isinstance(node, ast.Name) and node.id == "property" for node in function_node.decorator_list)
Loading