diff --git a/codeflash/code_utils/edit_generated_tests.py b/codeflash/code_utils/edit_generated_tests.py index 4e6e31072..547dbc92b 100644 --- a/codeflash/code_utils/edit_generated_tests.py +++ b/codeflash/code_utils/edit_generated_tests.py @@ -1,10 +1,14 @@ +import os import re +from pathlib import Path import libcst as cst from codeflash.cli_cmds.console import logger -from codeflash.code_utils.time_utils import format_time -from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults +from codeflash.code_utils.time_utils import format_perf, format_time +from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId +from codeflash.result.critic import performance_gain +from codeflash.verification.verification_utils import TestConfig def remove_functions_from_generated_tests( @@ -33,30 +37,39 @@ def remove_functions_from_generated_tests( def add_runtime_comments_to_generated_tests( - generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults + test_cfg: TestConfig, + generated_tests: GeneratedTestsList, + original_runtimes: dict[InvocationId, list[int]], + optimized_runtimes: dict[InvocationId, list[int]], ) -> GeneratedTestsList: """Add runtime performance comments to function calls in generated tests.""" - # Create dictionaries for fast lookup of runtime data - original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case() - optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() + tests_root = test_cfg.tests_root + module_root = test_cfg.project_root_path + rel_tests_root = tests_root.relative_to(module_root) + # TODO: reduce for loops to one class RuntimeCommentTransformer(cst.CSTTransformer): - def __init__(self) -> None: - self.in_test_function = False - self.current_test_name: str | None = None + def __init__(self, test: GeneratedTests, tests_root: Path, rel_tests_root: Path) -> None: + self.test = test + self.context_stack: list[str] = [] + self.tests_root = tests_root + self.rel_tests_root = rel_tests_root + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + # Track when we enter a class + self.context_stack.append(node.name.value) + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002 + # Pop the context when we leave a class + self.context_stack.pop() + return updated_node def visit_FunctionDef(self, node: cst.FunctionDef) -> None: - if node.name.value.startswith("test_"): - self.in_test_function = True - self.current_test_name = node.name.value - else: - self.in_test_function = False - self.current_test_name = None - - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: - if original_node.name.value.startswith("test_"): - self.in_test_function = False - self.current_test_name = None + self.context_stack.append(node.name.value) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 + # Pop the context when we leave a function + self.context_stack.pop() return updated_node def leave_SimpleStatementLine( @@ -64,9 +77,6 @@ def leave_SimpleStatementLine( original_node: cst.SimpleStatementLine, # noqa: ARG002 updated_node: cst.SimpleStatementLine, ) -> cst.SimpleStatementLine: - if not self.in_test_function or not self.current_test_name: - return updated_node - # Look for assignment statements that assign to codeflash_output # Handle both single statements and multiple statements on one line codeflash_assignment_found = False @@ -83,30 +93,65 @@ def leave_SimpleStatementLine( # Find matching test cases by looking for this test function name in the test results matching_original_times = [] matching_optimized_times = [] - - for invocation_id, runtimes in original_runtime_by_test.items(): - if invocation_id.test_function_name == self.current_test_name: + # TODO : will not work if there are multiple test cases with the same name, match filename + test class + test function name + for invocation_id, runtimes in original_runtimes.items(): + qualified_name = ( + invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] + if invocation_id.test_class_name + else invocation_id.test_function_name + ) + rel_path = ( + Path(invocation_id.test_module_path.replace(".", os.sep)) + .with_suffix(".py") + .relative_to(self.rel_tests_root) + ) + if qualified_name == ".".join(self.context_stack) and rel_path in [ + self.test.behavior_file_path.relative_to(self.tests_root), + self.test.perf_file_path.relative_to(self.tests_root), + ]: matching_original_times.extend(runtimes) - for invocation_id, runtimes in optimized_runtime_by_test.items(): - if invocation_id.test_function_name == self.current_test_name: + for invocation_id, runtimes in optimized_runtimes.items(): + qualified_name = ( + invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] + if invocation_id.test_class_name + else invocation_id.test_function_name + ) + rel_path = ( + Path(invocation_id.test_module_path.replace(".", os.sep)) + .with_suffix(".py") + .relative_to(self.rel_tests_root) + ) + if qualified_name == ".".join(self.context_stack) and rel_path in [ + self.test.behavior_file_path.relative_to(self.tests_root), + self.test.perf_file_path.relative_to(self.tests_root), + ]: matching_optimized_times.extend(runtimes) if matching_original_times and matching_optimized_times: original_time = min(matching_original_times) optimized_time = min(matching_optimized_times) - - # Create the runtime comment - comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}" - - # Add comment to the trailing whitespace - new_trailing_whitespace = cst.TrailingWhitespace( - whitespace=cst.SimpleWhitespace(" "), - comment=cst.Comment(comment_text), - newline=updated_node.trailing_whitespace.newline, - ) - - return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) + if original_time != 0 and optimized_time != 0: + perf_gain = format_perf( + abs( + performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) + * 100 + ) + ) + status = "slower" if optimized_time > original_time else "faster" + # Create the runtime comment + comment_text = ( + f"# {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})" + ) + + # Add comment to the trailing whitespace + new_trailing_whitespace = cst.TrailingWhitespace( + whitespace=cst.SimpleWhitespace(" "), + comment=cst.Comment(comment_text), + newline=updated_node.trailing_whitespace.newline, + ) + + return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace) return updated_node @@ -118,7 +163,7 @@ def leave_SimpleStatementLine( tree = cst.parse_module(test.generated_original_test_source) # Transform the tree to add runtime comments - transformer = RuntimeCommentTransformer() + transformer = RuntimeCommentTransformer(test, tests_root, rel_tests_root) modified_tree = tree.visit(transformer) # Convert back to source code diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index 89273fe2d..4e32eedab 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -85,3 +85,15 @@ def format_time(nanoseconds: int) -> str: # This should never be reached, but included for completeness return f"{nanoseconds}ns" + + +def format_perf(percentage: float) -> str: + """Format percentage into a human-readable string with 3 significant digits when needed.""" + percentage_abs = abs(percentage) + if percentage_abs >= 100: + return f"{percentage:.0f}" + if percentage_abs >= 10: + return f"{percentage:.1f}" + if percentage_abs >= 1: + return f"{percentage:.2f}" + return f"{percentage:.3f}" diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 02db2d0b6..bd4556965 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -557,6 +557,7 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree: def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: # Efficient single traversal, directly accumulating into a dict. + # can track mins here and only sums can be return in total_passed_runtime by_id: dict[InvocationId, list[int]] = {} for result in self.test_results: if result.did_pass: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index f6c7661b4..c94759369 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -341,12 +341,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 optimized_function=best_optimization.candidate.source_code, ) - existing_tests = existing_tests_source_for( - self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), - function_to_all_tests, - tests_root=self.test_cfg.tests_root, - ) - original_code_combined = original_helper_code.copy() original_code_combined[explanation.file_path] = self.function_to_optimize_source_code new_code_combined = new_helper_code.copy() @@ -360,15 +354,26 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 generated_tests = remove_functions_from_generated_tests( generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove ) + original_runtime_by_test = ( + original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case() + ) + optimized_runtime_by_test = ( + best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case() + ) # Add runtime comments to generated tests before creating the PR generated_tests = add_runtime_comments_to_generated_tests( - generated_tests, - original_code_baseline.benchmarking_test_results, - best_optimization.winning_benchmarking_test_results, + self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test ) generated_tests_str = "\n\n".join( [test.generated_original_test_source for test in generated_tests.generated_tests] ) + existing_tests = existing_tests_source_for( + self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), + function_to_all_tests, + test_cfg=self.test_cfg, + original_runtimes_all=original_runtime_by_test, + optimized_runtimes_all=optimized_runtime_by_test, + ) if concolic_test_str: generated_tests_str += "\n\n" + concolic_test_str diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index b9e05e660..a08875a4f 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -16,24 +17,113 @@ git_root_dir, ) from codeflash.code_utils.github_utils import github_pr_url +from codeflash.code_utils.tabulate import tabulate +from codeflash.code_utils.time_utils import format_perf, format_time from codeflash.github.PrComment import FileDiffContent, PrComment +from codeflash.result.critic import performance_gain if TYPE_CHECKING: - from codeflash.models.models import FunctionCalledInTest + from codeflash.models.models import FunctionCalledInTest, InvocationId from codeflash.result.explanation import Explanation + from codeflash.verification.verification_utils import TestConfig def existing_tests_source_for( function_qualified_name_with_modules_from_root: str, function_to_tests: dict[str, set[FunctionCalledInTest]], - tests_root: Path, + test_cfg: TestConfig, + original_runtimes_all: dict[InvocationId, list[int]], + optimized_runtimes_all: dict[InvocationId, list[int]], ) -> str: test_files = function_to_tests.get(function_qualified_name_with_modules_from_root) - existing_tests_unique = set() - if test_files: - for test_file in test_files: - existing_tests_unique.add("- " + str(Path(test_file.tests_in_file.test_file).relative_to(tests_root))) - return "\n".join(sorted(existing_tests_unique)) + if not test_files: + return "" + output: str = "" + rows = [] + headers = ["Test File::Test Function", "Original ⏱️", "Optimized ⏱️", "Speedup"] + tests_root = test_cfg.tests_root + module_root = test_cfg.project_root_path + rel_tests_root = tests_root.relative_to(module_root) + original_tests_to_runtimes: dict[Path, dict[str, int]] = {} + optimized_tests_to_runtimes: dict[Path, dict[str, int]] = {} + non_generated_tests = set() + for test_file in test_files: + non_generated_tests.add(Path(test_file.tests_in_file.test_file).relative_to(tests_root)) + # TODO confirm that original and optimized have the same keys + all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys() + for invocation_id in all_invocation_ids: + rel_path = ( + Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").relative_to(rel_tests_root) + ) + if rel_path not in non_generated_tests: + continue + if rel_path not in original_tests_to_runtimes: + original_tests_to_runtimes[rel_path] = {} + if rel_path not in optimized_tests_to_runtimes: + optimized_tests_to_runtimes[rel_path] = {} + qualified_name = ( + invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator] + if invocation_id.test_class_name + else invocation_id.test_function_name + ) + if qualified_name not in original_tests_to_runtimes[rel_path]: + original_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index] + if qualified_name not in optimized_tests_to_runtimes[rel_path]: + optimized_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index] + if invocation_id in original_runtimes_all: + original_tests_to_runtimes[rel_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index] + if invocation_id in optimized_runtimes_all: + optimized_tests_to_runtimes[rel_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index] + # parse into string + all_rel_paths = ( + original_tests_to_runtimes.keys() + ) # both will have the same keys as some default values are assigned in the previous loop + for filename in sorted(all_rel_paths): + all_qualified_names = original_tests_to_runtimes[ + filename + ].keys() # both will have the same keys as some default values are assigned in the previous loop + for qualified_name in sorted(all_qualified_names): + # if not present in optimized output nan + if ( + original_tests_to_runtimes[filename][qualified_name] != 0 + and optimized_tests_to_runtimes[filename][qualified_name] != 0 + ): + print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name]) + print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name]) + greater = ( + optimized_tests_to_runtimes[filename][qualified_name] + > original_tests_to_runtimes[filename][qualified_name] + ) + perf_gain = format_perf( + performance_gain( + original_runtime_ns=original_tests_to_runtimes[filename][qualified_name], + optimized_runtime_ns=optimized_tests_to_runtimes[filename][qualified_name], + ) + * 100 + ) + if greater: + rows.append( + [ + f"`{filename}::{qualified_name}`", + f"{print_original_runtime}", + f"{print_optimized_runtime}", + f"⚠️{perf_gain}%", + ] + ) + else: + rows.append( + [ + f"`{filename}::{qualified_name}`", + f"{print_original_runtime}", + f"{print_optimized_runtime}", + f"✅{perf_gain}%", + ] + ) + output += tabulate( # type: ignore[no-untyped-call] + headers=headers, tabular_data=rows, tablefmt="pipe", colglobalalign=None, preserve_whitespace=True + ) + output += "\n" + return output def check_create_pr( diff --git a/tests/test_add_runtime_comments.py b/tests/test_add_runtime_comments.py index 51c1ef052..66a77b0d0 100644 --- a/tests/test_add_runtime_comments.py +++ b/tests/test_add_runtime_comments.py @@ -1,18 +1,23 @@ -"""Tests for the add_runtime_comments_to_generated_tests functionality.""" - +import os from pathlib import Path +from unittest.mock import Mock -from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests -from codeflash.models.models import ( - FunctionTestInvocation, - GeneratedTests, - GeneratedTestsList, - InvocationId, - TestResults, - TestType, - VerificationType, -) +import pytest +from codeflash.code_utils.edit_generated_tests import add_runtime_comments_to_generated_tests +from codeflash.models.models import GeneratedTests, GeneratedTestsList, InvocationId, FunctionTestInvocation, TestType, \ + VerificationType, TestResults +from codeflash.verification.verification_utils import TestConfig + +@pytest.fixture +def test_config(): + """Create a mock TestConfig for testing.""" + config = Mock(spec=TestConfig) + config.project_root_path = Path("/project") + config.test_framework= "pytest" + config.tests_project_rootdir = Path("/project/tests") + config.tests_root = Path("/project/tests") + return config class TestAddRuntimeComments: """Test cases for add_runtime_comments_to_generated_tests method.""" @@ -24,13 +29,13 @@ def create_test_invocation( return FunctionTestInvocation( loop_index=loop_index, id=InvocationId( - test_module_path="test_module", + test_module_path="tests.test_module", test_class_name=None, test_function_name=test_function_name, function_getting_tested="test_function", iteration_id=iteration_id, ), - file_name=Path("test.py"), + file_name=Path("tests/test.py"), did_pass=did_pass, runtime=runtime, test_framework="pytest", @@ -40,7 +45,7 @@ def create_test_invocation( verification_type=VerificationType.FUNCTION_CALL, ) - def test_basic_runtime_comment_addition(self): + def test_basic_runtime_comment_addition(self, test_config): """Test basic functionality of adding runtime comments.""" # Create test source code test_source = """def test_bubble_sort(): @@ -52,10 +57,12 @@ def test_basic_runtime_comment_addition(self): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py"), ) - + """add_runtime_comments_to_generated_tests( + test_config, generated_tests, original_runtimes, optimized_runtimes + )""" generated_tests = GeneratedTestsList(generated_tests=[generated_test]) # Create test results @@ -68,16 +75,17 @@ def test_basic_runtime_comment_addition(self): original_test_results.add(original_invocation) optimized_test_results.add(optimized_invocation) - + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that comments were added modified_source = result.generated_tests[0].generated_original_test_source assert "# 500μs -> 300μs" in modified_source assert "codeflash_output = bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source - def test_multiple_test_functions(self): + def test_multiple_test_functions(self, test_config): """Test handling multiple test functions in the same file.""" test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) @@ -95,8 +103,8 @@ def helper_function(): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -112,8 +120,11 @@ def helper_function(): optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() + # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) modified_source = result.generated_tests[0].generated_original_test_source @@ -126,7 +137,7 @@ def helper_function(): and "# " not in modified_source.split("helper_function():")[1].split("\n")[0] ) - def test_different_time_formats(self): + def test_different_time_formats(self, test_config): """Test that different time ranges are formatted correctly with new precision rules.""" test_cases = [ (999, 500, "999ns -> 500ns"), # nanoseconds @@ -147,8 +158,8 @@ def test_different_time_formats(self): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -160,15 +171,17 @@ def test_different_time_formats(self): original_test_results.add(self.create_test_invocation("test_function", original_time)) optimized_test_results.add(self.create_test_invocation("test_function", optimized_time)) + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality result = add_runtime_comments_to_generated_tests( - generated_tests, original_test_results, optimized_test_results + test_config, generated_tests, original_runtimes, optimized_runtimes ) modified_source = result.generated_tests[0].generated_original_test_source assert f"# {expected_comment}" in modified_source - def test_missing_test_results(self): + def test_missing_test_results(self, test_config): """Test behavior when test results are missing for a test function.""" test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) @@ -179,8 +192,8 @@ def test_missing_test_results(self): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -189,14 +202,17 @@ def test_missing_test_results(self): original_test_results = TestResults() optimized_test_results = TestResults() + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() + # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that no comments were added modified_source = result.generated_tests[0].generated_original_test_source assert modified_source == test_source # Should be unchanged - def test_partial_test_results(self): + def test_partial_test_results(self, test_config): """Test behavior when only one set of test results is available.""" test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) @@ -207,8 +223,8 @@ def test_partial_test_results(self): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -219,15 +235,16 @@ def test_partial_test_results(self): original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) # No optimized results - + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that no comments were added modified_source = result.generated_tests[0].generated_original_test_source assert modified_source == test_source # Should be unchanged - def test_multiple_runtimes_uses_minimum(self): + def test_multiple_runtimes_uses_minimum(self, test_config): """Test that when multiple runtimes exist, the minimum is used.""" test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) @@ -238,8 +255,8 @@ def test_multiple_runtimes_uses_minimum(self): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -257,14 +274,16 @@ def test_multiple_runtimes_uses_minimum(self): optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000, loop_index=2)) optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 320_000, loop_index=3)) + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that minimum times were used (500μs -> 300μs) modified_source = result.generated_tests[0].generated_original_test_source assert "# 500μs -> 300μs" in modified_source - def test_no_codeflash_output_assignment(self): + def test_no_codeflash_output_assignment(self, test_config): """Test behavior when test doesn't have codeflash_output assignment.""" test_source = """def test_bubble_sort(): result = bubble_sort([3, 1, 2]) @@ -275,8 +294,8 @@ def test_no_codeflash_output_assignment(self): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -288,14 +307,17 @@ def test_no_codeflash_output_assignment(self): original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() + # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that no comments were added (no codeflash_output assignment) modified_source = result.generated_tests[0].generated_original_test_source assert modified_source == test_source # Should be unchanged - def test_invalid_python_code_handling(self): + def test_invalid_python_code_handling(self, test_config): """Test behavior when test source code is invalid Python.""" test_source = """def test_bubble_sort(: codeflash_output = bubble_sort([3, 1, 2]) @@ -306,8 +328,8 @@ def test_invalid_python_code_handling(self): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -319,14 +341,17 @@ def test_invalid_python_code_handling(self): original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() + # Test the functionality - should handle parse error gracefully - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that original test is preserved when parsing fails modified_source = result.generated_tests[0].generated_original_test_source assert modified_source == test_source # Should be unchanged due to parse error - def test_multiple_generated_tests(self): + def test_multiple_generated_tests(self, test_config): """Test handling multiple generated test objects.""" test_source_1 = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) @@ -342,16 +367,16 @@ def test_multiple_generated_tests(self): generated_original_test_source=test_source_1, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior_1.py"), - perf_file_path=Path("test_perf_1.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_test_2 = GeneratedTests( generated_original_test_source=test_source_2, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior_2.py"), - perf_file_path=Path("test_perf_2.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test_1, generated_test_2]) @@ -366,8 +391,11 @@ def test_multiple_generated_tests(self): optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) optimized_test_results.add(self.create_test_invocation("test_quick_sort", 600_000)) + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() + # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that comments were added to both test files modified_source_1 = result.generated_tests[0].generated_original_test_source @@ -376,7 +404,7 @@ def test_multiple_generated_tests(self): assert "# 500μs -> 300μs" in modified_source_1 assert "# 800μs -> 600μs" in modified_source_2 - def test_preserved_test_attributes(self): + def test_preserved_test_attributes(self, test_config): """Test that other test attributes are preserved during modification.""" test_source = """def test_bubble_sort(): codeflash_output = bubble_sort([3, 1, 2]) @@ -385,15 +413,15 @@ def test_preserved_test_attributes(self): original_behavior_source = "behavior test source" original_perf_source = "perf test source" - original_behavior_path = Path("test_behavior.py") - original_perf_path = Path("test_perf.py") + original_behavior_path = Path("/project/tests/test_module.py") + original_perf_path = Path("/project/tests/test_module_perf.py") generated_test = GeneratedTests( generated_original_test_source=test_source, instrumented_behavior_test_source=original_behavior_source, instrumented_perf_test_source=original_perf_source, behavior_file_path=original_behavior_path, - perf_file_path=original_perf_path, + perf_file_path=original_perf_path ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -405,8 +433,10 @@ def test_preserved_test_attributes(self): original_test_results.add(self.create_test_invocation("test_bubble_sort", 500_000)) optimized_test_results.add(self.create_test_invocation("test_bubble_sort", 300_000)) + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that other attributes are preserved modified_test = result.generated_tests[0] @@ -418,7 +448,7 @@ def test_preserved_test_attributes(self): # Check that only the generated_original_test_source was modified assert "# 500μs -> 300μs" in modified_test.generated_original_test_source - def test_multistatement_line_handling(self): + def test_multistatement_line_handling(self, test_config): """Test that runtime comments work correctly with multiple statements on one line.""" test_source = """def test_mutation_of_input(): # Test that the input list is mutated in-place and returned @@ -432,8 +462,8 @@ def test_multistatement_line_handling(self): generated_original_test_source=test_source, instrumented_behavior_test_source="", instrumented_perf_test_source="", - behavior_file_path=Path("test_behavior.py"), - perf_file_path=Path("test_perf.py"), + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py") ) generated_tests = GeneratedTestsList(generated_tests=[generated_test]) @@ -445,8 +475,11 @@ def test_multistatement_line_handling(self): original_test_results.add(self.create_test_invocation("test_mutation_of_input", 19_000)) # 19μs optimized_test_results.add(self.create_test_invocation("test_mutation_of_input", 14_000)) # 14μs + original_runtimes = original_test_results.usable_runtime_data_by_test_case() + optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case() + # Test the functionality - result = add_runtime_comments_to_generated_tests(generated_tests, original_test_results, optimized_test_results) + result = add_runtime_comments_to_generated_tests(test_config, generated_tests, original_runtimes, optimized_runtimes) # Check that comments were added to the correct line modified_source = result.generated_tests[0].generated_original_test_source @@ -462,3 +495,319 @@ def test_multistatement_line_handling(self): assert codeflash_line is not None, "Could not find codeflash_output assignment line" assert "# 19.0μs -> 14.0μs" in codeflash_line, f"Comment not found in the correct line: {codeflash_line}" + + + def test_add_runtime_comments_simple_function(self, test_config): + """Test adding runtime comments to a simple test function.""" + test_source = '''def test_function(): + codeflash_output = some_function() + assert codeflash_output == expected +''' + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + invocation_id = InvocationId( + test_module_path="tests.test_module", + test_class_name=None, + test_function_name="test_function", + function_getting_tested="some_function", + iteration_id="0", + ) + + original_runtimes = {invocation_id: [1000000000, 1200000000]} # 1s, 1.2s in nanoseconds + optimized_runtimes = {invocation_id: [500000000, 600000000]} # 0.5s, 0.6s in nanoseconds + + result = add_runtime_comments_to_generated_tests( + test_config, generated_tests, original_runtimes, optimized_runtimes + ) + + expected_source = '''def test_function(): + codeflash_output = some_function() # 1.00s -> 500ms (100% faster) + assert codeflash_output == expected +''' + + assert len(result.generated_tests) == 1 + assert result.generated_tests[0].generated_original_test_source == expected_source + + def test_add_runtime_comments_class_method(self, test_config): + """Test adding runtime comments to a test method within a class.""" + test_source = '''class TestClass: + def test_function(self): + codeflash_output = some_function() + assert codeflash_output == expected +''' + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + invocation_id = InvocationId( + test_module_path="tests.test_module", + test_class_name="TestClass", + test_function_name="test_function", + function_getting_tested="some_function", + iteration_id="0", + + ) + + original_runtimes = {invocation_id: [2000000000]} # 2s in nanoseconds + optimized_runtimes = {invocation_id: [1000000000]} # 1s in nanoseconds + + result = add_runtime_comments_to_generated_tests( + test_config, generated_tests, original_runtimes, optimized_runtimes + ) + + expected_source = '''class TestClass: + def test_function(self): + codeflash_output = some_function() # 2.00s -> 1.00s (100% faster) + assert codeflash_output == expected +''' + + assert len(result.generated_tests) == 1 + assert result.generated_tests[0].generated_original_test_source == expected_source + + def test_add_runtime_comments_multiple_assignments(self, test_config): + """Test adding runtime comments when there are multiple codeflash_output assignments.""" + test_source = '''def test_function(): + setup_data = prepare_test() + codeflash_output = some_function() + assert codeflash_output == expected + codeflash_output = another_function() + assert codeflash_output == expected2 +''' + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + invocation_id = InvocationId( + test_module_path="tests.test_module", + test_class_name=None, + test_function_name="test_function", + function_getting_tested="some_function", + iteration_id="0", + ) + + original_runtimes = {invocation_id: [1500000000]} # 1.5s in nanoseconds + optimized_runtimes = {invocation_id: [750000000]} # 0.75s in nanoseconds + + result = add_runtime_comments_to_generated_tests( + test_config, generated_tests, original_runtimes, optimized_runtimes + ) + + expected_source = '''def test_function(): + setup_data = prepare_test() + codeflash_output = some_function() # 1.50s -> 750ms (100% faster) + assert codeflash_output == expected + codeflash_output = another_function() # 1.50s -> 750ms (100% faster) + assert codeflash_output == expected2 +''' + + assert len(result.generated_tests) == 1 + assert result.generated_tests[0].generated_original_test_source == expected_source + + def test_add_runtime_comments_no_matching_runtimes(self, test_config): + """Test that source remains unchanged when no matching runtimes are found.""" + test_source = '''def test_function(): + codeflash_output = some_function() + assert codeflash_output == expected +''' + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + # Different invocation ID that won't match + invocation_id = InvocationId( + test_module_path="tests.other_module", + test_class_name=None, + test_function_name="other_function", + function_getting_tested="some_other_function", + iteration_id="0", + ) + + original_runtimes = {invocation_id: [1000000000]} + optimized_runtimes = {invocation_id: [500000000]} + + result = add_runtime_comments_to_generated_tests( + test_config, generated_tests, original_runtimes, optimized_runtimes + ) + + # Source should remain unchanged + assert len(result.generated_tests) == 1 + assert result.generated_tests[0].generated_original_test_source == test_source + + def test_add_runtime_comments_no_codeflash_output(self, test_config): + """Test that source remains unchanged when there's no codeflash_output assignment.""" + test_source = '''def test_function(): + result = some_function() + assert result == expected +''' + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + invocation_id = InvocationId( + test_module_path="tests.test_module", + test_class_name=None, + test_function_name="test_function", + function_getting_tested="some_function", + iteration_id="0", + ) + + original_runtimes = {invocation_id: [1000000000]} + optimized_runtimes = {invocation_id: [500000000]} + + result = add_runtime_comments_to_generated_tests( + test_config, generated_tests, original_runtimes, optimized_runtimes + ) + + # Source should remain unchanged + assert len(result.generated_tests) == 1 + assert result.generated_tests[0].generated_original_test_source == test_source + + def test_add_runtime_comments_multiple_tests(self, test_config): + """Test adding runtime comments to multiple generated tests.""" + test_source1 = '''def test_function1(): + codeflash_output = some_function() + assert codeflash_output == expected +''' + + test_source2 = '''def test_function2(): + codeflash_output = another_function() + assert codeflash_output == expected +''' + + generated_test1 = GeneratedTests( + generated_original_test_source=test_source1, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("/project/tests/test_module1.py"), + perf_file_path=Path("/project/tests/test_module1_perf.py"), + ) + + generated_test2 = GeneratedTests( + generated_original_test_source=test_source2, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("/project/tests/test_module2.py"), + perf_file_path=Path("/project/tests/test_module2_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test1, generated_test2]) + + invocation_id1 = InvocationId( + test_module_path="tests.test_module1", + test_class_name=None, + test_function_name="test_function1", + function_getting_tested="some_function", + iteration_id="0", + ) + + invocation_id2 = InvocationId( + test_module_path="tests.test_module2", + test_class_name=None, + test_function_name="test_function2", + function_getting_tested="another_function", + iteration_id = "0", + ) + + original_runtimes = { + invocation_id1: [1000000000], # 1s + invocation_id2: [2000000000], # 2s + } + optimized_runtimes = { + invocation_id1: [500000000], # 0.5s + invocation_id2: [800000000], # 0.8s + } + + result = add_runtime_comments_to_generated_tests( + test_config, generated_tests, original_runtimes, optimized_runtimes + ) + + expected_source1 = '''def test_function1(): + codeflash_output = some_function() # 1.00s -> 500ms (100% faster) + assert codeflash_output == expected +''' + + expected_source2 = '''def test_function2(): + codeflash_output = another_function() # 2.00s -> 800ms (150% faster) + assert codeflash_output == expected +''' + + assert len(result.generated_tests) == 2 + assert result.generated_tests[0].generated_original_test_source == expected_source1 + assert result.generated_tests[1].generated_original_test_source == expected_source2 + + def test_add_runtime_comments_performance_regression(self, test_config): + """Test adding runtime comments when optimized version is slower (negative performance gain).""" + test_source = '''def test_function(): + codeflash_output = some_function() + assert codeflash_output == expected +''' + + generated_test = GeneratedTests( + generated_original_test_source=test_source, + instrumented_behavior_test_source="", + instrumented_perf_test_source="", + behavior_file_path=Path("/project/tests/test_module.py"), + perf_file_path=Path("/project/tests/test_module_perf.py"), + ) + + generated_tests = GeneratedTestsList(generated_tests=[generated_test]) + + invocation_id = InvocationId( + test_module_path="tests.test_module", + test_class_name=None, + test_function_name="test_function", + function_getting_tested="some_function", + iteration_id="0", + ) + + original_runtimes = {invocation_id: [1000000000]} # 1s + optimized_runtimes = {invocation_id: [1500000000]} # 1.5s (slower!) + + result = add_runtime_comments_to_generated_tests( + test_config, generated_tests, original_runtimes, optimized_runtimes + ) + + expected_source = '''def test_function(): + codeflash_output = some_function() # 1.00s -> 1.50s (33.3% slower) + assert codeflash_output == expected +''' + + assert len(result.generated_tests) == 1 + assert result.generated_tests[0].generated_original_test_source == expected_source diff --git a/tests/test_existing_tests_source_for.py b/tests/test_existing_tests_source_for.py new file mode 100644 index 000000000..8940b20d2 --- /dev/null +++ b/tests/test_existing_tests_source_for.py @@ -0,0 +1,350 @@ +import os +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from codeflash.result.create_pr import existing_tests_source_for + + +class TestExistingTestsSourceFor: + """Test cases for existing_tests_source_for function.""" + + def setup_method(self): + """Set up test fixtures.""" + # Mock test config + self.test_cfg = Mock() + self.test_cfg.tests_root = Path("/project/tests") + self.test_cfg.project_root_path = Path("/project") + + # Mock invocation ID + self.mock_invocation_id = Mock() + self.mock_invocation_id.test_module_path = "tests.test_module" + self.mock_invocation_id.test_class_name = "TestClass" + self.mock_invocation_id.test_function_name = "test_function" + + # Mock function called in test + self.mock_function_called_in_test = Mock() + self.mock_function_called_in_test.tests_in_file = Mock() + self.mock_function_called_in_test.tests_in_file.test_file = "/project/tests/test_module.py" + + def test_no_test_files_returns_empty_string(self): + """Test that function returns empty string when no test files exist.""" + function_to_tests = {} + original_runtimes = {} + optimized_runtimes = {} + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + assert result == "" + + def test_single_test_with_improvement(self): + """Test single test showing performance improvement.""" + function_to_tests = { + "module.function": {self.mock_function_called_in_test} + } + original_runtimes = { + self.mock_invocation_id: [1000000] # 1ms in nanoseconds + } + optimized_runtimes = { + self.mock_invocation_id: [500000] # 0.5ms in nanoseconds + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|:------------------------------------------|:--------------|:---------------|:----------| +| `test_module.py::TestClass.test_function` | 1.00ms | 500μs | ✅100% | +""" + + assert result == expected + + def test_single_test_with_regression(self): + """Test single test showing performance regression.""" + function_to_tests = { + "module.function": {self.mock_function_called_in_test} + } + original_runtimes = { + self.mock_invocation_id: [500000] # 0.5ms in nanoseconds + } + optimized_runtimes = { + self.mock_invocation_id: [1000000] # 1ms in nanoseconds + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|:------------------------------------------|:--------------|:---------------|:----------| +| `test_module.py::TestClass.test_function` | 500μs | 1.00ms | ⚠️-50.0% | +""" + + assert result == expected + + def test_test_without_class_name(self): + """Test function without class name (standalone test function).""" + mock_invocation_no_class = Mock() + mock_invocation_no_class.test_module_path = "tests.test_module" + mock_invocation_no_class.test_class_name = None + mock_invocation_no_class.test_function_name = "test_standalone" + + function_to_tests = { + "module.function": {self.mock_function_called_in_test} + } + original_runtimes = { + mock_invocation_no_class: [1000000] + } + optimized_runtimes = { + mock_invocation_no_class: [800000] + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|:----------------------------------|:--------------|:---------------|:----------| +| `test_module.py::test_standalone` | 1.00ms | 800μs | ✅25.0% | +""" + + assert result == expected + + def test_missing_original_runtime(self): + """Test when original runtime is missing (shows NaN).""" + function_to_tests = { + "module.function": {self.mock_function_called_in_test} + } + original_runtimes = {} + optimized_runtimes = { + self.mock_invocation_id: [500000] + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|----------------------------|---------------|----------------|-----------| +""" + + assert result == expected + + def test_missing_optimized_runtime(self): + """Test when optimized runtime is missing (shows NaN).""" + function_to_tests = { + "module.function": {self.mock_function_called_in_test} + } + original_runtimes = { + self.mock_invocation_id: [1000000] + } + optimized_runtimes = {} + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|----------------------------|---------------|----------------|-----------| +""" + + assert result == expected + + def test_multiple_tests_sorted_output(self): + """Test multiple tests with sorted output by filename and function name.""" + # Create second test file + mock_function_called_2 = Mock() + mock_function_called_2.tests_in_file = Mock() + mock_function_called_2.tests_in_file.test_file = "/project/tests/test_another.py" + + mock_invocation_2 = Mock() + mock_invocation_2.test_module_path = "tests.test_another" + mock_invocation_2.test_class_name = "TestAnother" + mock_invocation_2.test_function_name = "test_another_function" + + function_to_tests = { + "module.function": {self.mock_function_called_in_test, mock_function_called_2} + } + original_runtimes = { + self.mock_invocation_id: [1000000], + mock_invocation_2: [2000000] + } + optimized_runtimes = { + self.mock_invocation_id: [800000], + mock_invocation_2: [1500000] + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|:-----------------------------------------------------|:--------------|:---------------|:----------| +| `test_another.py::TestAnother.test_another_function` | 2.00ms | 1.50ms | ✅33.3% | +| `test_module.py::TestClass.test_function` | 1.00ms | 800μs | ✅25.0% | +""" + + assert result == expected + + def test_multiple_runtimes_uses_minimum(self): + """Test that function uses minimum runtime when multiple measurements exist.""" + function_to_tests = { + "module.function": {self.mock_function_called_in_test} + } + original_runtimes = { + self.mock_invocation_id: [1000000, 1200000, 800000] # min: 800000 + } + optimized_runtimes = { + self.mock_invocation_id: [600000, 700000, 500000] # min: 500000 + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|:------------------------------------------|:--------------|:---------------|:----------| +| `test_module.py::TestClass.test_function` | 800μs | 500μs | ✅60.0% | +""" + + assert result == expected + + def test_complex_module_path_conversion(self): + """Test conversion of complex module paths to file paths.""" + mock_invocation_complex = Mock() + mock_invocation_complex.test_module_path = "tests.integration.test_complex_module" + mock_invocation_complex.test_class_name = "TestComplex" + mock_invocation_complex.test_function_name = "test_complex_function" + + mock_function_complex = Mock() + mock_function_complex.tests_in_file = Mock() + mock_function_complex.tests_in_file.test_file = f"/project/tests/integration/test_complex_module.py" + + function_to_tests = { + "module.function": {mock_function_complex} + } + original_runtimes = { + mock_invocation_complex: [1000000] + } + optimized_runtimes = { + mock_invocation_complex: [750000] + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|:------------------------------------------------------------------------|:--------------|:---------------|:----------| +| `integration/test_complex_module.py::TestComplex.test_complex_function` | 1.00ms | 750μs | ✅33.3% | +""" + + assert result == expected + + def test_zero_runtime_values(self): + """Test handling of zero runtime values.""" + function_to_tests = { + "module.function": {self.mock_function_called_in_test} + } + original_runtimes = { + self.mock_invocation_id: [0] + } + optimized_runtimes = { + self.mock_invocation_id: [0] + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|----------------------------|---------------|----------------|-----------| +""" + + assert result == expected + + def test_filters_out_generated_tests(self): + """Test that generated tests are filtered out and only non-generated tests are included.""" + # Create a test that would be filtered out (not in non_generated_tests) + mock_generated_test = Mock() + mock_generated_test.tests_in_file = Mock() + mock_generated_test.tests_in_file.test_file = "/project/tests/generated_test.py" + + mock_generated_invocation = Mock() + mock_generated_invocation.test_module_path = "tests.generated_test" + mock_generated_invocation.test_class_name = "TestGenerated" + mock_generated_invocation.test_function_name = "test_generated" + + function_to_tests = { + "module.function": {self.mock_function_called_in_test} + } + original_runtimes = { + self.mock_invocation_id: [1000000], + mock_generated_invocation: [500000] # This should be filtered out + } + optimized_runtimes = { + self.mock_invocation_id: [800000], + mock_generated_invocation: [400000] # This should be filtered out + } + + result = existing_tests_source_for( + "module.function", + function_to_tests, + self.test_cfg, + original_runtimes, + optimized_runtimes + ) + + # Should only include the non-generated test + expected = """| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup | +|:------------------------------------------|:--------------|:---------------|:----------| +| `test_module.py::TestClass.test_function` | 1.00ms | 800μs | ✅25.0% | +""" + + assert result == expected + + diff --git a/tests/test_humanize_time.py b/tests/test_humanize_time.py index 4021b077e..ecc5e16d7 100644 --- a/tests/test_humanize_time.py +++ b/tests/test_humanize_time.py @@ -1,4 +1,5 @@ from codeflash.code_utils.time_utils import humanize_runtime, format_time +from codeflash.code_utils.time_utils import format_perf import pytest @@ -172,4 +173,103 @@ def test_negative_values(self): # This test depends on whether your function should handle negative values # You might want to modify based on expected behavior with pytest.raises((ValueError, TypeError)) or pytest.warns(): - format_time(-1000) \ No newline at end of file + format_time(-1000) + + +class TestFormatPerf: + """Test cases for the format_perf function.""" + + def test_format_perf_large_values_above_100(self): + """Test formatting for values above 100 (no decimal places).""" + assert format_perf(150.789) == "151" + assert format_perf(999.999) == "1000" + assert format_perf(100.1) == "100" + assert format_perf(500) == "500" + assert format_perf(1000.5) == "1000" + + def test_format_perf_medium_values_10_to_100(self): + """Test formatting for values between 10 and 100 (1 decimal place).""" + assert format_perf(99.99) == "100.0" + assert format_perf(50.789) == "50.8" + assert format_perf(10.1) == "10.1" + assert format_perf(25.0) == "25.0" + assert format_perf(33.333) == "33.3" + + def test_format_perf_small_values_1_to_10(self): + """Test formatting for values between 1 and 10 (2 decimal places).""" + assert format_perf(9.999) == "10.00" + assert format_perf(5.789) == "5.79" + assert format_perf(1.1) == "1.10" + assert format_perf(2.0) == "2.00" + assert format_perf(7.123) == "7.12" + + def test_format_perf_very_small_values_below_1(self): + """Test formatting for values below 1 (3 decimal places).""" + assert format_perf(0.999) == "0.999" + assert format_perf(0.5) == "0.500" + assert format_perf(0.123) == "0.123" + assert format_perf(0.001) == "0.001" + assert format_perf(0.0) == "0.000" + + def test_format_perf_negative_values(self): + """Test formatting for negative values (uses absolute value for comparison).""" + assert format_perf(-150.789) == "-151" + assert format_perf(-50.789) == "-50.8" + assert format_perf(-5.789) == "-5.79" + assert format_perf(-0.999) == "-0.999" + assert format_perf(-0.0) == "-0.000" + + def test_format_perf_boundary_values(self): + """Test formatting for exact boundary values.""" + assert format_perf(100.0) == "100" + assert format_perf(10.0) == "10.0" + assert format_perf(1.0) == "1.00" + assert format_perf(-100.0) == "-100" + assert format_perf(-10.0) == "-10.0" + assert format_perf(-1.0) == "-1.00" + + def test_format_perf_integer_inputs(self): + """Test formatting with integer inputs.""" + assert format_perf(150) == "150" + assert format_perf(50) == "50.0" + assert format_perf(5) == "5.00" + assert format_perf(0) == "0.000" + assert format_perf(-150) == "-150" + assert format_perf(-50) == "-50.0" + assert format_perf(-5) == "-5.00" + + def test_format_perf_float_inputs(self): + """Test formatting with float inputs.""" + assert format_perf(123.456) == "123" + assert format_perf(12.3456) == "12.3" + assert format_perf(1.23456) == "1.23" + assert format_perf(0.123456) == "0.123" + + def test_format_perf_edge_cases(self): + """Test formatting for edge cases and special values.""" + # Very large numbers + assert format_perf(999999.99) == "1000000" + assert format_perf(1000000) == "1000000" + + # Very small positive numbers + assert format_perf(0.0001) == "0.000" + assert format_perf(0.00001) == "0.000" + + # Numbers very close to boundaries + assert format_perf(99.9999) == "100.0" + assert format_perf(9.9999) == "10.00" + assert format_perf(0.9999) == "1.000" + + def test_format_perf_rounding_behavior(self): + """Test that rounding behavior is consistent.""" + # Test rounding up + assert format_perf(100.5) == "100" + assert format_perf(10.55) == "10.6" + assert format_perf(1.555) == "1.55" + assert format_perf(0.1555) == "0.155" + + # Test rounding down + assert format_perf(100.4) == "100" + assert format_perf(10.54) == "10.5" + assert format_perf(1.554) == "1.55" + assert format_perf(0.1554) == "0.155" \ No newline at end of file