diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 3ff487517..82bbf0067 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -96,6 +96,19 @@ def is_diff_line(line: str) -> bool: return len(diff_lines) +def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str: + with tempfile.TemporaryDirectory() as test_dir_str: + # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines + original_temp = Path(test_dir_str) / "original_temp.py" + original_temp.write_text(generated_test_source, encoding="utf8") + _, formatted_code, changed = apply_formatter_cmds( + formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False + ) + if not changed: + return re.sub(r"\n{2,}", "\n\n", formatted_code) + return formatted_code + + def format_code( formatter_cmds: list[str], path: Union[str, Path], @@ -120,7 +133,7 @@ def format_code( original_code_lines = len(original_code.split("\n")) if check_diff and original_code_lines > 50: - # we dont' count the formatting diff for the optimized function as it should be well-formatted + # we don't count the formatting diff for the optimized function as it should be well-formatted original_code_without_opfunc = original_code.replace(optimized_code, "") original_temp = Path(test_dir_str) / "original_temp.py" diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5f4ab8767..387babf54 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -55,7 +55,7 @@ remove_functions_from_generated_tests, ) from codeflash.code_utils.env_utils import get_pr_number -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports from codeflash.code_utils.git_utils import git_root_dir from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports @@ -1413,11 +1413,15 @@ def process_review( generated_tests_str = "" for test in generated_tests.generated_tests: - generated_tests_str += f"```python\n{test.generated_original_test_source}\n```" + formatted_generated_test = format_generated_code( + test.generated_original_test_source, self.args.formatter_cmds + ) + generated_tests_str += f"```python\n{formatted_generated_test}\n```" generated_tests_str += "\n\n" if concolic_test_str: - generated_tests_str += f"```python\n{concolic_test_str}\n```\n\n" + formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds) + generated_tests_str += f"```python\n{formatted_generated_test}\n```\n\n" existing_tests, replay_tests, concolic_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 1703f572b..635a7d10b 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -7,7 +7,7 @@ import shutil from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeString, CodeStringsMarkdown @@ -805,3 +805,616 @@ def test_sort_imports_skip_file(): import sys, os, json # isort will ignore this file completely""" new_code = sort_imports(code) assert new_code == code + + +# ==================== Tests for format_generated_code ==================== + +def test_format_generated_code_disabled(): + """Test that format_generated_code returns code with normalized newlines when formatter is disabled.""" + test_code = """import os + + +def test_function(): + pass + + + +def another_function(): + return 42""" + + # Test with None formatter + result = format_generated_code(test_code, ["disabled"]) + # Multiple newlines (3+) are reduced to 2 + expected = """import os + +def test_function(): + pass + +def another_function(): + return 42""" + assert result == expected + + # Test with ["disabled"] formatter + result = format_generated_code(test_code, ["disabled"]) + assert result == expected + + +def test_format_generated_code_disabled_case_insensitive(): + """Test that format_generated_code handles 'Disabled', 'DISABLED' etc.""" + test_code = """def test(): + + + pass""" + + # Multiple newlines are reduced to at most 2 + expected = """def test(): + + pass""" + + # Test various cases + assert format_generated_code(test_code, ["Disabled"]) == expected + assert format_generated_code(test_code, ["DISABLED"]) == expected + assert format_generated_code(test_code, ["DiSaBlEd"]) == expected + + +def test_format_generated_code_empty_string(): + """Test format_generated_code with empty string.""" + result = format_generated_code("", ["disabled"]) + assert result == "" + + result = format_generated_code("", ["disabled"]) + assert result == "" + + +def test_format_generated_code_with_black(): + """Test format_generated_code with black formatter.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """import os,sys +def test_function(x,y,z): + result=x+y+z + return result""" + + expected = """import os, sys + + +def test_function(x, y, z): + result = x + y + z + return result +""" + + result = format_generated_code(test_code, ["black $file"]) + assert result == expected + +def test_format_generated_code_with_inference(): + """Test format_generated_code with ruff formatter.""" + try: + import ruff # type: ignore + except ImportError: + pytest.skip("ruff is not installed") + + test_code = '''from time import sleep +from typing import List, Union + +# imports +import pytest +from inference.core.models.base import Model + +# --- Dummy classes to mimic the actual entities used in the function --- + +class InferenceRequest: + def __init__(self, image, visualize_predictions=False, id=None): + self.image = image + self.visualize_predictions = visualize_predictions + self.id = id + + def dict(self): + # Simulate the dict() method to unpack arguments for infer() + return { + "image": self.image, + "visualize_predictions": self.visualize_predictions, + "id": self.id + } + +class InferenceResponse: + def __init__(self, instances=None): + self.instances = instances if instances is not None else [] + self.time = None + self.visualization = None + self.inference_id = None +from inference.core.models.base import Model + +# --- Unit tests for infer_from_request --- + +@pytest.fixture +def model(): + # Returns a fresh instance of Model for each test + return Model() + +# -------------------------- +# 1. Basic Test Cases +# -------------------------- + + + + + + + + + + + +def test_visualization_true_but_no_draw_method(monkeypatch, model): + """Test with visualize_predictions=True but draw_predictions raises exception.""" + def broken_draw_predictions(request, response): + raise RuntimeError("Visualization failed") + monkeypatch.setattr(model, "draw_predictions", broken_draw_predictions) + req = InferenceRequest(image="img1", visualize_predictions=True) + with pytest.raises(RuntimeError): + model.infer_from_request(req) + + + + + +def test_large_image_list_empty_instances(model): + """Test with large image list and infer returns empty instances.""" + # Patch the model.infer to return responses with empty instances + def empty_infer(image, **kwargs): + if isinstance(image, list): + return [InferenceResponse(instances=[]) for _ in image] + return [InferenceResponse(instances=[])] + model.infer = empty_infer + images = [f"img_{i}" for i in range(900)] + req = InferenceRequest(image=images) + codeflash_output = model.infer_from_request(req); resp = codeflash_output # 1.42ms -> 471μs (201% faster) + for r in resp: + pass + + +#------------------------------------------------ +import time +from typing import Any, List, Tuple, Union + +# imports +import pytest +from inference.core.models.base import Model + +# --- Minimal stubs/mocks for dependencies --- + +class DummyLogger: + def debug(self, msg): + pass + +logger = DummyLogger() + +def perf_counter(): + # Use time.monotonic() for monotonic clock + return time.monotonic() + +# --- Entities and types --- + +class InferenceRequest: + def __init__(self, image, id=None, visualize_predictions=False, **kwargs): + self.image = image + self.id = id + self.visualize_predictions = visualize_predictions + self.kwargs = kwargs + def dict(self): + d = {"image": self.image} + d.update(self.kwargs) + return d + +class InferenceResponse: + def __init__(self, result=None): + self.result = result + self.time = None + self.inference_id = None + self.visualization = None +from inference.core.models.base import Model + +# --- Unit tests --- + +# 1. BASIC TEST CASES +''' + expected = '''from time import sleep +from typing import List, Union + +# imports +import pytest +from inference.core.models.base import Model + +# --- Dummy classes to mimic the actual entities used in the function --- + + +class InferenceRequest: + def __init__(self, image, visualize_predictions=False, id=None): + self.image = image + self.visualize_predictions = visualize_predictions + self.id = id + + def dict(self): + # Simulate the dict() method to unpack arguments for infer() + return {"image": self.image, "visualize_predictions": self.visualize_predictions, "id": self.id} + + +class InferenceResponse: + def __init__(self, instances=None): + self.instances = instances if instances is not None else [] + self.time = None + self.visualization = None + self.inference_id = None + + +from inference.core.models.base import Model + +# --- Unit tests for infer_from_request --- + + +@pytest.fixture +def model(): + # Returns a fresh instance of Model for each test + return Model() + + +# -------------------------- +# 1. Basic Test Cases +# -------------------------- + + +def test_visualization_true_but_no_draw_method(monkeypatch, model): + """Test with visualize_predictions=True but draw_predictions raises exception.""" + + def broken_draw_predictions(request, response): + raise RuntimeError("Visualization failed") + + monkeypatch.setattr(model, "draw_predictions", broken_draw_predictions) + req = InferenceRequest(image="img1", visualize_predictions=True) + with pytest.raises(RuntimeError): + model.infer_from_request(req) + + +def test_large_image_list_empty_instances(model): + """Test with large image list and infer returns empty instances.""" + + # Patch the model.infer to return responses with empty instances + def empty_infer(image, **kwargs): + if isinstance(image, list): + return [InferenceResponse(instances=[]) for _ in image] + return [InferenceResponse(instances=[])] + + model.infer = empty_infer + images = [f"img_{i}" for i in range(900)] + req = InferenceRequest(image=images) + codeflash_output = model.infer_from_request(req) + resp = codeflash_output # 1.42ms -> 471μs (201% faster) + for r in resp: + pass + + +# ------------------------------------------------ +import time +from typing import Any, List, Tuple, Union + +# imports +import pytest +from inference.core.models.base import Model + +# --- Minimal stubs/mocks for dependencies --- + + +class DummyLogger: + def debug(self, msg): + pass + + +logger = DummyLogger() + + +def perf_counter(): + # Use time.monotonic() for monotonic clock + return time.monotonic() + + +# --- Entities and types --- + + +class InferenceRequest: + def __init__(self, image, id=None, visualize_predictions=False, **kwargs): + self.image = image + self.id = id + self.visualize_predictions = visualize_predictions + self.kwargs = kwargs + + def dict(self): + d = {"image": self.image} + d.update(self.kwargs) + return d + + +class InferenceResponse: + def __init__(self, result=None): + self.result = result + self.time = None + self.inference_id = None + self.visualization = None + + +from inference.core.models.base import Model + +# --- Unit tests --- + +# 1. BASIC TEST CASES +''' + + result = format_generated_code(test_code, ["ruff format $file"]) + assert result == expected + +def test_format_generated_code_with_ruff(): + """Test format_generated_code with ruff formatter.""" + try: + import ruff # type: ignore + except ImportError: + pytest.skip("ruff is not installed") + + test_code = """import os,sys +def test_function(x,y,z): + result=x+y+z + return result""" + + expected = """import os, sys + + +def test_function(x, y, z): + result = x + y + z + return result +""" + + result = format_generated_code(test_code, ["ruff format $file"]) + assert result == expected + + +def test_format_generated_code_multiple_formatters(): + """Test format_generated_code with multiple formatter commands.""" + try: + import ruff # type: ignore + except ImportError: + pytest.skip("ruff is not installed") + + test_code = """import sys,os # wrong order +def test_function(x,y,z): + result=x+y+z + return result""" + + # Ruff format will fix spacing + result = format_generated_code(test_code, ["ruff format $file"]) + + # Check that formatting happened + assert "result = x + y + z" in result # spacing should be fixed + assert "def test_function(x, y, z):" in result # parameters should have spaces + + +def test_format_generated_code_invalid_formatter(): + """Test format_generated_code with non-existent formatter command.""" + test_code = """def test(): + pass""" + + # Should handle gracefully and return code with normalized newlines + result = format_generated_code(test_code, ["nonexistent_formatter $file"]) + assert result == """def test(): + pass""" + + +def test_format_generated_code_syntax_error(): + """Test format_generated_code with Python code containing syntax errors.""" + test_code = """def test(: # syntax error + pass""" + + # Formatter should fail but function should handle it gracefully + result = format_generated_code(test_code, ["black $file"]) + # Should return code with normalized newlines when formatting fails + assert result == """def test(: # syntax error + pass""" + + +def test_format_generated_code_already_formatted(): + """Test format_generated_code with already well-formatted code.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """import os +import sys + + +def test_function(x, y, z): + result = x + y + z + return result +""" + + # Code is already formatted, should return the same + result = format_generated_code(test_code, ["black $file"]) + assert result == test_code + + +def test_format_generated_code_with_tabs(): + """Test format_generated_code with code containing tabs.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """def test(): +\tif True: +\t\treturn 42 +\treturn 0""" + + # Black should convert tabs to spaces + result = format_generated_code(test_code, ["black $file"]) + assert "\t" not in result # No tabs should remain + assert " " in result # Should have spaces + + +def test_format_generated_code_trailing_whitespace(): + """Test format_generated_code removes trailing whitespace.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """def test(): + pass + """ + + result = format_generated_code(test_code, ["black $file"]) + lines = result.split('\n') + for line in lines: + assert line == line.rstrip(), f"Line has trailing whitespace: {repr(line)}" + + +def test_format_generated_code_preserves_comments(): + """Test format_generated_code preserves comments.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """# This is a module comment +import os # import os module + +def test(): + # This function does something + pass # TODO: implement this +""" + + result = format_generated_code(test_code, ["black $file"]) + assert "# This is a module comment" in result + assert "# import os module" in result + assert "# This function does something" in result + assert "# TODO: implement this" in result + + +def test_format_generated_code_with_docstrings(): + """Test format_generated_code handles docstrings correctly.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = '''def test(): + """This is a docstring.""" + pass + +class TestClass: + """ + Multi-line + docstring + """ + def method(self): + \'\'\'Single quote docstring\'\'\' + pass''' + + result = format_generated_code(test_code, ["black $file"]) + assert '"""This is a docstring."""' in result + assert "Multi-line" in result + assert "docstring" in result + + +def test_format_generated_code_normalizes_multiple_newlines(): + """Test that multiple consecutive newlines are normalized to two.""" + test_code = """import os + + + + +def func1(): + pass + + + +def func2(): + pass""" + + result = format_generated_code(test_code, ["disabled"]) + # Should have at most two consecutive newlines + assert "\n\n\n" not in result + assert "import os\n\n" in result + assert "pass\n\n" in result + + +def test_format_generated_code_complex_code(): + """Test format_generated_code with complex real-world code.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """import unittest +from unittest.mock import patch,Mock,MagicMock +import os,sys +from typing import Dict,List,Optional + +class TestComplexClass(unittest.TestCase): + def setUp(self): + self.config={'key1':'value1','key2':'value2'} + self.data=[{'id':1,'name':'test1'},{'id':2,'name':'test2'}] + + def test_something(self): + result=process_data(self.data,lambda x:x['id']>0) + self.assertEqual(len(result),2) + + @patch('module.function') + def test_with_mock(self,mock_func): + mock_func.return_value={'status':'ok'} + response=make_request() + self.assertEqual(response['status'],'ok') + +def process_data(data:List[Dict],filter_func)->List[Dict]: + return [item for item in data if filter_func(item)]""" + + result = format_generated_code(test_code, ["black $file"]) + + # Check that formatting was applied + assert "self.config = {" in result + assert "self.data = [" in result + assert "result = process_data" in result + assert "mock_func.return_value = {" in result + # Check imports are formatted + assert "from unittest.mock import " in result + assert "from typing import Dict, List, Optional" in result + + +def test_format_generated_code_unicode(): + """Test format_generated_code with Unicode characters.""" + test_code = """def test(): + message = "Hello, 世界! 🌍" + return message""" + + result = format_generated_code(test_code, ["disabled"]) + assert "Hello, 世界! 🌍" in result + + +def test_format_generated_code_f_strings(): + """Test format_generated_code with f-strings.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """def test(name,age): + return f"Hello {name}, you are {age} years old" + +def test2(): + x=10 + y=20 + return f"{x}+{y}={x+y}" """ + + result = format_generated_code(test_code, ["black $file"]) + assert 'f"Hello {name}, you are {age} years old"' in result + assert "x = 10" in result + assert "y = 20" in result