Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
621 changes: 360 additions & 261 deletions codeflash/code_utils/instrument_existing_tests.py

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions codeflash/code_utils/static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,19 @@ def get_first_top_level_object_def_ast(

def get_first_top_level_function_or_method_ast(
function_name: str, parents: list[FunctionParent], node: ast.AST
) -> ast.FunctionDef | None:
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
if not parents:
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
if result is None:
result = get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)
return result
if parents[0].type == "ClassDef" and (
class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
):
return get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
if result is None:
result = get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
return result
return None


Expand Down
28 changes: 26 additions & 2 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def visit_FunctionDef(self, node: FunctionDef) -> None:
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)

def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
# Check if the async function has a return statement and add it to the list
if function_has_return_statement(node) and not function_is_a_property(node):
self.functions.append(
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
)

def generic_visit(self, node: ast.AST) -> None:
if isinstance(node, (FunctionDef, AsyncFunctionDef, ClassDef)):
self.ast_path.append(FunctionParent(node.name, node.__class__.__name__))
Expand Down Expand Up @@ -221,6 +228,7 @@ def get_functions_to_optimize(
f"It might take about {humanize_runtime(functions_count * three_min_in_ns)} to fully optimize this project. Codeflash "
f"will keep opening pull requests as it finds optimizations."
)
console.rule()
return filtered_modified_functions, functions_count, trace_file_path


Expand Down Expand Up @@ -396,11 +404,27 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
)
)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
if self.class_name is None and node.name == self.function_name:
self.is_top_level = True
self.function_has_args = any(
(
bool(node.args.args),
bool(node.args.kwonlyargs),
bool(node.args.kwarg),
bool(node.args.posonlyargs),
bool(node.args.vararg),
)
)

def visit_ClassDef(self, node: ast.ClassDef) -> None:
# iterate over the class methods
if node.name == self.class_name:
for body_node in node.body:
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
if (
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
):
self.is_top_level = True
if any(
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
Expand All @@ -418,7 +442,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
and body_node.name == self.function_name
and body_node.lineno in {self.line_no, self.line_no + 1}
and any(
Expand Down
5 changes: 1 addition & 4 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
diff_length,
file_name_from_test_module_name,
get_run_tmp_file,
has_any_async_functions,
module_name_from_file_path,
restore_conftest,
)
Expand Down Expand Up @@ -110,7 +109,7 @@ def __init__(
test_cfg: TestConfig,
function_to_optimize_source_code: str = "",
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
aiservice_client: AiServiceClient | None = None,
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
Expand Down Expand Up @@ -169,8 +168,6 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code

if has_any_async_functions(code_context.read_writable_code):
return Failure("Codeflash does not support async functions in the code to optimize.")
# Random here means that we still attempt optimization with a fractional chance to see if
# last time we could not find an optimization, maybe this time we do.
# Random is before as a performance optimization, swapping the two 'and' statements has the same effect
Expand Down
3 changes: 2 additions & 1 deletion codeflash/optimization/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_optimizable_functions(self) -> tuple[dict[Path, list[FunctionToOptimize]
def create_function_optimizer(
self,
function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_optimize_ast: ast.FunctionDef | ast.AsyncFunctionDef | None = None,
function_to_tests: dict[str, set[FunctionCalledInTest]] | None = None,
function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
Expand Down Expand Up @@ -269,6 +269,7 @@ def run(self) -> None:
ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions})
if num_optimizable_functions == 0:
logger.info("No functions found to optimize. Exiting…")
console.rule()
return

function_to_tests, _ = self.discover_tests(file_to_funcs_to_optimize)
Expand Down
2 changes: 1 addition & 1 deletion codeflash/telemetry/sentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def init_sentry(enabled: bool = False, exclude_errors: bool = False) -> None: #
)

sentry_sdk.init(
dsn="https://4b9a1902f9361b48c04376df6483bc96@o4506833230561280.ingest.sentry.io/4506833262477312",
dsn="https://4b9a1902f9361b48c04376df6483bc96@o4506833230561280.ingest.us.sentry.io/4506833262477312",
integrations=[sentry_logging],
# Set traces_sample_rate to 1.0 to capture 100%
# of transactions for performance monitoring.
Expand Down
4 changes: 2 additions & 2 deletions codeflash/verification/coverage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def load_from_sqlite_database(
if not database_path.stat().st_size or not database_path.exists():
logger.debug(f"Coverage database {database_path} is empty or does not exist")
sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
return CoverageData.create_empty(source_code_path, function_name, code_context)
cov.load()

reporter = JsonReporter(cov)
Expand All @@ -51,7 +51,7 @@ def load_from_sqlite_database(
reporter.report(morfs=[source_code_path.as_posix()], outfile=f)
except NoDataError:
sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}")
return CoverageUtils.create_empty(source_code_path, function_name, code_context)
return CoverageData.create_empty(source_code_path, function_name, code_context)
with temp_json_file.open() as f:
original_coverage_data = json.load(f)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ dev = [
"types-unidiff>=0.7.0.20240505,<0.8",
"uv>=0.6.2",
"pre-commit>=4.2.0,<5",
"pytest-asyncio>=1.1.0",
]

[tool.hatch.build.targets.sdist]
Expand Down
Loading
Loading