From 6a920c74faa61e7c281b89b52d05a9ab229996ac Mon Sep 17 00:00:00 2001 From: mohammed Date: Sun, 1 Jun 2025 14:48:18 +0300 Subject: [PATCH 1/2] partially format optimized functions (not the whole file)--WIP --- code_to_optimize/bad_formatting.py | 43 ++++++++++ codeflash/discovery/functions_to_optimize.py | 21 +++++ codeflash/optimization/function_optimizer.py | 76 ++++++++++++++---- tests/test_function_optimizer.py | 83 ++++++++++++++++++++ 4 files changed, 207 insertions(+), 16 deletions(-) create mode 100644 code_to_optimize/bad_formatting.py create mode 100644 tests/test_function_optimizer.py diff --git a/code_to_optimize/bad_formatting.py b/code_to_optimize/bad_formatting.py new file mode 100644 index 000000000..00e3c5070 --- /dev/null +++ b/code_to_optimize/bad_formatting.py @@ -0,0 +1,43 @@ +import sys + + +def lol(): + print( "lol" ) + + + + + + + + + +class BubbleSorter: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + + + + + + + + def sorter (self, arr): + + + print ("codeflash stdout : BubbleSorter.sorter() called") + n = len(arr) + for i in range(n): + swapped = False + for j in range(0, n - i - 1): + if arr[j] > arr[j + 1]: + arr[j], arr[j + 1] = arr[j + 1], arr[j] # Faster swap + swapped = True + if not swapped: + break + print ("stderr test", file=sys.stderr) + return arr \ No newline at end of file diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 8aa052ab0..1dd321fed 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -81,6 +81,27 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: ) ) +class CodeRangeFunctionVisitor(cst.CSTVisitor): + METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.QualifiedNameProvider) + + def __init__(self, target_function_name: str) -> None: + super().__init__() + self.target_func = target_function_name + self.current_path = [] + self.start_line: Optional[int] = None + self.end_line: Optional[int] = None + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + qualified_names = { + str(qn.name) for qn in + self.get_metadata(cst.metadata.QualifiedNameProvider, node) + } + + if self.target_func in qualified_names: + position = self.get_metadata(cst.metadata.PositionProvider, node) + self.start_line = position.start.line + self.end_line = position.end.line + class FunctionWithReturnStatement(ast.NodeVisitor): def __init__(self, file_path: Path) -> None: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index fe4357839..4abaed5b3 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -9,7 +9,7 @@ from collections import defaultdict, deque from pathlib import Path from typing import TYPE_CHECKING - +import tempfile import isort import libcst as cst from rich.console import Group @@ -72,6 +72,8 @@ from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests +from codeflash.discovery.functions_to_optimize import CodeRangeFunctionVisitor + if TYPE_CHECKING: from argparse import Namespace @@ -300,8 +302,16 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 code_context=code_context, optimized_code=best_optimization.candidate.source_code ) + print("file_path_to_helper_classes\n", file_path_to_helper_classes) + filepaths_to_inspect = [ + self.function_to_optimize.file_path, + *list({helper.file_path for helper in code_context.helper_functions}), + ] + print("filepaths_to_inspect\n", filepaths_to_inspect) + new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code + code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code, + opt_func_name=explanation.function_name ) existing_tests = existing_tests_source_for( @@ -590,25 +600,59 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, f.write(helper_code) def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str + self, helper_functions: list[FunctionSource], + path: Path, + original_code: str, + opt_func_name: str ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - new_code = format_code(self.args.formatter_cmds, path) - if should_sort_imports: - new_code = sort_imports(new_code) - - new_helper_code: dict[Path, str] = {} - helper_functions_paths = {hf.file_path for hf in helper_functions} - for module_abspath in helper_functions_paths: - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) - if should_sort_imports: - formatted_helper_code = sort_imports(formatted_helper_code) - new_helper_code[module_abspath] = formatted_helper_code - - return new_code, new_helper_code + whole_file_content = path.read_text(encoding="utf8") + wrapper = cst.metadata.MetadataWrapper(cst.parse_module(whole_file_content)) + visitor = CodeRangeFunctionVisitor(target_function_name=opt_func_name) + wrapper.visit(visitor) + + lines = whole_file_content.splitlines(keepends=True) + if visitor.start_line == None: + logger.error(f"Could not find function {opt_func_name} in {path}, aborting reformatting.") + else: + opt_func_source_lines = lines[visitor.start_line-1:visitor.end_line] + + # fix opt func identation + first_line = opt_func_source_lines[0] + first_line_indent = len(first_line) - len(first_line.lstrip()) # number of spaces before the first character + opt_func_source_lines[0] = opt_func_source_lines[0][first_line_indent:] # remove first line ident, so when we save the function code into a temp file, we don't get syntax errors + + with tempfile.NamedTemporaryFile(mode='w+', delete=True) as f: + f.write("".join(opt_func_source_lines)) + f.flush() + tmp_file = Path(f.name) + formatted_func = format_code(self.args.formatter_cmds, tmp_file) + # apply the identation back to all lines of the formatted function + formatted_lines = formatted_func.splitlines(keepends=True) + for i in range(len(formatted_lines)): + formatted_lines[i] = (" " * first_line_indent) + formatted_lines[i] + + # replace the unformatted code with formatted ones + new_code = ( + "".join(lines[:visitor.start_line-1]) + + "".join(formatted_lines) + + "".join(lines[visitor.end_line:]) + ) + if should_sort_imports: + new_code = sort_imports(new_code) + + new_helper_code: dict[Path, str] = {} + helper_functions_paths = {hf.file_path for hf in helper_functions} + for module_abspath in helper_functions_paths: + formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) + if should_sort_imports: + formatted_helper_code = sort_imports(formatted_helper_code) + new_helper_code[module_abspath] = formatted_helper_code + + return new_code, new_helper_code def replace_function_and_helpers_with_optimized_code( self, code_context: CodeOptimizationContext, optimized_code: str diff --git a/tests/test_function_optimizer.py b/tests/test_function_optimizer.py new file mode 100644 index 000000000..fc25ae691 --- /dev/null +++ b/tests/test_function_optimizer.py @@ -0,0 +1,83 @@ +import argparse +from pathlib import Path +import tempfile + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig + + +def test_reformat_code_and_helpers(): + """ + reformat_code_and_helpers should only format the code that is optimized not the whole file, to avoid large diffing + """ + with tempfile.TemporaryDirectory() as test_dir_str: + test_dir = Path(test_dir_str) + target_path = test_dir / "target.py" + unformatted_code = """import sys + + +def lol(): + print( "lol" ) + + + + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + def lol2 (self): + print( " lol2" )""" + expected_code = """import sys + + +def lol(): + print( "lol" ) + + + + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + def lol2(self): + print(" lol2") +""" + target_path.write_text(unformatted_code, encoding="utf-8") + function_to_optimize = FunctionToOptimize(function_name="MyClass.lol2", parents=[], file_path=target_path) + + test_cfg = TestConfig( + tests_root=test_dir, + project_root_path=test_dir, + test_framework="pytest", + tests_project_rootdir=test_dir, + ) + args = argparse.Namespace( + disable_imports_sorting=False, + formatter_cmds=[ + "ruff check --exit-zero --fix $file", + "ruff format $file" + ], + ) + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + args=args, + ) + + + formatted_code,_ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + assert formatted_code == expected_code \ No newline at end of file From cd9f0f48db648d7f643c0f72ebe91427204a3755 Mon Sep 17 00:00:00 2001 From: mohammed Date: Mon, 2 Jun 2025 12:17:43 +0300 Subject: [PATCH 2/2] tests: edge cases for cst formatting --- codeflash/discovery/functions_to_optimize.py | 24 +- codeflash/optimization/function_optimizer.py | 16 +- tests/test_formatter.py | 256 +++++++++++++++++++ tests/test_function_optimizer.py | 83 ------ 4 files changed, 277 insertions(+), 102 deletions(-) delete mode 100644 tests/test_function_optimizer.py diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 76ce8afff..37bd298cb 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -80,27 +80,29 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: ending_line=pos.end.line, ) ) - class CodeRangeFunctionVisitor(cst.CSTVisitor): - METADATA_DEPENDENCIES = (cst.metadata.PositionProvider, cst.metadata.QualifiedNameProvider) - + METADATA_DEPENDENCIES = ( + cst.metadata.PositionProvider, + cst.metadata.QualifiedNameProvider, + ) + def __init__(self, target_function_name: str) -> None: super().__init__() self.target_func = target_function_name - self.current_path = [] self.start_line: Optional[int] = None self.end_line: Optional[int] = None def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: - qualified_names = { - str(qn.name) for qn in + qualified_names = [ + str(qn.name).replace(".", "") for qn in self.get_metadata(cst.metadata.QualifiedNameProvider, node) - } - + ] if self.target_func in qualified_names: - position = self.get_metadata(cst.metadata.PositionProvider, node) - self.start_line = position.start.line - self.end_line = position.end.line + func_position = self.get_metadata(cst.metadata.PositionProvider, node) + decorators_count = len(node.decorators) + self.start_line = func_position.start.line - decorators_count + self.end_line = func_position.end.line + return False class FunctionWithReturnStatement(ast.NodeVisitor): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 4abaed5b3..1d7448400 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -302,13 +302,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 code_context=code_context, optimized_code=best_optimization.candidate.source_code ) - print("file_path_to_helper_classes\n", file_path_to_helper_classes) - filepaths_to_inspect = [ - self.function_to_optimize.file_path, - *list({helper.file_path for helper in code_context.helper_functions}), - ] - print("filepaths_to_inspect\n", filepaths_to_inspect) - new_code, new_helper_code = self.reformat_code_and_helpers( code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code, opt_func_name=explanation.function_name @@ -610,13 +603,20 @@ def reformat_code_and_helpers( should_sort_imports = False whole_file_content = path.read_text(encoding="utf8") - wrapper = cst.metadata.MetadataWrapper(cst.parse_module(whole_file_content)) + wrapper: cst.metadata.MetadataWrapper | None = None + try: + wrapper = cst.metadata.MetadataWrapper(cst.parse_module(whole_file_content)) + except cst.ParserSyntaxError as e: + logger.error(f"Syntax error detected, aborting reformatting.") + return original_code, {} + visitor = CodeRangeFunctionVisitor(target_function_name=opt_func_name) wrapper.visit(visitor) lines = whole_file_content.splitlines(keepends=True) if visitor.start_line == None: logger.error(f"Could not find function {opt_func_name} in {path}, aborting reformatting.") + return original_code, {} else: opt_func_source_lines = lines[visitor.start_line-1:visitor.end_line] diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..261ca4233 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1,3 +1,4 @@ +import argparse import os import tempfile from pathlib import Path @@ -7,6 +8,9 @@ from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig def test_remove_duplicate_imports(): """Test that duplicate imports are removed when should_sort_imports is True.""" @@ -209,3 +213,255 @@ def foo(): tmp_path = tmp.name with pytest.raises(FileNotFoundError): format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) + +############################################################ +################ CST based formatting tests ################ +############################################################ +@pytest.fixture +def setup_cst_formatter_args(): + """Common setup for reformat_code_and_helpers tests.""" + def _setup(unformatted_code, function_name): + test_dir = Path(tempfile.mkdtemp()) + target_path = test_dir / "target.py" + target_path.write_text(unformatted_code, encoding="utf-8") + + function_to_optimize = FunctionToOptimize( + function_name=function_name, + parents=[], + file_path=target_path + ) + + test_cfg = TestConfig( + tests_root=test_dir, + project_root_path=test_dir, + test_framework="pytest", + tests_project_rootdir=test_dir, + ) + + args = argparse.Namespace( + disable_imports_sorting=False, + formatter_cmds=[ + "ruff check --exit-zero --fix $file", + "ruff format $file" + ], + ) + + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + args=args, + ) + + return optimizer, target_path, function_to_optimize + + yield _setup + + +def test_reformat_code_and_helpers(setup_cst_formatter_args): + """ + reformat_code_and_helpers should only format the code that is optimized not the whole file, to avoid large diffing + """ + unformatted_code = """import sys + + +def lol(): + print( "lol" ) + + + + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + def lol2 (self): + print( " lol2" )""" + + expected_code = """import sys + + +def lol(): + print( "lol" ) + + + + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + def lol2(self): + print(" lol2") +""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "MyClass.lol2" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + +def test_reformat_code_and_helpers_with_duplicated_target_function_names(setup_cst_formatter_args): + unformatted_code = """import sys +def lol(): + print( "lol" ) + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" )""" + + expected_code = """import sys +def lol(): + print( "lol" ) + +class MyClass: + def __init__(self, x=0): + self.x = x + + def lol(self): + print("lol") +""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "MyClass.lol" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + + +def test_formatting_nested_functions(setup_cst_formatter_args): + unformatted_code = """def hello(): + print("Hello") + def nested_function() : + print ("This is a nested function") + def another_nested_function(): + print ("This is another nested function")""" + + expected_code = """def hello(): + print("Hello") + def nested_function(): + print("This is a nested function") + def another_nested_function(): + print ("This is another nested function")""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "hello.nested_function" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + +def test_formatting_standalone_functions(setup_cst_formatter_args): + unformatted_code = """def func1 (): + print( "This is a function with bad formatting") +def func2() : + print ( "This is another function with bad formatting" ) +""" + + expected_code = """def func1 (): + print( "This is a function with bad formatting") +def func2(): + print("This is another function with bad formatting") +""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "func2" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + +def test_formatting_function_with_decorators(setup_cst_formatter_args): + unformatted_code = """@decorator1 +@decorator2( arg1 , arg2 ) +def func1 (): + print( "This is a function with bad formatting") + +@another_decorator( arg) +def func2 ( x,y ): + print ( "This is another function with bad formatting" )""" + + expected_code = """@decorator1 +@decorator2( arg1 , arg2 ) +def func1 (): + print( "This is a function with bad formatting") + +@another_decorator(arg) +def func2(x, y): + print("This is another function with bad formatting") +""" + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "func2" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code + + +def test_formatting_function_with_syntax_error(setup_cst_formatter_args): + """shouldn't happen anyway, but just in case""" + unformatted_code = """def func1(): + print("This is a function with a syntax error" +def func2(): + print("This is another function with a syntax error") +""" + + expected_code = unformatted_code # No formatting should be applied due to syntax error + + optimizer, target_path, function_to_optimize = setup_cst_formatter_args( + unformatted_code, "func2" + ) + + formatted_code, _ = optimizer.reformat_code_and_helpers( + helper_functions=[], + path=target_path, + original_code=optimizer.function_to_optimize_source_code, + opt_func_name=function_to_optimize.function_name + ) + + assert formatted_code == expected_code diff --git a/tests/test_function_optimizer.py b/tests/test_function_optimizer.py deleted file mode 100644 index fc25ae691..000000000 --- a/tests/test_function_optimizer.py +++ /dev/null @@ -1,83 +0,0 @@ -import argparse -from pathlib import Path -import tempfile - -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.optimization.function_optimizer import FunctionOptimizer -from codeflash.verification.verification_utils import TestConfig - - -def test_reformat_code_and_helpers(): - """ - reformat_code_and_helpers should only format the code that is optimized not the whole file, to avoid large diffing - """ - with tempfile.TemporaryDirectory() as test_dir_str: - test_dir = Path(test_dir_str) - target_path = test_dir / "target.py" - unformatted_code = """import sys - - -def lol(): - print( "lol" ) - - - - -class MyClass: - def __init__(self, x=0): - self.x = x - - def lol(self): - print( "lol" ) - - def lol2 (self): - print( " lol2" )""" - expected_code = """import sys - - -def lol(): - print( "lol" ) - - - - -class MyClass: - def __init__(self, x=0): - self.x = x - - def lol(self): - print( "lol" ) - - def lol2(self): - print(" lol2") -""" - target_path.write_text(unformatted_code, encoding="utf-8") - function_to_optimize = FunctionToOptimize(function_name="MyClass.lol2", parents=[], file_path=target_path) - - test_cfg = TestConfig( - tests_root=test_dir, - project_root_path=test_dir, - test_framework="pytest", - tests_project_rootdir=test_dir, - ) - args = argparse.Namespace( - disable_imports_sorting=False, - formatter_cmds=[ - "ruff check --exit-zero --fix $file", - "ruff format $file" - ], - ) - optimizer = FunctionOptimizer( - function_to_optimize=function_to_optimize, - test_cfg=test_cfg, - args=args, - ) - - - formatted_code,_ = optimizer.reformat_code_and_helpers( - helper_functions=[], - path=target_path, - original_code=optimizer.function_to_optimize_source_code, - opt_func_name=function_to_optimize.function_name - ) - assert formatted_code == expected_code \ No newline at end of file