diff --git a/codeflash/benchmarking/codeflash_trace.py b/codeflash/benchmarking/codeflash_trace.py index 35232f954..fbff70324 100644 --- a/codeflash/benchmarking/codeflash_trace.py +++ b/codeflash/benchmarking/codeflash_trace.py @@ -6,6 +6,7 @@ import time from typing import Callable +from codeflash.cli_cmds.cli import logger from codeflash.picklepatch.pickle_patcher import PicklePatcher @@ -42,10 +43,8 @@ def setup(self, trace_path: str) -> None: ) self._connection.commit() except Exception as e: - print(f"Database setup error: {e}") - if self._connection: - self._connection.close() - self._connection = None + logger.error(f"Database setup error: {e}") + self.close() raise def write_function_timings(self) -> None: @@ -63,18 +62,17 @@ def write_function_timings(self) -> None: try: cur = self._connection.cursor() - # Insert data into the benchmark_function_timings table cur.executemany( "INSERT INTO benchmark_function_timings" "(function_name, class_name, module_name, file_path, benchmark_function_name, " "benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) " "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", - self.function_calls_data + self.function_calls_data, ) self._connection.commit() self.function_calls_data = [] except Exception as e: - print(f"Error writing to function timings database: {e}") + logger.error(f"Error writing to function timings database: {e}") if self._connection: self._connection.rollback() raise @@ -100,9 +98,10 @@ def __call__(self, func: Callable) -> Callable: The wrapped function """ - func_id = (func.__module__,func.__name__) + func_id = (func.__module__, func.__name__) + @functools.wraps(func) - def wrapper(*args, **kwargs): + def wrapper(*args: tuple, **kwargs: dict) -> object: # Initialize thread-local active functions set if it doesn't exist if not hasattr(self._thread_local, "active_functions"): self._thread_local.active_functions = set() @@ -123,25 +122,33 @@ def wrapper(*args, **kwargs): if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False": self._thread_local.active_functions.remove(func_id) return result - # Get benchmark info from environment + benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "") benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "") benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "") - # Get class name class_name = "" qualname = func.__qualname__ if "." in qualname: class_name = qualname.split(".")[0] - # Limit pickle count so memory does not explode if self.function_call_count > self.pickle_count_limit: - print("Pickle limit reached") + logger.debug("CodeflashTrace: Pickle limit reached") self._thread_local.active_functions.remove(func_id) overhead_time = time.thread_time_ns() - end_time self.function_calls_data.append( - (func.__name__, class_name, func.__module__, func.__code__.co_filename, - benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, - overhead_time, None, None) + ( + func.__name__, + class_name, + func.__module__, + func.__code__.co_filename, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + None, + None, + ) ) return result @@ -150,17 +157,26 @@ def wrapper(*args, **kwargs): pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) except Exception as e: - print(f"Error pickling arguments for function {func.__name__}: {e}") + logger.debug(f"CodeflashTrace: Error pickling arguments for function {func.__name__}: {e}") # Add to the list of function calls without pickled args. Used for timing info only self._thread_local.active_functions.remove(func_id) overhead_time = time.thread_time_ns() - end_time self.function_calls_data.append( - (func.__name__, class_name, func.__module__, func.__code__.co_filename, - benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, - overhead_time, None, None) + ( + func.__name__, + class_name, + func.__module__, + func.__code__.co_filename, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + None, + None, + ) ) return result - # Flush to database every 100 calls if len(self.function_calls_data) > 100: self.write_function_timings() @@ -168,12 +184,23 @@ def wrapper(*args, **kwargs): self._thread_local.active_functions.remove(func_id) overhead_time = time.thread_time_ns() - end_time self.function_calls_data.append( - (func.__name__, class_name, func.__module__, func.__code__.co_filename, - benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, - overhead_time, pickled_args, pickled_kwargs) + ( + func.__name__, + class_name, + func.__module__, + func.__code__.co_filename, + benchmark_function_name, + benchmark_module_path, + benchmark_line_number, + execution_time, + overhead_time, + pickled_args, + pickled_kwargs, + ) ) return result + return wrapper -# Create a singleton instance + codeflash_trace = CodeflashTrace() diff --git a/codeflash/benchmarking/instrument_codeflash_trace.py b/codeflash/benchmarking/instrument_codeflash_trace.py index 044b0b0a4..603d6405b 100644 --- a/codeflash/benchmarking/instrument_codeflash_trace.py +++ b/codeflash/benchmarking/instrument_codeflash_trace.py @@ -13,39 +13,35 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None: self.added_codeflash_trace = False self.class_name = "" self.function_name = "" - self.decorator = cst.Decorator( - decorator=cst.Name(value="codeflash_trace") - ) + self.decorator = cst.Decorator(decorator=cst.Name(value="codeflash_trace")) - def leave_ClassDef(self, original_node, updated_node): + def leave_ClassDef(self, original_node, updated_node): # noqa: ANN001, ANN201, N802 if self.class_name == original_node.name.value: - self.class_name = "" # Even if nested classes are not visited, this function is still called on them + self.class_name = "" # Even if nested classes are not visited, this function is still called on them return updated_node - def visit_ClassDef(self, node): - if self.class_name: # Don't go into nested class + def visit_ClassDef(self, node): # noqa: ANN001, ANN201, N802 + if self.class_name: # Don't go into nested class return False - self.class_name = node.name.value + self.class_name = node.name.value # noqa: RET503 - def visit_FunctionDef(self, node): - if self.function_name: # Don't go into nested function + def visit_FunctionDef(self, node): # noqa: ANN001, ANN201, N802 + if self.function_name: # Don't go into nested function return False - self.function_name = node.name.value + self.function_name = node.name.value # noqa: RET503 - def leave_FunctionDef(self, original_node, updated_node): + def leave_FunctionDef(self, original_node, updated_node): # noqa: ANN001, ANN201, N802 if self.function_name == original_node.name.value: self.function_name = "" if (self.class_name, original_node.name.value) in self.target_functions: # Add the new decorator after any existing decorators, so it gets executed first - updated_decorators = list(updated_node.decorators) + [self.decorator] + updated_decorators = [*list(updated_node.decorators), self.decorator] self.added_codeflash_trace = True - return updated_node.with_changes( - decorators=updated_decorators - ) + return updated_node.with_changes(decorators=updated_decorators) return updated_node - def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002, N802 # Create import statement for codeflash_trace if not self.added_codeflash_trace: return updated_node @@ -53,17 +49,10 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c body=[ cst.ImportFrom( module=cst.Attribute( - value=cst.Attribute( - value=cst.Name(value="codeflash"), - attr=cst.Name(value="benchmarking") - ), - attr=cst.Name(value="codeflash_trace") + value=cst.Attribute(value=cst.Name(value="codeflash"), attr=cst.Name(value="benchmarking")), + attr=cst.Name(value="codeflash_trace"), ), - names=[ - cst.ImportAlias( - name=cst.Name(value="codeflash_trace") - ) - ] + names=[cst.ImportAlias(name=cst.Name(value="codeflash_trace"))], ) ] ) @@ -73,12 +62,13 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return updated_node.with_changes(body=new_body) + def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str: """Add codeflash_trace to a function. Args: code: The source code as a string - function_to_optimize: The FunctionToOptimize instance containing function details + functions_to_optimize: List of FunctionToOptimize instances containing function details Returns: The modified source code as a string @@ -91,27 +81,17 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct class_name = function_to_optimize.parents[0].name target_functions.add((class_name, function_to_optimize.function_name)) - transformer = AddDecoratorTransformer( - target_functions = target_functions, - ) + transformer = AddDecoratorTransformer(target_functions=target_functions) module = cst.parse_module(code) modified_module = module.visit(transformer) return modified_module.code -def instrument_codeflash_trace_decorator( - file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] -) -> None: +def instrument_codeflash_trace_decorator(file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]) -> None: """Instrument codeflash_trace decorator to functions to optimize.""" for file_path, functions_to_optimize in file_to_funcs_to_optimize.items(): original_code = file_path.read_text(encoding="utf-8") - new_code = add_codeflash_decorator_to_code( - original_code, - functions_to_optimize - ) - # Modify the code + new_code = add_codeflash_decorator_to_code(original_code, functions_to_optimize) modified_code = isort.code(code=new_code, float_to_top=True) - - # Write the modified code back to the file file_path.write_text(modified_code, encoding="utf-8") diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index 313817041..f29d93fab 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -5,10 +5,12 @@ import sys import time from pathlib import Path +from typing import Any, Callable import pytest from codeflash.benchmarking.codeflash_trace import codeflash_trace +from codeflash.cli_cmds.cli import logger from codeflash.code_utils.code_utils import module_name_from_file_path from codeflash.models.models import BenchmarkKey @@ -20,9 +22,8 @@ def __init__(self) -> None: self.project_root = None self.benchmark_timings = [] - def setup(self, trace_path:str, project_root:str) -> None: + def setup(self, trace_path: str, project_root: str) -> None: try: - # Open connection self.project_root = project_root self._trace_path = trace_path self._connection = sqlite3.connect(self._trace_path) @@ -35,12 +36,10 @@ def setup(self, trace_path:str, project_root:str) -> None: "benchmark_time_ns INTEGER)" ) self._connection.commit() - self.close() # Reopen only at the end of pytest session + self.close() except Exception as e: - print(f"Database setup error: {e}") - if self._connection: - self._connection.close() - self._connection = None + logger.error(f"Database setup error: {e}") + self.close() raise def write_benchmark_timings(self) -> None: @@ -52,17 +51,17 @@ def write_benchmark_timings(self) -> None: try: cur = self._connection.cursor() - # Insert data into the benchmark_timings table cur.executemany( "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)", - self.benchmark_timings + self.benchmark_timings, ) self._connection.commit() - self.benchmark_timings = [] # Clear the benchmark timings list + self.benchmark_timings.clear() except Exception as e: - print(f"Error writing to benchmark timings database: {e}") + logger.error(f"Error writing to benchmark timings database: {e}") self._connection.rollback() raise + def close(self) -> None: if self._connection: self._connection.close() @@ -82,22 +81,18 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark - Values are function timing in milliseconds """ - # Initialize the result dictionary result = {} - # Connect to the SQLite database connection = sqlite3.connect(trace_path) cursor = connection.cursor() try: - # Query the function_calls table for all function calls cursor.execute( "SELECT module_name, class_name, function_name, " "benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns " "FROM benchmark_function_timings" ) - # Process each row for row in cursor.fetchall(): module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row @@ -109,7 +104,6 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark # Create the benchmark key (file::function::line) benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) - # Initialize the inner dictionary if needed if qualified_name not in result: result[qualified_name] = {} @@ -121,7 +115,6 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark result[qualified_name][benchmark_key] = time_ns finally: - # Close the connection connection.close() return result @@ -139,11 +132,9 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]: - Values are total benchmark timing in milliseconds (with overhead subtracted) """ - # Initialize the result dictionary result = {} overhead_by_benchmark = {} - # Connect to the SQLite database connection = sqlite3.connect(trace_path) cursor = connection.cursor() @@ -155,7 +146,6 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]: "GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number" ) - # Process overhead information for row in cursor.fetchall(): benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) @@ -167,57 +157,48 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]: "FROM benchmark_timings" ) - # Process each row and subtract overhead for row in cursor.fetchall(): benchmark_file, benchmark_func, benchmark_line, time_ns = row - # Create the benchmark key (file::function::line) - benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) + benchmark_key = BenchmarkKey( + module_path=benchmark_file, function_name=benchmark_func + ) # (file::function::line) # Subtract overhead from total time overhead = overhead_by_benchmark.get(benchmark_key, 0) result[benchmark_key] = time_ns - overhead finally: - # Close the connection connection.close() return result - # Pytest hooks @pytest.hookimpl - def pytest_sessionfinish(self, session, exitstatus): + def pytest_sessionfinish(self, session: pytest.Session, exitstatus: int) -> None: # noqa: ARG002 """Execute after whole test run is completed.""" - # Write any remaining benchmark timings to the database codeflash_trace.close() if self.benchmark_timings: self.write_benchmark_timings() - # Close the database connection self.close() @staticmethod - def pytest_addoption(parser): - parser.addoption( - "--codeflash-trace", - action="store_true", - default=False, - help="Enable CodeFlash tracing" - ) + def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption("--codeflash-trace", action="store_true", default=False, help="Enable CodeFlash tracing") @staticmethod - def pytest_plugin_registered(plugin, manager): + def pytest_plugin_registered(plugin: Any, manager: Any) -> None: # noqa: ANN401 # Not necessary since run with -p no:benchmark, but just in case if hasattr(plugin, "name") and plugin.name == "pytest-benchmark": manager.unregister(plugin) @staticmethod - def pytest_configure(config): + def pytest_configure(config: pytest.Config) -> None: """Register the benchmark marker.""" config.addinivalue_line( - "markers", - "benchmark: mark test as a benchmark that should be run with codeflash tracing" + "markers", "benchmark: mark test as a benchmark that should be run with codeflash tracing" ) + @staticmethod - def pytest_collection_modifyitems(config, items): + def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: # Skip tests that don't have the benchmark fixture if not config.getoption("--codeflash-trace"): return @@ -240,54 +221,45 @@ def pytest_collection_modifyitems(config, items): # Benchmark fixture class Benchmark: - def __init__(self, request): + """Benchmark fixture class for running and timing benchmarked functions.""" + + def __init__(self, request: pytest.FixtureRequest) -> None: self.request = request + self._call_count = 0 - def __call__(self, func, *args, **kwargs): - """Handle both direct function calls and decorator usage.""" - if args or kwargs: - # Used as benchmark(func, *args, **kwargs) - return self._run_benchmark(func, *args, **kwargs) - # Used as @benchmark decorator - def wrapped_func(*args, **kwargs): - return func(*args, **kwargs) - result = self._run_benchmark(func) - return wrapped_func - - def _run_benchmark(self, func, *args, **kwargs): - """Actual benchmark implementation.""" - benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), - Path(codeflash_benchmark_plugin.project_root)) - benchmark_function_name = self.request.node.name + def __call__(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + benchmark_module_path = module_name_from_file_path( + Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root) + ) + node_name = self.request.node.name + benchmark_function_name = node_name.split("[", 1)[0] if "[" in node_name else node_name line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack - # Set env vars + + os.environ["CODEFLASH_BENCHMARKING"] = "True" os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number) os.environ["CODEFLASH_BENCHMARKING"] = "True" - # Run the function - start = time.time_ns() + start = time.perf_counter_ns() result = func(*args, **kwargs) - end = time.time_ns() - # Reset the environment variable + end = time.perf_counter_ns() os.environ["CODEFLASH_BENCHMARKING"] = "False" - # Write function calls codeflash_trace.write_function_timings() - # Reset function call count codeflash_trace.function_call_count = 0 - # Add to the benchmark timings buffer codeflash_benchmark_plugin.benchmark_timings.append( - (benchmark_module_path, benchmark_function_name, line_number, end - start)) + (benchmark_module_path, benchmark_function_name, line_number, end - start) + ) return result @staticmethod @pytest.fixture - def benchmark(request): + def benchmark(request: pytest.FixtureRequest) -> CodeFlashBenchmarkPlugin.Benchmark | None: if not request.config.getoption("--codeflash-trace"): return None return CodeFlashBenchmarkPlugin.Benchmark(request) + codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin() diff --git a/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py b/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py index 232c39fa7..39e35909f 100644 --- a/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py +++ b/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py @@ -16,8 +16,20 @@ codeflash_benchmark_plugin.setup(trace_file, project_root) codeflash_trace.setup(trace_file) exitcode = pytest.main( - [benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin] - ) # Errors will be printed to stdout, not stderr + [ + benchmarks_root, + "--codeflash-trace", + "-p", + "no:benchmark", + "-p", + "no:codspeed", + "-p", + "no:cov-s", + "-o", + "addopts=", + ], + plugins=[codeflash_benchmark_plugin], + ) # Errors will be printed to stdout, not stderr except Exception as e: print(f"Failed to collect tests: {e!s}", file=sys.stderr) diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index ee1107241..ba2b47277 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -16,20 +16,23 @@ def get_next_arg_and_return( - trace_file: str, benchmark_function_name:str, function_name: str, file_path: str, class_name: str | None = None, num_to_get: int = 256 + trace_file: str, + benchmark_function_name: str, + function_name: str, + file_path: str, + class_name: str | None = None, + num_to_get: int = 256, ) -> Generator[Any]: db = sqlite3.connect(trace_file) cur = db.cursor() limit = num_to_get if class_name is not None: - cursor = cur.execute( - "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?", - (benchmark_function_name, function_name, file_path, class_name, limit), - ) + query = "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?" # noqa: E501 + cursor = cur.execute(query(benchmark_function_name, function_name, file_path, class_name, limit)) else: cursor = cur.execute( - "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?", + "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?", # noqa: E501 (benchmark_function_name, function_name, file_path, limit), ) @@ -42,10 +45,7 @@ def get_function_alias(module: str, function_name: str) -> str: def create_trace_replay_test_code( - trace_file: str, - functions_data: list[dict[str, Any]], - test_framework: str = "pytest", - max_run_count=256 + trace_file: str, functions_data: list[dict[str, Any]], test_framework: str = "pytest", max_run_count=256 ) -> str: """Create a replay test for functions based on trace data. @@ -83,8 +83,9 @@ def create_trace_replay_test_code( imports += "\n".join(function_imports) - functions_to_optimize = sorted({func.get("function_name") for func in functions_data - if func.get("function_name") != "__init__"}) + functions_to_optimize = sorted( + {func.get("function_name") for func in functions_data if func.get("function_name") != "__init__"} + ) metadata = f"""functions = {functions_to_optimize} trace_file_path = r"{trace_file}" """ @@ -95,7 +96,7 @@ def create_trace_replay_test_code( args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl) ret = {function_name}(*args, **kwargs) - """ + """ # noqa: E501 ) test_method_body = textwrap.dedent( @@ -111,7 +112,8 @@ def create_trace_replay_test_code( else: instance = args[0] # self ret = instance{method_name}(*args[1:], **kwargs) - """) + """ # noqa: E501 + ) test_class_method_body = textwrap.dedent( """\ @@ -121,7 +123,7 @@ def create_trace_replay_test_code( if not args: raise ValueError("No arguments provided for the method.") ret = {class_name_alias}{method_name}(*args[1:], **kwargs) - """ + """ # noqa: E501 ) test_static_method_body = textwrap.dedent( """\ @@ -129,7 +131,7 @@ def create_trace_replay_test_code( args = pickle.loads(args_pkl) kwargs = pickle.loads(kwargs_pkl){filter_variables} ret = {class_name_alias}{method_name}(*args, **kwargs) - """ + """ # noqa: E501 ) # Create main body @@ -142,7 +144,6 @@ def create_trace_replay_test_code( self = "" for func in functions_data: - module_name = func.get("module_name") function_name = func.get("function_name") class_name = func.get("class_name") @@ -163,7 +164,7 @@ def create_trace_replay_test_code( alias = get_function_alias(module_name, class_name + "_" + function_name) filter_variables = "" - # filter_variables = '\n args.pop("cls", None)' + # filter_variables = '\n args.pop("cls", None)' # noqa: ERA001 method_name = "." + function_name if function_name != "__init__" else "" if function_properties.is_classmethod: test_body = test_class_method_body.format( @@ -206,7 +207,10 @@ def create_trace_replay_test_code( return imports + "\n" + metadata + "\n" + test_template -def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100) -> int: + +def generate_replay_test( + trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100 +) -> int: """Generate multiple replay tests from the traced function calls, grouped by benchmark. Args: @@ -221,14 +225,10 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework """ count = 0 try: - # Connect to the database conn = sqlite3.connect(trace_file_path.as_posix()) cursor = conn.cursor() - # Get distinct benchmark file paths - cursor.execute( - "SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings" - ) + cursor.execute("SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings") benchmark_files = cursor.fetchall() # Generate a test for each benchmark file @@ -236,29 +236,28 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework benchmark_module_path = benchmark_file[0] # Get all benchmarks and functions associated with this file path cursor.execute( - "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings " + "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings " # noqa: E501 "WHERE benchmark_module_path = ?", - (benchmark_module_path,) + (benchmark_module_path,), ) functions_data = [] for row in cursor.fetchall(): benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row - # Add this function to our list - functions_data.append({ - "function_name": function_name, - "class_name": class_name, - "file_path": file_path, - "module_name": module_name, - "benchmark_function_name": benchmark_function_name, - "benchmark_module_path": benchmark_module_path, - "benchmark_line_number": benchmark_line_number, - "function_properties": inspect_top_level_functions_or_methods( - file_name=Path(file_path), - function_or_method_name=function_name, - class_name=class_name, - ) - }) + functions_data.append( + { + "function_name": function_name, + "class_name": class_name, + "file_path": file_path, + "module_name": module_name, + "benchmark_function_name": benchmark_function_name, + "benchmark_module_path": benchmark_module_path, + "benchmark_line_number": benchmark_line_number, + "function_properties": inspect_top_level_functions_or_methods( + file_name=Path(file_path), function_or_method_name=function_name, class_name=class_name + ), + } + ) if not functions_data: logger.info(f"No benchmark test functions found in {benchmark_module_path}") @@ -274,13 +273,12 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework output_file = get_test_file_path( test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay" ) - # Write test code to file, parents = true output_dir.mkdir(parents=True, exist_ok=True) output_file.write_text(test_code, "utf-8") count += 1 conn.close() - except Exception as e: - logger.info(f"Error generating replay tests: {e}") + except Exception as e: # noqa: BLE001 + logger.error(f"Error generating replay test: {e}") return count diff --git a/codeflash/benchmarking/trace_benchmarks.py b/codeflash/benchmarking/trace_benchmarks.py index 8d14068e7..e59b06656 100644 --- a/codeflash/benchmarking/trace_benchmarks.py +++ b/codeflash/benchmarking/trace_benchmarks.py @@ -9,7 +9,9 @@ from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE -def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root: Path, trace_file: Path, timeout:int = 300) -> None: +def trace_benchmarks_pytest( + benchmarks_root: Path, tests_root: Path, project_root: Path, trace_file: Path, timeout: int = 300 +) -> None: benchmark_env = os.environ.copy() if "PYTHONPATH" not in benchmark_env: benchmark_env["PYTHONPATH"] = str(project_root) @@ -43,6 +45,4 @@ def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root error_section = match.group(1) if match else result.stdout else: error_section = result.stdout - logger.warning( - f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}" - ) + logger.warning(f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}") diff --git a/codeflash/benchmarking/utils.py b/codeflash/benchmarking/utils.py index da09cd57a..bf41b6da3 100644 --- a/codeflash/benchmarking/utils.py +++ b/codeflash/benchmarking/utils.py @@ -1,12 +1,10 @@ from __future__ import annotations -import shutil from typing import TYPE_CHECKING, Optional -from rich.console import Console from rich.table import Table -from codeflash.cli_cmds.console import logger +from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.time_utils import humanize_runtime from codeflash.models.models import BenchmarkDetail, ProcessedBenchmarkInfo from codeflash.result.critic import performance_gain @@ -15,8 +13,9 @@ from codeflash.models.models import BenchmarkKey -def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], - total_benchmark_timings: dict[BenchmarkKey, int]) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]: +def validate_and_format_benchmark_table( + function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int] +) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]: function_to_result = {} # Process each function's benchmark data for func_path, test_times in function_benchmark_timings.items(): @@ -41,56 +40,58 @@ def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, di def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None: - - try: - terminal_width = int(shutil.get_terminal_size().columns * 0.9) - except Exception: - terminal_width = 120 # Fallback width - console = Console(width = terminal_width) for func_path, sorted_tests in function_to_results.items(): console.print() function_name = func_path.split(":")[-1] - # Create a table for this function - table = Table(title=f"Function: {function_name}", width=terminal_width, border_style="blue", show_lines=True) - benchmark_col_width = max(int(terminal_width * 0.4), 40) - # Add columns - split the benchmark test into two columns - table.add_column("Benchmark Module Path", width=benchmark_col_width, style="cyan", overflow="fold") + table = Table(title=f"Function: {function_name}", border_style="blue", show_lines=True) + table.add_column("Benchmark Module Path", style="cyan", overflow="fold") table.add_column("Test Function", style="magenta", overflow="fold") table.add_column("Total Time (ms)", justify="right", style="green") table.add_column("Function Time (ms)", justify="right", style="yellow") table.add_column("Percentage (%)", justify="right", style="red") - for benchmark_key, total_time, func_time, percentage in sorted_tests: - # Split the benchmark test into module path and function name - module_path = benchmark_key.module_path + multi_call_bases = set() + call_1_tests = [] + + for i, (benchmark_key, _, _, _) in enumerate(sorted_tests): test_function = benchmark_key.function_name + module_path = benchmark_key.module_path + if "::call_" in test_function: + try: + base_name, call_part = test_function.rsplit("::call_", 1) + call_num = int(call_part) + if call_num == 1: + call_1_tests.append((i, base_name, module_path)) + elif call_num > 1: + multi_call_bases.add((base_name, module_path)) + except ValueError: + pass + + tests_to_modify = { + index: base_name + for index, base_name, module_path in call_1_tests + if (base_name, module_path) not in multi_call_bases + } + + for i, (benchmark_key, total_time, func_time, percentage) in enumerate(sorted_tests): + module_path = benchmark_key.module_path + test_function_display = tests_to_modify.get(i, benchmark_key.function_name) if total_time == 0.0: - table.add_row( - module_path, - test_function, - "N/A", - "N/A", - "N/A" - ) + table.add_row(module_path, test_function_display, "N/A", "N/A", "N/A") else: table.add_row( - module_path, - test_function, - f"{total_time:.3f}", - f"{func_time:.3f}", - f"{percentage:.2f}" + module_path, test_function_display, f"{total_time:.3f}", f"{func_time:.3f}", f"{percentage:.2f}" ) - # Print the table console.print(table) def process_benchmark_data( - replay_performance_gain: dict[BenchmarkKey, float], - fto_benchmark_timings: dict[BenchmarkKey, int], - total_benchmark_timings: dict[BenchmarkKey, int] + replay_performance_gain: dict[BenchmarkKey, float], + fto_benchmark_timings: dict[BenchmarkKey, int], + total_benchmark_timings: dict[BenchmarkKey, int], ) -> Optional[ProcessedBenchmarkInfo]: """Process benchmark data and generate detailed benchmark information. @@ -109,19 +110,25 @@ def process_benchmark_data( benchmark_details = [] for benchmark_key, og_benchmark_timing in fto_benchmark_timings.items(): - total_benchmark_timing = total_benchmark_timings.get(benchmark_key, 0) if total_benchmark_timing == 0: continue # Skip benchmarks with zero timing # Calculate expected new benchmark timing - expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + ( - 1 / (replay_performance_gain[benchmark_key] + 1) - ) * og_benchmark_timing + expected_new_benchmark_timing = ( + total_benchmark_timing + - og_benchmark_timing + + (1 / (replay_performance_gain[benchmark_key] + 1)) * og_benchmark_timing + ) # Calculate speedup - benchmark_speedup_percent = performance_gain(original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing)) * 100 + benchmark_speedup_percent = ( + performance_gain( + original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing) + ) + * 100 + ) benchmark_details.append( BenchmarkDetail( @@ -129,7 +136,7 @@ def process_benchmark_data( test_function=benchmark_key.function_name, original_timing=humanize_runtime(int(total_benchmark_timing)), expected_new_timing=humanize_runtime(int(expected_new_benchmark_timing)), - speedup_percent=benchmark_speedup_percent + speedup_percent=benchmark_speedup_percent, ) ) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 3ae28c65b..e28ad915c 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -2,6 +2,7 @@ import ast import os +import shutil import site from functools import lru_cache from pathlib import Path @@ -118,4 +119,9 @@ def has_any_async_functions(code: str) -> bool: def cleanup_paths(paths: list[Path]) -> None: for path in paths: - path.unlink(missing_ok=True) + if not path or not path.exists(): + continue + if path.is_dir(): + shutil.rmtree(path, ignore_errors=True) + else: + path.unlink(missing_ok=True) diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 8adfb4e00..d7578f9b6 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -201,6 +201,7 @@ def get_functions_to_optimize( functions, test_cfg.tests_root, ignore_paths, project_root, module_root ) logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize") + console.rule() return filtered_modified_functions, functions_count diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 946e3e822..2b9130a24 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -2,7 +2,6 @@ import ast import os -import shutil import tempfile import time from collections import defaultdict @@ -18,7 +17,7 @@ from codeflash.cli_cmds.console import console, logger, progress_bar from codeflash.code_utils import env_utils from codeflash.code_utils.code_replacer import normalize_code, normalize_node -from codeflash.code_utils.code_utils import get_run_tmp_file +from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast from codeflash.discovery.discover_unit_tests import discover_unit_tests from codeflash.discovery.functions_to_optimize import get_functions_to_optimize @@ -52,6 +51,11 @@ def __init__(self, args: Namespace) -> None: self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None) self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None self.replay_tests_dir = None + + self.test_cfg.concolic_test_root_dir = Path( + tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_") + ) + def create_function_optimizer( self, function_to_optimize: FunctionToOptimize, @@ -71,7 +75,7 @@ def create_function_optimizer( args=self.args, function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None, total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None, - replay_tests_dir = self.replay_tests_dir + replay_tests_dir=self.replay_tests_dir, ) def run(self) -> None: @@ -81,6 +85,7 @@ def run(self) -> None: if not env_utils.ensure_codeflash_api_key(): return function_optimizer = None + trace_file = None file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]] num_optimizable_functions: int @@ -98,10 +103,7 @@ def run(self) -> None: function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {} total_benchmark_timings: dict[BenchmarkKey, int] = {} if self.args.benchmark and num_optimizable_functions > 0: - with progress_bar( - f"Running benchmarks in {self.args.benchmarks_root}", - transient=True, - ): + with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True): # Insert decorator file_path_to_source_code = defaultdict(str) for file in file_to_funcs_to_optimize: @@ -113,15 +115,23 @@ def run(self) -> None: if trace_file.exists(): trace_file.unlink() - self.replay_tests_dir = Path(tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root)) - trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark + self.replay_tests_dir = Path( + tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root) + ) + trace_benchmarks_pytest( + self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file + ) # Run all tests that use pytest-benchmark replay_count = generate_replay_test(trace_file, self.replay_tests_dir) if replay_count == 0: - logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization") + logger.info( + f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization" + ) else: function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file) total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file) - function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) + function_to_results = validate_and_format_benchmark_table( + function_benchmark_timings, total_benchmark_timings + ) print_benchmark_table(function_to_results) except Exception as e: logger.info(f"Error while tracing existing benchmarks: {e}") @@ -131,12 +141,9 @@ def run(self) -> None: for file in file_path_to_source_code: with file.open("w", encoding="utf8") as f: f.write(file_path_to_source_code[file]) + self.cleanup() optimizations_found: int = 0 function_iterator_count: int = 0 - if self.args.test_framework == "pytest": - self.test_cfg.concolic_test_root_dir = Path( - tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_") - ) try: ph("cli-optimize-functions-to-optimize", {"num_functions": num_optimizable_functions}) if num_optimizable_functions == 0: @@ -148,11 +155,12 @@ def run(self) -> None: function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg) num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()]) console.rule() - logger.info(f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}") + logger.info( + f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}" + ) console.rule() ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests}) - for original_module_path in file_to_funcs_to_optimize: logger.info(f"Examining file {original_module_path!s}…") console.rule() @@ -212,14 +220,26 @@ def run(self) -> None: qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root( self.args.project_root ) - if self.args.benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings: + if ( + self.args.benchmark + and function_benchmark_timings + and qualified_name_w_module in function_benchmark_timings + and total_benchmark_timings + ): function_optimizer = self.create_function_optimizer( - function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings[qualified_name_w_module], total_benchmark_timings + function_to_optimize, + function_to_optimize_ast, + function_to_tests, + validated_original_code[original_module_path].source_code, + function_benchmark_timings[qualified_name_w_module], + total_benchmark_timings, ) else: function_optimizer = self.create_function_optimizer( - function_to_optimize, function_to_optimize_ast, function_to_tests, - validated_original_code[original_module_path].source_code + function_to_optimize, + function_to_optimize_ast, + function_to_tests, + validated_original_code[original_module_path].source_code, ) best_optimization = function_optimizer.optimize_function() @@ -235,23 +255,44 @@ def run(self) -> None: elif self.args.all: logger.info("✨ All functions have been optimized! ✨") finally: - if function_optimizer: - for test_file in function_optimizer.test_files.get_by_type(TestType.GENERATED_REGRESSION).test_files: - test_file.instrumented_behavior_file_path.unlink(missing_ok=True) - test_file.benchmarking_file_path.unlink(missing_ok=True) - for test_file in function_optimizer.test_files.get_by_type(TestType.EXISTING_UNIT_TEST).test_files: - test_file.instrumented_behavior_file_path.unlink(missing_ok=True) - test_file.benchmarking_file_path.unlink(missing_ok=True) - for test_file in function_optimizer.test_files.get_by_type(TestType.CONCOLIC_COVERAGE_TEST).test_files: - test_file.instrumented_behavior_file_path.unlink(missing_ok=True) - if function_optimizer.test_cfg.concolic_test_root_dir: - shutil.rmtree(function_optimizer.test_cfg.concolic_test_root_dir, ignore_errors=True) - if self.args.benchmark: - if self.replay_tests_dir.exists(): - shutil.rmtree(self.replay_tests_dir, ignore_errors=True) - trace_file.unlink(missing_ok=True) - if hasattr(get_run_tmp_file, "tmpdir"): - get_run_tmp_file.tmpdir.cleanup() + self.cleanup(function_optimizer=function_optimizer) + + def cleanup(self, function_optimizer: FunctionOptimizer | None = None) -> None: + paths_to_cleanup: list[Path] = [] + if function_optimizer: + paths_to_cleanup.extend( + test_file.instrumented_behavior_file_path + for test_file in function_optimizer.test_files.get_by_type(TestType.GENERATED_REGRESSION).test_files + ) + paths_to_cleanup.extend( + test_file.benchmarking_file_path + for test_file in function_optimizer.test_files.get_by_type(TestType.GENERATED_REGRESSION).test_files + ) + paths_to_cleanup.extend( + test_file.instrumented_behavior_file_path + for test_file in function_optimizer.test_files.get_by_type(TestType.EXISTING_UNIT_TEST).test_files + ) + paths_to_cleanup.extend( + test_file.benchmarking_file_path + for test_file in function_optimizer.test_files.get_by_type(TestType.EXISTING_UNIT_TEST).test_files + ) + paths_to_cleanup.extend( + test_file.instrumented_behavior_file_path + for test_file in function_optimizer.test_files.get_by_type(TestType.CONCOLIC_COVERAGE_TEST).test_files + ) + paths_to_cleanup.extend( + test_file.benchmarking_file_path + for test_file in function_optimizer.test_files.get_by_type(TestType.REPLAY_TEST).test_files + ) + + paths_to_cleanup.extend( + path for path in {self.replay_tests_dir, self.test_cfg.concolic_test_root_dir} if path and path.exists() + ) + + cleanup_paths(paths_to_cleanup) + + if hasattr(get_run_tmp_file, "tmpdir"): + get_run_tmp_file.tmpdir.cleanup() def run_with_args(args: Namespace) -> None: diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 2228559f9..2d5701694 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -66,7 +66,11 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes len_next_bytes = file.read(4) len_next = int.from_bytes(len_next_bytes, byteorder="big") invocation_id_bytes = file.read(len_next) - invocation_id = invocation_id_bytes.decode("ascii") + try: + invocation_id = invocation_id_bytes.decode("ascii") + except UnicodeDecodeError as e: + logger.warning(f"Failed to decode invocation_id_bytes as ASCII in {file_location}: {e}. Skipping entry.") + continue invocation_id_object = InvocationId.from_str_id(encoded_test_name, invocation_id) test_file_path = file_path_from_module_name(