diff --git a/codeflash/benchmarking/codeflash_trace.py b/codeflash/benchmarking/codeflash_trace.py index 20743fd56..249acdeb3 100644 --- a/codeflash/benchmarking/codeflash_trace.py +++ b/codeflash/benchmarking/codeflash_trace.py @@ -25,6 +25,7 @@ def setup(self, trace_path: str) -> None: """Set up the database connection for direct writing. Args: + ---- trace_path: Path to the trace database file """ @@ -52,6 +53,7 @@ def write_function_timings(self) -> None: """Write function call data directly to the database. Args: + ---- data: List of function call data tuples to write """ @@ -94,9 +96,11 @@ def __call__(self, func: Callable) -> Callable: """Use as a decorator to trace function execution. Args: + ---- func: The function to be decorated Returns: + ------- The wrapped function """ diff --git a/codeflash/benchmarking/instrument_codeflash_trace.py b/codeflash/benchmarking/instrument_codeflash_trace.py index 761e91f71..04b12018a 100644 --- a/codeflash/benchmarking/instrument_codeflash_trace.py +++ b/codeflash/benchmarking/instrument_codeflash_trace.py @@ -76,10 +76,12 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct """Add codeflash_trace to a function. Args: + ---- code: The source code as a string functions_to_optimize: List of FunctionToOptimize instances containing function details Returns: + ------- The modified source code as a string """ diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index 45fabef14..6516fba38 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -74,9 +74,11 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark """Process the trace file and extract timing data for all functions. Args: + ---- trace_path: Path to the trace file Returns: + ------- A nested dictionary where: - Outer keys are module_name.qualified_name (module.class.function) - Inner keys are of type BenchmarkKey @@ -132,9 +134,11 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]: """Extract total benchmark timings from trace files. Args: + ---- trace_path: Path to the trace file Returns: + ------- A dictionary mapping where: - Keys are of type BenchmarkKey - Values are total benchmark timing in milliseconds (with overhead subtracted) diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index c44649632..c2e1889db 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -55,12 +55,14 @@ def create_trace_replay_test_code( """Create a replay test for functions based on trace data. Args: + ---- trace_file: Path to the SQLite database file functions_data: List of dictionaries with function info extracted from DB test_framework: 'pytest' or 'unittest' max_run_count: Maximum number of runs to include in the test Returns: + ------- A string containing the test code """ @@ -218,12 +220,14 @@ def generate_replay_test( """Generate multiple replay tests from the traced function calls, grouped by benchmark. Args: + ---- trace_file_path: Path to the SQLite database file output_dir: Directory to write the generated tests (if None, only returns the code) test_framework: 'pytest' or 'unittest' max_run_count: Maximum number of runs to include per function Returns: + ------- Dictionary mapping benchmark names to generated test code """ diff --git a/codeflash/benchmarking/utils.py b/codeflash/benchmarking/utils.py index 5dae99444..db89c4c33 100644 --- a/codeflash/benchmarking/utils.py +++ b/codeflash/benchmarking/utils.py @@ -83,11 +83,13 @@ def process_benchmark_data( """Process benchmark data and generate detailed benchmark information. Args: + ---- replay_performance_gain: The performance gain from replay fto_benchmark_timings: Function to optimize benchmark timings total_benchmark_timings: Total benchmark timings Returns: + ------- ProcessedBenchmarkInfo containing processed benchmark details """ diff --git a/codeflash/cli_cmds/logging_config.py b/codeflash/cli_cmds/logging_config.py index 8bd4a48d9..e546836fc 100644 --- a/codeflash/cli_cmds/logging_config.py +++ b/codeflash/cli_cmds/logging_config.py @@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None: ], force=True, ) - logging.info("Verbose DEBUG logging enabled") # noqa: LOG015 + logging.info("Verbose DEBUG logging enabled") else: - logging.info("Logging level set to INFO") # noqa: LOG015 + logging.info("Logging level set to INFO") console.rule() diff --git a/codeflash/code_utils/checkpoint.py b/codeflash/code_utils/checkpoint.py index 8a333c3fe..4c69ecc58 100644 --- a/codeflash/code_utils/checkpoint.py +++ b/codeflash/code_utils/checkpoint.py @@ -47,6 +47,7 @@ def add_function_to_checkpoint( """Add a function to the checkpoint after it has been processed. Args: + ---- function_fully_qualified_name: The fully qualified name of the function status: Status of optimization (e.g., "optimized", "failed", "skipped") additional_info: Any additional information to store about the function @@ -104,7 +105,8 @@ def cleanup(self) -> None: def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]: """Get information about all processed functions, regardless of status. - Returns: + Returns + ------- Dictionary mapping function names to their processing information """ diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py new file mode 100644 index 000000000..4e6e31072 --- /dev/null +++ b/codeflash/code_utils/edit_generated_tests.py @@ -0,0 +1,141 @@ +import re + +import libcst as cst + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.time_utils import format_time +from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults + + +def remove_functions_from_generated_tests( + generated_tests: GeneratedTestsList, test_functions_to_remove: list[str] +) -> GeneratedTestsList: + new_generated_tests = [] + for generated_test in generated_tests.generated_tests: + for test_function in test_functions_to_remove: + function_pattern = re.compile( + rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)", + re.DOTALL, + ) + + match = function_pattern.search(generated_test.generated_original_test_source) + + if match is None or "@pytest.mark.parametrize" in match.group(0): + continue + + generated_test.generated_original_test_source = function_pattern.sub( + "", generated_test.generated_original_test_source + ) + + new_generated_tests.append(generated_test) + + return GeneratedTestsList(generated_tests=new_generated_tests) + + +def add_runtime_comments_to_generated_tests( + generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults +) -> GeneratedTestsList: + """Add runtime performance comments to function calls in generated tests.""" + # Create dictionaries for fast lookup of runtime data + original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case() + optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() + + class RuntimeCommentTransformer(cst.CSTTransformer): + def __init__(self) -> None: + self.in_test_function = False + self.current_test_name: str | None = None + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + if node.name.value.startswith("test_"): + self.in_test_function = True + self.current_test_name = node.name.value + else: + self.in_test_function = False + self.current_test_name = None + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + if original_node.name.value.startswith("test_"): + self.in_test_function = False + self.current_test_name = None + return updated_node + + def leave_SimpleStatementLine( + self, + original_node: cst.SimpleStatementLine, # noqa: ARG002 + updated_node: cst.SimpleStatementLine, + ) -> cst.SimpleStatementLine: + if not self.in_test_function or not self.current_test_name: + return updated_node + + # Look for assignment statements that assign to codeflash_output + # Handle both single statements and multiple statements on one line + codeflash_assignment_found = False + for stmt in updated_node.body: + if isinstance(stmt, cst.Assign) and ( + len(stmt.targets) == 1 + and isinstance(stmt.targets[0].target, cst.Name) + and stmt.targets[0].target.value == "codeflash_output" + ): + codeflash_assignment_found = True + break + + if codeflash_assignment_found: + # Find matching test cases by looking for this test function name in the test results + matching_original_times = [] + matching_optimized_times = [] + + for invocation_id, runtimes in original_runtime_by_test.items(): + if invocation_id.test_function_name == self.current_test_name: + matching_original_times.extend(runtimes) + + for invocation_id, runtimes in optimized_runtime_by_test.items(): + if invocation_id.test_function_name == self.current_test_name: + matching_optimized_times.extend(runtimes) + + if matching_original_times and matching_optimized_times: + original_time = min(matching_original_times) + optimized_time = min(matching_optimized_times) + + # Create the runtime comment + comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}" + + # Add comment to the trailing whitespace + new_trailing_whitespace = cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment(comment_text), + newline=updated_node.trailing_whitespace.newline, + ) + + return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) + + return updated_node + + # Process each generated test + modified_tests = [] + for test in generated_tests.generated_tests: + try: + # Parse the test source code + tree = cst.parse_module(test.generated_original_test_source) + + # Transform the tree to add runtime comments + transformer = RuntimeCommentTransformer() + modified_tree = tree.visit(transformer) + + # Convert back to source code + modified_source = modified_tree.code + + # Create a new GeneratedTests object with the modified source + modified_test = GeneratedTests( + generated_original_test_source=modified_source, + instrumented_behavior_test_source=test.instrumented_behavior_test_source, + instrumented_perf_test_source=test.instrumented_perf_test_source, + behavior_file_path=test.behavior_file_path, + perf_file_path=test.perf_file_path, + ) + modified_tests.append(modified_test) + except Exception as e: + # If parsing fails, keep the original test + logger.debug(f"Failed to add runtime comments to test: {e}") + modified_tests.append(test) + + return GeneratedTestsList(generated_tests=modified_tests) diff --git a/codeflash/code_utils/line_profile_utils.py b/codeflash/code_utils/line_profile_utils.py index 935e30356..498571578 100644 --- a/codeflash/code_utils/line_profile_utils.py +++ b/codeflash/code_utils/line_profile_utils.py @@ -24,6 +24,7 @@ def __init__(self, qualified_name: str, decorator_name: str) -> None: """Initialize the transformer. Args: + ---- qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). decorator_name: The name of the decorator to add. @@ -144,11 +145,13 @@ def add_decorator_to_qualified_function(module: cst.Module, qualified_name: str, """Add a decorator to a function with the exact qualified name in the source code. Args: + ---- module: The Python source code as a CST module. qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). decorator_name: The name of the decorator to add. Returns: + ------- The modified CST module. """ diff --git a/codeflash/code_utils/remove_generated_tests.py b/codeflash/code_utils/remove_generated_tests.py deleted file mode 100644 index 25eb58965..000000000 --- a/codeflash/code_utils/remove_generated_tests.py +++ /dev/null @@ -1,28 +0,0 @@ -import re - -from codeflash.models.models import GeneratedTestsList - - -def remove_functions_from_generated_tests( - generated_tests: GeneratedTestsList, test_functions_to_remove: list[str] -) -> GeneratedTestsList: - new_generated_tests = [] - for generated_test in generated_tests.generated_tests: - for test_function in test_functions_to_remove: - function_pattern = re.compile( - rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)", - re.DOTALL, - ) - - match = function_pattern.search(generated_test.generated_original_test_source) - - if match is None or "@pytest.mark.parametrize" in match.group(0): - continue - - generated_test.generated_original_test_source = function_pattern.sub( - "", generated_test.generated_original_test_source - ) - - new_generated_tests.append(generated_test) - - return GeneratedTestsList(generated_tests=new_generated_tests) diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index aaf74fc93..4e43e7239 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -49,3 +49,40 @@ def humanize_runtime(time_in_ns: int) -> str: runtime_human = runtime_human_parts[0] return f"{runtime_human} {units}" + + +def format_time(nanoseconds: int) -> str: + """Format nanoseconds into a human-readable string with 3 significant digits when needed.""" + # Inlined significant digit check: >= 3 digits if value >= 100 + if nanoseconds < 1_000: + return f"{nanoseconds}ns" + if nanoseconds < 1_000_000: + microseconds_int = nanoseconds // 1_000 + if microseconds_int >= 100: + return f"{microseconds_int}μs" + microseconds = nanoseconds / 1_000 + # Format with precision: 3 significant digits + if microseconds >= 100: + return f"{microseconds:.0f}μs" + if microseconds >= 10: + return f"{microseconds:.1f}μs" + return f"{microseconds:.2f}μs" + if nanoseconds < 1_000_000_000: + milliseconds_int = nanoseconds // 1_000_000 + if milliseconds_int >= 100: + return f"{milliseconds_int}ms" + milliseconds = nanoseconds / 1_000_000 + if milliseconds >= 100: + return f"{milliseconds:.0f}ms" + if milliseconds >= 10: + return f"{milliseconds:.1f}ms" + return f"{milliseconds:.2f}ms" + seconds_int = nanoseconds // 1_000_000_000 + if seconds_int >= 100: + return f"{seconds_int}s" + seconds = nanoseconds / 1_000_000_000 + if seconds >= 100: + return f"{seconds:.0f}s" + if seconds >= 10: + return f"{seconds:.1f}s" + return f"{seconds:.2f}s" diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index ba8929343..934d3053b 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -3,11 +3,9 @@ import os from collections import defaultdict from itertools import chain -from pathlib import Path # noqa: TC003 from typing import TYPE_CHECKING import libcst as cst -from libcst import CSTNode # noqa: TC002 from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects @@ -24,7 +22,10 @@ from codeflash.optimization.function_context import belongs_to_function_qualified if TYPE_CHECKING: + from pathlib import Path + from jedi.api.classes import Name + from libcst import CSTNode def get_code_optimization_context( @@ -150,6 +151,7 @@ def extract_code_string_context_from_files( imports, and combines them. Args: + ---- helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions project_root_path: Root path of the project @@ -157,6 +159,7 @@ def extract_code_string_context_from_files( code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) Returns: + ------- CodeString containing the extracted code context with necessary imports """ # noqa: D205 @@ -257,6 +260,7 @@ def extract_code_markdown_context_from_files( imports, and combines them into a structured markdown format. Args: + ---- helpers_of_fto: Dictionary mapping file paths to sets of Function Sources of function to optimize and its helpers helpers_of_helpers: Dictionary mapping file paths to sets of Function Sources of helpers of helper functions project_root_path: Root path of the project @@ -264,6 +268,7 @@ def extract_code_markdown_context_from_files( code_context_type: Type of code context to extract (READ_ONLY, READ_WRITABLE, or TESTGEN) Returns: + ------- CodeStringsMarkdown containing the extracted code context with necessary imports, formatted for inclusion in markdown @@ -382,7 +387,7 @@ def get_function_to_optimize_as_function_source( source_code=name.get_line_code(), jedi_definition=name, ) - except Exception as e: # noqa: PERF203 + except Exception as e: logger.exception(f"Error while getting function source: {e}") continue raise ValueError( @@ -502,7 +507,8 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911 ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions. - Returns: + Returns + ------- (filtered_node, found_target): filtered_node: The modified CST node or None if it should be removed. found_target: True if a target function was found in this node's subtree. @@ -586,7 +592,8 @@ def prune_cst_for_read_only_code( # noqa: PLR0911 ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node for read-only context. - Returns: + Returns + ------- (filtered_node, found_target): filtered_node: The modified CST node or None if it should be removed. found_target: True if a target function was found in this node's subtree. @@ -690,7 +697,8 @@ def prune_cst_for_testgen_code( # noqa: PLR0911 ) -> tuple[cst.CSTNode | None, bool]: """Recursively filter the node for testgen context. - Returns: + Returns + ------- (filtered_node, found_target): filtered_node: The modified CST node or None if it should be removed. found_target: True if a target function was found in this node's subtree. diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 86835e128..53e249495 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -311,10 +311,12 @@ def remove_unused_definitions_recursively( # noqa: PLR0911 """Recursively filter the node to remove unused definitions. Args: + ---- node: The CST node to process definitions: Dictionary of definition info Returns: + ------- (filtered_node, used_by_function): filtered_node: The modified CST node or None if it should be removed used_by_function: True if this node or any child is used by qualified functions @@ -450,6 +452,7 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na If a class is referenced by a qualified function, we keep the entire class. Args: + ---- code: The code to process qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname' diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 12aeff3fa..a3cd11b3e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -36,10 +36,13 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) +from codeflash.code_utils.edit_generated_tests import ( + add_runtime_comments_to_generated_tests, + remove_functions_from_generated_tests, +) from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports -from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime from codeflash.context import code_context_extractor @@ -265,10 +268,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 }, ) - generated_tests = remove_functions_from_generated_tests( - generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove - ) - if best_optimization: logger.info("Best candidate:") code_print(best_optimization.candidate.source_code) @@ -295,8 +294,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None, ) - self.log_successful_optimization(explanation, generated_tests, exp_type) - self.replace_function_and_helpers_with_optimized_code( code_context=code_context, optimized_code=best_optimization.candidate.source_code ) @@ -321,6 +318,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 if original_code_baseline.coverage_results else "Coverage data not available" ) + generated_tests = remove_functions_from_generated_tests( + generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove + ) + # Add runtime comments to generated tests before creating the PR + generated_tests = add_runtime_comments_to_generated_tests( + generated_tests, + original_code_baseline.benchmarking_test_results, + best_optimization.winning_benchmarking_test_results, + ) generated_tests_str = "\n\n".join( [test.generated_original_test_source for test in generated_tests.generated_tests] ) @@ -345,6 +351,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 original_helper_code, self.function_to_optimize.file_path, ) + self.log_successful_optimization(explanation, generated_tests, exp_type) if not best_optimization: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") diff --git a/codeflash/picklepatch/pickle_patcher.py b/codeflash/picklepatch/pickle_patcher.py index 0e08756ab..3f3236f76 100644 --- a/codeflash/picklepatch/pickle_patcher.py +++ b/codeflash/picklepatch/pickle_patcher.py @@ -30,12 +30,14 @@ def dumps(obj: object, protocol: int | None = None, max_depth: int = 100, **kwar """Safely pickle an object, replacing unpicklable parts with placeholders. Args: + ---- obj: The object to pickle protocol: The pickle protocol version to use max_depth: Maximum recursion depth **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ @@ -46,9 +48,11 @@ def loads(pickled_data: bytes) -> object: """Unpickle data that may contain placeholders. Args: + ---- pickled_data: Pickled data with possible placeholders Returns: + ------- The unpickled object with placeholders for unpicklable parts """ @@ -59,11 +63,13 @@ def _create_placeholder(obj: object, error_msg: str, path: list[str]) -> PickleP """Create a placeholder for an unpicklable object. Args: + ---- obj: The original unpicklable object error_msg: Error message explaining why it couldn't be pickled path: Path to this object in the object graph Returns: + ------- PicklePlaceholder: A placeholder object """ @@ -91,12 +97,14 @@ def _pickle( """Try to pickle an object using pickle first, then dill. If both fail, create a placeholder. Args: + ---- obj: The object to pickle path: Path to this object in the object graph protocol: The pickle protocol version to use **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- tuple: (success, result) where success is a boolean and result is either: - Pickled bytes if successful - Error message if not successful @@ -123,6 +131,7 @@ def _recursive_pickle( # noqa: PLR0911 """Recursively try to pickle an object, replacing unpicklable parts with placeholders. Args: + ---- obj: The object to pickle max_depth: Maximum recursion depth path: Current path in the object graph @@ -130,6 +139,7 @@ def _recursive_pickle( # noqa: PLR0911 **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ @@ -185,6 +195,7 @@ def _handle_dict( """Handle pickling for dictionary objects. Args: + ---- obj_dict: The dictionary to pickle max_depth: Maximum recursion depth error_msg: Error message from the original pickling attempt @@ -193,6 +204,7 @@ def _handle_dict( **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ @@ -249,6 +261,7 @@ def _handle_sequence( """Handle pickling for sequence types (list, tuple, set). Args: + ---- obj_seq: The sequence to pickle max_depth: Maximum recursion depth error_msg: Error message from the original pickling attempt @@ -257,6 +270,7 @@ def _handle_sequence( **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ @@ -305,6 +319,7 @@ def _handle_object( """Handle pickling for custom objects with __dict__. Args: + ---- obj: The object to pickle max_depth: Maximum recursion depth error_msg: Error message from the original pickling attempt @@ -313,6 +328,7 @@ def _handle_object( **kwargs: Additional arguments for pickle/dill.dumps Returns: + ------- bytes: Pickled data with placeholders for unpicklable objects """ diff --git a/codeflash/picklepatch/pickle_placeholder.py b/codeflash/picklepatch/pickle_placeholder.py index 50e9c5aa3..4268a9146 100644 --- a/codeflash/picklepatch/pickle_placeholder.py +++ b/codeflash/picklepatch/pickle_placeholder.py @@ -18,6 +18,7 @@ def __init__(self, obj_type: str, obj_str: str, error_msg: str, path: list[str] """Initialize a placeholder for an unpicklable object. Args: + ---- obj_type (str): The type name of the original object obj_str (str): String representation of the original object error_msg (str): The error message that occurred during pickling diff --git a/codeflash/tracing/profile_stats.py b/codeflash/tracing/profile_stats.py index c2ed7cb49..8e2fc5e28 100644 --- a/codeflash/tracing/profile_stats.py +++ b/codeflash/tracing/profile_stats.py @@ -55,7 +55,7 @@ def print_stats(self, *amount) -> pstats.Stats: # noqa: ANN002 print(indent, self.total_calls, "function calls", end=" ", file=self.stream) if self.total_calls != self.prim_calls: - print("(%d primitive calls)" % self.prim_calls, end=" ", file=self.stream) # noqa: UP031 + print(f"({self.prim_calls:d} primitive calls)", end=" ", file=self.stream) time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit] print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream) print(file=self.stream) diff --git a/mypy_allowlist.txt b/mypy_allowlist.txt index 6abaa8894..6a070b606 100644 --- a/mypy_allowlist.txt +++ b/mypy_allowlist.txt @@ -29,7 +29,7 @@ codeflash/code_utils/time_utils.py codeflash/code_utils/env_utils.py codeflash/code_utils/config_consts.py codeflash/code_utils/static_analysis.py -codeflash/code_utils/remove_generated_tests.py +codeflash/code_utils/edit_generated_tests.py codeflash/cli_cmds/console_constants.py codeflash/cli_cmds/logging_config.py codeflash/cli_cmds/__init__.py diff --git a/pyproject.toml b/pyproject.toml index c3e48f889..cb8f2c7d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -191,7 +191,9 @@ ignore = [ "T201", "PGH004", "S301", - "D104" + "D104", + "PERF203", + "LOG015" ] [tool.ruff.lint.flake8-type-checking] diff --git a/tests/test_add_runtime_comments.py b/tests/test_add_runtime_comments.py new file mode 100644 index 000000000..51c1ef052 --- /dev/null +++ b/tests/test_add_runtime_comments.py @@ -0,0 +1,464 @@ +"""Tests for the add_runtime_comments_to_generated_tests functionality.""" + +from pathlib import Path + +from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests +from codeflash.models.models import ( + FunctionTestInvocation, + GeneratedTests, + GeneratedTestsList, + InvocationId, + TestResults, + TestType, + VerificationType, +) + + +class TestAddRuntimeComments: + """Test cases for add_runtime_comments_to_generated_tests method.""" + + def create_test_invocation( + self, test_function_name: str, runtime: int, loop_index: int = 1, iteration_id: str = "1", did_pass: bool = True + ) -> FunctionTestInvocation: + """Helper to create test invocation objects.""" + return FunctionTestInvocation( + loop_index=loop_index, + id=InvocationId( + test_module_path="test_module", + test_class_name=None, + test_function_name=test_function_name, + function_getting_tested="test_function", + iteration_id=iteration_id, + ), + file_name=Path("test.py"), + did_pass=did_pass, + runtime=runtime, + test_framework="pytest", + test_type=TestType.GENERATED_REGRESSION, + return_value=None, + timed_out=False, + verification_type=VerificationType.FUNCTION_CALL, + ) + + def test_basic_runtime_comment_addition(self): + """Test basic functionality of adding runtime comments.""" + # Create test source code + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + # Add test invocations with different runtimes + original_invocation = self.create_test_invocation("test_bubble_sort", 500_000) # 500μs + optimized_invocation = self.create_test_invocation("test_bubble_sort", 300_000) # 300μs + + original_test_results.add(original_invocation) + optimized_test_results.add(optimized_invocation) + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that comments were added + modified_source = result.generated_tests[0].generated_original_test_source + assert "# 500μs -> 300μs" in modified_source + assert "codeflash_output = bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source + + def test_multiple_test_functions(self): + """Test handling multiple test functions in the same file.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] + +def test_quick_sort(): + codeflash_output = quick_sort([5, 2, 8]) + assert codeflash_output == [2, 5, 8] + +def helper_function(): + return "not a test" +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results for both functions + original_test_results = TestResults() + optimized_test_results = TestResults() + + # Add test invocations for both test functions + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000)) + + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + modified_source = result.generated_tests[0].generated_original_test_source + + # Check that comments were added to both test functions + assert "# 500μs -> 300μs" in modified_source + assert "# 800μs -> 600μs" in modified_source + # Helper function should not have comments + assert ( + "helper_function():" in modified_source + and "# " not in modified_source.split("helper_function():")[1].split("\n")[0] + ) + + def test_different_time_formats(self): + """Test that different time ranges are formatted correctly with new precision rules.""" + test_cases = [ + (999, 500, "999ns -> 500ns"), # nanoseconds + (25_000, 18_000, "25.0μs -> 18.0μs"), # microseconds with precision + (500_000, 300_000, "500μs -> 300μs"), # microseconds full integers + (1_500_000, 800_000, "1.50ms -> 800μs"), # milliseconds with precision + (365_000_000, 290_000_000, "365ms -> 290ms"), # milliseconds full integers + (2_000_000_000, 1_500_000_000, "2.00s -> 1.50s"), # seconds with precision + ] + + for original_time, optimized_time, expected_comment in test_cases: + test_source = """def test_function(): + codeflash_output = some_function() + assert codeflash_output is not None +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_function", original_time)) + optimized_test_results.add(self.create_test_invocation("test_function", optimized_time)) + + # Test the functionality + result = add_runtime_comments_to_generated_tests( + generated_tests, original_test_results, optimized_test_results + ) + + modified_source = result.generated_tests[0].generated_original_test_source + assert f"# {expected_comment}" in modified_source + + def test_missing_test_results(self): + """Test behavior when test results are missing for a test function.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create empty test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that no comments were added + modified_source = result.generated_tests[0].generated_original_test_source + assert modified_source == test_source # Should be unchanged + + def test_partial_test_results(self): + """Test behavior when only one set of test results is available.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results with only original data + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + # No optimized results + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that no comments were added + modified_source = result.generated_tests[0].generated_original_test_source + assert modified_source == test_source # Should be unchanged + + def test_multiple_runtimes_uses_minimum(self): + """Test that when multiple runtimes exist, the minimum is used.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results with multiple loop iterations + original_test_results = TestResults() + optimized_test_results = TestResults() + + # Add multiple runs with different runtimes + original_test_results.add(self.create_test_invocation("test_bubble_sort", 600_000, loop_index=1)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000, loop_index=2)) + original_test_results.add(self.create_test_invocation("test_bubble_sort", 550_000, loop_index=3)) + + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 350_000, loop_index=1)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000, loop_index=2)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 320_000, loop_index=3)) + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that minimum times were used (500μs -> 300μs) + modified_source = result.generated_tests[0].generated_original_test_source + assert "# 500μs -> 300μs" in modified_source + + def test_no_codeflash_output_assignment(self): + """Test behavior when test doesn't have codeflash_output assignment.""" + test_source = """def test_bubble_sort(): + result = bubble_sort([3, 1, 2]) + assert result == [1, 2, 3] +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that no comments were added (no codeflash_output assignment) + modified_source = result.generated_tests[0].generated_original_test_source + assert modified_source == test_source # Should be unchanged + + def test_invalid_python_code_handling(self): + """Test behavior when test source code is invalid Python.""" + test_source = """def test_bubble_sort(: + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" # Invalid syntax: extra colon + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + + # Test the functionality - should handle parse error gracefully + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that original test is preserved when parsing fails + modified_source = result.generated_tests[0].generated_original_test_source + assert modified_source == test_source # Should be unchanged due to parse error + + def test_multiple_generated_tests(self): + """Test handling multiple generated test objects.""" + test_source_1 = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + test_source_2 = """def test_quick_sort(): + codeflash_output = quick_sort([5, 2, 8]) + assert codeflash_output == [2, 5, 8] +""" + + generated_test_1 = GeneratedTests( + generated_original_test_source=test_source_1, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior_1.py"), + perf_file_path=Path("test_perf_1.py"), + ) + + generated_test_2 = GeneratedTests( + generated_original_test_source=test_source_2, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior_2.py"), + perf_file_path=Path("test_perf_2.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test_1, generated_test_2]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + original_test_results.add(self.create_test_invocation("test_quick_sort", 800_000)) + + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that comments were added to both test files + modified_source_1 = result.generated_tests[0].generated_original_test_source + modified_source_2 = result.generated_tests[1].generated_original_test_source + + assert "# 500μs -> 300μs" in modified_source_1 + assert "# 800μs -> 600μs" in modified_source_2 + + def test_preserved_test_attributes(self): + """Test that other test attributes are preserved during modification.""" + test_source = """def test_bubble_sort(): + codeflash_output = bubble_sort([3, 1, 2]) + assert codeflash_output == [1, 2, 3] +""" + + original_behavior_source = "behavior test source" + original_perf_source = "perf test source" + original_behavior_path = Path("test_behavior.py") + original_perf_path = Path("test_perf.py") + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source=original_behavior_source, + instrumented_perf_test_source=original_perf_source, + behavior_file_path=original_behavior_path, + perf_file_path=original_perf_path, + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) + optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that other attributes are preserved + modified_test = result.generated_tests[0] + assert modified_test.instrumented_behavior_test_source == original_behavior_source + assert modified_test.instrumented_perf_test_source == original_perf_source + assert modified_test.behavior_file_path == original_behavior_path + assert modified_test.perf_file_path == original_perf_path + + # Check that only the generated_original_test_source was modified + assert "# 500μs -> 300μs" in modified_test.generated_original_test_source + + def test_multistatement_line_handling(self): + """Test that runtime comments work correctly with multiple statements on one line.""" + test_source = """def test_mutation_of_input(): + # Test that the input list is mutated in-place and returned + arr = [3, 1, 2] + codeflash_output = sorter(arr); result = codeflash_output + assert result == [1, 2, 3] + assert arr == [1, 2, 3] # Input should be mutated +""" + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("test_behavior.py"), + perf_file_path=Path("test_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Create test results + original_test_results = TestResults() + optimized_test_results = TestResults() + + original_test_results.add(self.create_test_invocation("test_mutation_of_input", 19_000)) # 19μs + optimized_test_results.add(self.create_test_invocation("test_mutation_of_input", 14_000)) # 14μs + + # Test the functionality + result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + + # Check that comments were added to the correct line + modified_source = result.generated_tests[0].generated_original_test_source + assert "# 19.0μs -> 14.0μs" in modified_source + + # Verify the comment is on the line with codeflash_output assignment + lines = modified_source.split("\n") + codeflash_line = None + for line in lines: + if "codeflash_output = sorter(arr)" in line: + codeflash_line = line + break + + assert codeflash_line is not None, "Could not find codeflash_output assignment line" + assert "# 19.0μs -> 14.0μs" in codeflash_line, f"Comment not found in the correct line: {codeflash_line}" diff --git a/tests/test_remove_functions_from_generated_tests.py b/tests/test_remove_functions_from_generated_tests.py index dc2a14468..c6fd9a7aa 100644 --- a/tests/test_remove_functions_from_generated_tests.py +++ b/tests/test_remove_functions_from_generated_tests.py @@ -1,8 +1,7 @@ from pathlib import Path import pytest - -from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests +from codeflash.code_utils.edit_generated_tests import remove_functions_from_generated_tests from codeflash.models.models import GeneratedTests, GeneratedTestsList