Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 86 additions & 41 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import os
import re
from pathlib import Path

import libcst as cst

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_time
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId
from codeflash.result.critic import performance_gain
from codeflash.verification.verification_utils import TestConfig


def remove_functions_from_generated_tests(
Expand Down Expand Up @@ -33,40 +37,46 @@ def remove_functions_from_generated_tests(


def add_runtime_comments_to_generated_tests(
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
test_cfg: TestConfig,
generated_tests: GeneratedTestsList,
original_runtimes: dict[InvocationId, list[int]],
optimized_runtimes: dict[InvocationId, list[int]],
) -> GeneratedTestsList:
"""Add runtime performance comments to function calls in generated tests."""
# Create dictionaries for fast lookup of runtime data
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
tests_root = test_cfg.tests_root
module_root = test_cfg.project_root_path
rel_tests_root = tests_root.relative_to(module_root)

# TODO: reduce for loops to one
class RuntimeCommentTransformer(cst.CSTTransformer):
def __init__(self) -> None:
self.in_test_function = False
self.current_test_name: str | None = None
def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None:
self.test = test
self.context_stack: list[str] = []
self.tests_root = tests_root
self.rel_tests_root = rel_tests_root

def visit_ClassDef(self, node: cst.ClassDef) -> None:
# Track when we enter a class
self.context_stack.append(node.name.value)

def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
# Pop the context when we leave a class
self.context_stack.pop()
return updated_node

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
if node.name.value.startswith("test_"):
self.in_test_function = True
self.current_test_name = node.name.value
else:
self.in_test_function = False
self.current_test_name = None

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if original_node.name.value.startswith("test_"):
self.in_test_function = False
self.current_test_name = None
self.context_stack.append(node.name.value)

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
# Pop the context when we leave a function
self.context_stack.pop()
return updated_node

def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
) -> cst.SimpleStatementLine:
if not self.in_test_function or not self.current_test_name:
return updated_node

# Look for assignment statements that assign to codeflash_output
# Handle both single statements and multiple statements on one line
codeflash_assignment_found = False
Expand All @@ -83,30 +93,65 @@ def leave_SimpleStatementLine(
# Find matching test cases by looking for this test function name in the test results
matching_original_times = []
matching_optimized_times = []

for invocation_id, runtimes in original_runtime_by_test.items():
if invocation_id.test_function_name == self.current_test_name:
# TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name
for invocation_id, runtimes in original_runtimes.items():
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
else invocation_id.test_function_name
)
rel_path = (
Path(invocation_id.test_module_path.replace(".", os.sep))
.with_suffix(".py")
.relative_to(self.rel_tests_root)
)
if qualified_name == ".".join(self.context_stack) and rel_path in [
self.test.behavior_file_path.relative_to(self.tests_root),
self.test.perf_file_path.relative_to(self.tests_root),
]:
matching_original_times.extend(runtimes)

for invocation_id, runtimes in optimized_runtime_by_test.items():
if invocation_id.test_function_name == self.current_test_name:
for invocation_id, runtimes in optimized_runtimes.items():
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
else invocation_id.test_function_name
)
rel_path = (
Path(invocation_id.test_module_path.replace(".", os.sep))
.with_suffix(".py")
.relative_to(self.rel_tests_root)
)
if qualified_name == ".".join(self.context_stack) and rel_path in [
self.test.behavior_file_path.relative_to(self.tests_root),
self.test.perf_file_path.relative_to(self.tests_root),
]:
matching_optimized_times.extend(runtimes)

if matching_original_times and matching_optimized_times:
original_time = min(matching_original_times)
optimized_time = min(matching_optimized_times)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for optimized_time or original_time = 0


# Create the runtime comment
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"

# Add comment to the trailing whitespace
new_trailing_whitespace = cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment(comment_text),
newline=updated_node.trailing_whitespace.newline,
)

return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
if original_time != 0 and optimized_time != 0:
perf_gain = format_perf(
abs(
performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time)
* 100
)
)
status = "slower" if optimized_time > original_time else "faster"
# Create the runtime comment
comment_text = (
f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
)

# Add comment to the trailing whitespace
new_trailing_whitespace = cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment(comment_text),
newline=updated_node.trailing_whitespace.newline,
)

return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)

return updated_node

Expand All @@ -118,7 +163,7 @@ def leave_SimpleStatementLine(
tree = cst.parse_module(test.generated_original_test_source)

# Transform the tree to add runtime comments
transformer = RuntimeCommentTransformer()
transformer = RuntimeCommentTransformer(test, tests_root, rel_tests_root)
modified_tree = tree.visit(transformer)

# Convert back to source code
Expand Down
12 changes: 12 additions & 0 deletions codeflash/code_utils/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,15 @@ def format_time(nanoseconds: int) -> str:

# This should never be reached, but included for completeness
return f"{nanoseconds}ns"


def format_perf(percentage: float) -> str:
"""Format percentage into a human-readable string with 3 significant digits when needed."""
percentage_abs = abs(percentage)
if percentage_abs >= 100:
return f"{percentage:.0f}"
if percentage_abs >= 10:
return f"{percentage:.1f}"
if percentage_abs >= 1:
return f"{percentage:.2f}"
return f"{percentage:.3f}"
1 change: 1 addition & 0 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:

def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
# Efficient single traversal, directly accumulating into a dict.
# can track mins here and only sums can be return in total_passed_runtime
by_id: dict[InvocationId, list[int]] = {}
for result in self.test_results:
if result.did_pass:
Expand Down
23 changes: 14 additions & 9 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,12 +341,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
optimized_function=best_optimization.candidate.source_code,
)

existing_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
function_to_all_tests,
tests_root=self.test_cfg.tests_root,
)

original_code_combined = original_helper_code.copy()
original_code_combined[explanation.file_path] = self.function_to_optimize_source_code
new_code_combined = new_helper_code.copy()
Expand All @@ -360,15 +354,26 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
)
original_runtime_by_test = (
original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
)
optimized_runtime_by_test = (
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
)
# Add runtime comments to generated tests before creating the PR
generated_tests = add_runtime_comments_to_generated_tests(
generated_tests,
original_code_baseline.benchmarking_test_results,
best_optimization.winning_benchmarking_test_results,
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
)
generated_tests_str = "\n\n".join(
[test.generated_original_test_source for test in generated_tests.generated_tests]
)
existing_tests = existing_tests_source_for(
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
function_to_all_tests,
test_cfg=self.test_cfg,
original_runtimes_all=original_runtime_by_test,
optimized_runtimes_all=optimized_runtime_by_test,
)
if concolic_test_str:
generated_tests_str += "\n\n" + concolic_test_str

Expand Down
104 changes: 97 additions & 7 deletions codeflash/result/create_pr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional

Expand All @@ -16,24 +17,113 @@
git_root_dir,
)
from codeflash.code_utils.github_utils import github_pr_url
from codeflash.code_utils.tabulate import tabulate
from codeflash.code_utils.time_utils import format_perf, format_time
from codeflash.github.PrComment import FileDiffContent, PrComment
from codeflash.result.critic import performance_gain

if TYPE_CHECKING:
from codeflash.models.models import FunctionCalledInTest
from codeflash.models.models import FunctionCalledInTest, InvocationId
from codeflash.result.explanation import Explanation
from codeflash.verification.verification_utils import TestConfig


def existing_tests_source_for(
function_qualified_name_with_modules_from_root: str,
function_to_tests: dict[str, set[FunctionCalledInTest]],
tests_root: Path,
test_cfg: TestConfig,
original_runtimes_all: dict[InvocationId, list[int]],
optimized_runtimes_all: dict[InvocationId, list[int]],
) -> str:
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)
existing_tests_unique = set()
if test_files:
for test_file in test_files:
existing_tests_unique.add("- " + str(Path(test_file.tests_in_file.test_file).relative_to(tests_root)))
return "\n".join(sorted(existing_tests_unique))
if not test_files:
return ""
output: str = ""
rows = []
headers = ["Test File::Test Function", "Original ⏱️", "Optimized ⏱️", "Speedup"]
tests_root = test_cfg.tests_root
module_root = test_cfg.project_root_path
rel_tests_root = tests_root.relative_to(module_root)
original_tests_to_runtimes: dict[Path, dict[str, int]] = {}
optimized_tests_to_runtimes: dict[Path, dict[str, int]] = {}
non_generated_tests = set()
for test_file in test_files:
non_generated_tests.add(Path(test_file.tests_in_file.test_file).relative_to(tests_root))
# TODO confirm that original and optimized have the same keys
all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys()
for invocation_id in all_invocation_ids:
rel_path = (
Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").relative_to(rel_tests_root)
)
if rel_path not in non_generated_tests:
continue
if rel_path not in original_tests_to_runtimes:
original_tests_to_runtimes[rel_path] = {}
if rel_path not in optimized_tests_to_runtimes:
optimized_tests_to_runtimes[rel_path] = {}
qualified_name = (
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
if invocation_id.test_class_name
else invocation_id.test_function_name
)
if qualified_name not in original_tests_to_runtimes[rel_path]:
original_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index]
if qualified_name not in optimized_tests_to_runtimes[rel_path]:
optimized_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index]
if invocation_id in original_runtimes_all:
original_tests_to_runtimes[rel_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index]
if invocation_id in optimized_runtimes_all:
optimized_tests_to_runtimes[rel_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index]
# parse into string
all_rel_paths = (
original_tests_to_runtimes.keys()
) # both will have the same keys as some default values are assigned in the previous loop
for filename in sorted(all_rel_paths):
all_qualified_names = original_tests_to_runtimes[
filename
].keys() # both will have the same keys as some default values are assigned in the previous loop
for qualified_name in sorted(all_qualified_names):
# if not present in optimized output nan
if (
original_tests_to_runtimes[filename][qualified_name] != 0
and optimized_tests_to_runtimes[filename][qualified_name] != 0
):
print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name])
print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name])
greater = (
optimized_tests_to_runtimes[filename][qualified_name]
> original_tests_to_runtimes[filename][qualified_name]
)
perf_gain = format_perf(
performance_gain(
original_runtime_ns=original_tests_to_runtimes[filename][qualified_name],
optimized_runtime_ns=optimized_tests_to_runtimes[filename][qualified_name],
)
* 100
)
if greater:
rows.append(
[
f"`{filename}::{qualified_name}`",
f"{print_original_runtime}",
f"{print_optimized_runtime}",
f"⚠️{perf_gain}%",
]
)
else:
rows.append(
[
f"`{filename}::{qualified_name}`",
f"{print_original_runtime}",
f"{print_optimized_runtime}",
f"✅{perf_gain}%",
]
)
output += tabulate( # type: ignore[no-untyped-call]
headers=headers, tabular_data=rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True
)
output += "\n"
return output


def check_create_pr(
Expand Down
Loading
Loading