From 60bb705b9a94d3eaf21132e6f1f3afd6247ab083 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Mon, 10 Nov 2025 18:44:09 -0500 Subject: [PATCH 1/7] wip --- codeflash/code_utils/formatter.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 498a8078b..bbfd0668d 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -96,6 +96,20 @@ def is_diff_line(line: str) -> bool: return len(diff_lines) +def format_generated_code(generated_test_source: str, formatter_cmds: Union[list[str], None] = None) -> str: + formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" + if formatter_name == "disabled": + return re.sub(r"\n{2,}", "\n\n", generated_test_source) + # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) + original_temp, test_dir_str, exit_on_failure = None, None, True + formatted_temp, formatted_code, changed = apply_formatter_cmds( + formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure + ) + if not changed: + return re.sub(r"\n{2,}", "\n\n", formatted_code) + return formatted_code + + def format_code( formatter_cmds: list[str], path: Union[str, Path], @@ -120,7 +134,7 @@ def format_code( original_code_lines = len(original_code.split("\n")) if check_diff and original_code_lines > 50: - # we dont' count the formatting diff for the optimized function as it should be well-formatted + # we don't count the formatting diff for the optimized function as it should be well-formatted original_code_without_opfunc = original_code.replace(optimized_code, "") original_temp = Path(test_dir_str) / "original_temp.py" From e642720b494efa590bf7a9f86c67d235dab97aa0 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Mon, 10 Nov 2025 18:50:45 -0500 Subject: [PATCH 2/7] wip --- codeflash/code_utils/formatter.py | 25 ++++++++++---------- codeflash/optimization/function_optimizer.py | 8 ++++--- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index bbfd0668d..0d0c046de 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -96,18 +96,19 @@ def is_diff_line(line: str) -> bool: return len(diff_lines) -def format_generated_code(generated_test_source: str, formatter_cmds: Union[list[str], None] = None) -> str: - formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" - if formatter_name == "disabled": - return re.sub(r"\n{2,}", "\n\n", generated_test_source) - # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) - original_temp, test_dir_str, exit_on_failure = None, None, True - formatted_temp, formatted_code, changed = apply_formatter_cmds( - formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure - ) - if not changed: - return re.sub(r"\n{2,}", "\n\n", formatted_code) - return formatted_code +def format_generated_code(generated_test_source: str) -> str: + return re.sub(r"\n{2,}", "\n\n", generated_test_source) + # formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" + # if formatter_name == "disabled": + # return re.sub(r"\n{2,}", "\n\n", generated_test_source) + # # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) + # original_temp, test_dir_str, exit_on_failure = None, None, True + # formatted_temp, formatted_code, changed = apply_formatter_cmds( + # formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure + # ) + # if not changed: + # return re.sub(r"\n{2,}", "\n\n", formatted_code) + # return formatted_code def format_code( diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5f4ab8767..7df8fd96b 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -55,7 +55,7 @@ remove_functions_from_generated_tests, ) from codeflash.code_utils.env_utils import get_pr_number -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports from codeflash.code_utils.git_utils import git_root_dir from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports @@ -1413,11 +1413,13 @@ def process_review( generated_tests_str = "" for test in generated_tests.generated_tests: - generated_tests_str += f"```python\n{test.generated_original_test_source}\n```" + formatted_generated_test = format_generated_code(test.generated_original_test_source) + generated_tests_str += f"```python\n{formatted_generated_test}\n```" generated_tests_str += "\n\n" if concolic_test_str: - generated_tests_str += f"```python\n{concolic_test_str}\n```\n\n" + formatted_generated_test = format_generated_code(concolic_test_str) + generated_tests_str += f"```python\n{formatted_generated_test}\n```\n\n" existing_tests, replay_tests, concolic_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), From 410ef8e8feeef896251f758bc0817f28e8c83660 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 11 Nov 2025 20:25:23 -0500 Subject: [PATCH 3/7] inference test --- codeflash/code_utils/formatter.py | 27 +- tests/test_formatter.py | 615 +++++++++++++++++++++++++++++- 2 files changed, 628 insertions(+), 14 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 8a8acb2df..f877033d8 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -96,19 +96,20 @@ def is_diff_line(line: str) -> bool: return len(diff_lines) -def format_generated_code(generated_test_source: str) -> str: - return re.sub(r"\n{2,}", "\n\n", generated_test_source) - # formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" - # if formatter_name == "disabled": - # return re.sub(r"\n{2,}", "\n\n", generated_test_source) - # # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) - # original_temp, test_dir_str, exit_on_failure = None, None, True - # formatted_temp, formatted_code, changed = apply_formatter_cmds( - # formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=exit_on_failure - # ) - # if not changed: - # return re.sub(r"\n{2,}", "\n\n", formatted_code) - # return formatted_code +def format_generated_code(generated_test_source: str, formatter_cmds: Union[list[str], None]) -> str: + formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" + if formatter_name == "disabled": + return re.sub(r"\n{2,}", "\n\n", generated_test_source) + with tempfile.TemporaryDirectory() as test_dir_str: + # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) + original_temp = Path(test_dir_str) / "original_temp.py" + original_temp.write_text(generated_test_source, encoding="utf8") + _, formatted_code, changed = apply_formatter_cmds( + formatter_cmds, original_temp, test_dir_str, print_status=False, exit_on_failure=False + ) + if not changed: + return re.sub(r"\n{2,}", "\n\n", formatted_code) + return formatted_code def format_code( diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 1703f572b..2af0e296a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -7,7 +7,7 @@ import shutil from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, format_generated_code, sort_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import CodeString, CodeStringsMarkdown @@ -805,3 +805,616 @@ def test_sort_imports_skip_file(): import sys, os, json # isort will ignore this file completely""" new_code = sort_imports(code) assert new_code == code + + +# ==================== Tests for format_generated_code ==================== + +def test_format_generated_code_disabled(): + """Test that format_generated_code returns code with normalized newlines when formatter is disabled.""" + test_code = """import os + + +def test_function(): + pass + + + +def another_function(): + return 42""" + + # Test with None formatter + result = format_generated_code(test_code, None) + # Multiple newlines (3+) are reduced to 2 + expected = """import os + +def test_function(): + pass + +def another_function(): + return 42""" + assert result == expected + + # Test with ["disabled"] formatter + result = format_generated_code(test_code, ["disabled"]) + assert result == expected + + +def test_format_generated_code_disabled_case_insensitive(): + """Test that format_generated_code handles 'Disabled', 'DISABLED' etc.""" + test_code = """def test(): + + + pass""" + + # Multiple newlines are reduced to at most 2 + expected = """def test(): + + pass""" + + # Test various cases + assert format_generated_code(test_code, ["Disabled"]) == expected + assert format_generated_code(test_code, ["DISABLED"]) == expected + assert format_generated_code(test_code, ["DiSaBlEd"]) == expected + + +def test_format_generated_code_empty_string(): + """Test format_generated_code with empty string.""" + result = format_generated_code("", None) + assert result == "" + + result = format_generated_code("", ["disabled"]) + assert result == "" + + +def test_format_generated_code_with_black(): + """Test format_generated_code with black formatter.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """import os,sys +def test_function(x,y,z): + result=x+y+z + return result""" + + expected = """import os, sys + + +def test_function(x, y, z): + result = x + y + z + return result +""" + + result = format_generated_code(test_code, ["black $file"]) + assert result == expected + +def test_format_generated_code_with_inference(): + """Test format_generated_code with ruff formatter.""" + try: + import ruff # type: ignore + except ImportError: + pytest.skip("ruff is not installed") + + test_code = '''from time import sleep +from typing import List, Union + +# imports +import pytest +from inference.core.models.base import Model + +# --- Dummy classes to mimic the actual entities used in the function --- + +class InferenceRequest: + def __init__(self, image, visualize_predictions=False, id=None): + self.image = image + self.visualize_predictions = visualize_predictions + self.id = id + + def dict(self): + # Simulate the dict() method to unpack arguments for infer() + return { + "image": self.image, + "visualize_predictions": self.visualize_predictions, + "id": self.id + } + +class InferenceResponse: + def __init__(self, instances=None): + self.instances = instances if instances is not None else [] + self.time = None + self.visualization = None + self.inference_id = None +from inference.core.models.base import Model + +# --- Unit tests for infer_from_request --- + +@pytest.fixture +def model(): + # Returns a fresh instance of Model for each test + return Model() + +# -------------------------- +# 1. Basic Test Cases +# -------------------------- + + + + + + + + + + + +def test_visualization_true_but_no_draw_method(monkeypatch, model): + """Test with visualize_predictions=True but draw_predictions raises exception.""" + def broken_draw_predictions(request, response): + raise RuntimeError("Visualization failed") + monkeypatch.setattr(model, "draw_predictions", broken_draw_predictions) + req = InferenceRequest(image="img1", visualize_predictions=True) + with pytest.raises(RuntimeError): + model.infer_from_request(req) + + + + + +def test_large_image_list_empty_instances(model): + """Test with large image list and infer returns empty instances.""" + # Patch the model.infer to return responses with empty instances + def empty_infer(image, **kwargs): + if isinstance(image, list): + return [InferenceResponse(instances=[]) for _ in image] + return [InferenceResponse(instances=[])] + model.infer = empty_infer + images = [f"img_{i}" for i in range(900)] + req = InferenceRequest(image=images) + codeflash_output = model.infer_from_request(req); resp = codeflash_output # 1.42ms -> 471ΞΌs (201% faster) + for r in resp: + pass + + +#------------------------------------------------ +import time +from typing import Any, List, Tuple, Union + +# imports +import pytest +from inference.core.models.base import Model + +# --- Minimal stubs/mocks for dependencies --- + +class DummyLogger: + def debug(self, msg): + pass + +logger = DummyLogger() + +def perf_counter(): + # Use time.monotonic() for monotonic clock + return time.monotonic() + +# --- Entities and types --- + +class InferenceRequest: + def __init__(self, image, id=None, visualize_predictions=False, **kwargs): + self.image = image + self.id = id + self.visualize_predictions = visualize_predictions + self.kwargs = kwargs + def dict(self): + d = {"image": self.image} + d.update(self.kwargs) + return d + +class InferenceResponse: + def __init__(self, result=None): + self.result = result + self.time = None + self.inference_id = None + self.visualization = None +from inference.core.models.base import Model + +# --- Unit tests --- + +# 1. BASIC TEST CASES +''' + expected = '''from time import sleep +from typing import List, Union + +# imports +import pytest +from inference.core.models.base import Model + +# --- Dummy classes to mimic the actual entities used in the function --- + + +class InferenceRequest: + def __init__(self, image, visualize_predictions=False, id=None): + self.image = image + self.visualize_predictions = visualize_predictions + self.id = id + + def dict(self): + # Simulate the dict() method to unpack arguments for infer() + return {"image": self.image, "visualize_predictions": self.visualize_predictions, "id": self.id} + + +class InferenceResponse: + def __init__(self, instances=None): + self.instances = instances if instances is not None else [] + self.time = None + self.visualization = None + self.inference_id = None + + +from inference.core.models.base import Model + +# --- Unit tests for infer_from_request --- + + +@pytest.fixture +def model(): + # Returns a fresh instance of Model for each test + return Model() + + +# -------------------------- +# 1. Basic Test Cases +# -------------------------- + + +def test_visualization_true_but_no_draw_method(monkeypatch, model): + """Test with visualize_predictions=True but draw_predictions raises exception.""" + + def broken_draw_predictions(request, response): + raise RuntimeError("Visualization failed") + + monkeypatch.setattr(model, "draw_predictions", broken_draw_predictions) + req = InferenceRequest(image="img1", visualize_predictions=True) + with pytest.raises(RuntimeError): + model.infer_from_request(req) + + +def test_large_image_list_empty_instances(model): + """Test with large image list and infer returns empty instances.""" + + # Patch the model.infer to return responses with empty instances + def empty_infer(image, **kwargs): + if isinstance(image, list): + return [InferenceResponse(instances=[]) for _ in image] + return [InferenceResponse(instances=[])] + + model.infer = empty_infer + images = [f"img_{i}" for i in range(900)] + req = InferenceRequest(image=images) + codeflash_output = model.infer_from_request(req) + resp = codeflash_output # 1.42ms -> 471ΞΌs (201% faster) + for r in resp: + pass + + +# ------------------------------------------------ +import time +from typing import Any, List, Tuple, Union + +# imports +import pytest +from inference.core.models.base import Model + +# --- Minimal stubs/mocks for dependencies --- + + +class DummyLogger: + def debug(self, msg): + pass + + +logger = DummyLogger() + + +def perf_counter(): + # Use time.monotonic() for monotonic clock + return time.monotonic() + + +# --- Entities and types --- + + +class InferenceRequest: + def __init__(self, image, id=None, visualize_predictions=False, **kwargs): + self.image = image + self.id = id + self.visualize_predictions = visualize_predictions + self.kwargs = kwargs + + def dict(self): + d = {"image": self.image} + d.update(self.kwargs) + return d + + +class InferenceResponse: + def __init__(self, result=None): + self.result = result + self.time = None + self.inference_id = None + self.visualization = None + + +from inference.core.models.base import Model + +# --- Unit tests --- + +# 1. BASIC TEST CASES +''' + + result = format_generated_code(test_code, ["ruff format $file"]) + assert result == expected + +def test_format_generated_code_with_ruff(): + """Test format_generated_code with ruff formatter.""" + try: + import ruff # type: ignore + except ImportError: + pytest.skip("ruff is not installed") + + test_code = """import os,sys +def test_function(x,y,z): + result=x+y+z + return result""" + + expected = """import os, sys + + +def test_function(x, y, z): + result = x + y + z + return result +""" + + result = format_generated_code(test_code, ["ruff format $file"]) + assert result == expected + + +def test_format_generated_code_multiple_formatters(): + """Test format_generated_code with multiple formatter commands.""" + try: + import ruff # type: ignore + except ImportError: + pytest.skip("ruff is not installed") + + test_code = """import sys,os # wrong order +def test_function(x,y,z): + result=x+y+z + return result""" + + # Ruff format will fix spacing + result = format_generated_code(test_code, ["ruff format $file"]) + + # Check that formatting happened + assert "result = x + y + z" in result # spacing should be fixed + assert "def test_function(x, y, z):" in result # parameters should have spaces + + +def test_format_generated_code_invalid_formatter(): + """Test format_generated_code with non-existent formatter command.""" + test_code = """def test(): + pass""" + + # Should handle gracefully and return code with normalized newlines + result = format_generated_code(test_code, ["nonexistent_formatter $file"]) + assert result == """def test(): + pass""" + + +def test_format_generated_code_syntax_error(): + """Test format_generated_code with Python code containing syntax errors.""" + test_code = """def test(: # syntax error + pass""" + + # Formatter should fail but function should handle it gracefully + result = format_generated_code(test_code, ["black $file"]) + # Should return code with normalized newlines when formatting fails + assert result == """def test(: # syntax error + pass""" + + +def test_format_generated_code_already_formatted(): + """Test format_generated_code with already well-formatted code.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """import os +import sys + + +def test_function(x, y, z): + result = x + y + z + return result +""" + + # Code is already formatted, should return the same + result = format_generated_code(test_code, ["black $file"]) + assert result == test_code + + +def test_format_generated_code_with_tabs(): + """Test format_generated_code with code containing tabs.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """def test(): +\tif True: +\t\treturn 42 +\treturn 0""" + + # Black should convert tabs to spaces + result = format_generated_code(test_code, ["black $file"]) + assert "\t" not in result # No tabs should remain + assert " " in result # Should have spaces + + +def test_format_generated_code_trailing_whitespace(): + """Test format_generated_code removes trailing whitespace.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """def test(): + pass + """ + + result = format_generated_code(test_code, ["black $file"]) + lines = result.split('\n') + for line in lines: + assert line == line.rstrip(), f"Line has trailing whitespace: {repr(line)}" + + +def test_format_generated_code_preserves_comments(): + """Test format_generated_code preserves comments.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """# This is a module comment +import os # import os module + +def test(): + # This function does something + pass # TODO: implement this +""" + + result = format_generated_code(test_code, ["black $file"]) + assert "# This is a module comment" in result + assert "# import os module" in result + assert "# This function does something" in result + assert "# TODO: implement this" in result + + +def test_format_generated_code_with_docstrings(): + """Test format_generated_code handles docstrings correctly.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = '''def test(): + """This is a docstring.""" + pass + +class TestClass: + """ + Multi-line + docstring + """ + def method(self): + \'\'\'Single quote docstring\'\'\' + pass''' + + result = format_generated_code(test_code, ["black $file"]) + assert '"""This is a docstring."""' in result + assert "Multi-line" in result + assert "docstring" in result + + +def test_format_generated_code_normalizes_multiple_newlines(): + """Test that multiple consecutive newlines are normalized to two.""" + test_code = """import os + + + + +def func1(): + pass + + + +def func2(): + pass""" + + result = format_generated_code(test_code, None) + # Should have at most two consecutive newlines + assert "\n\n\n" not in result + assert "import os\n\n" in result + assert "pass\n\n" in result + + +def test_format_generated_code_complex_code(): + """Test format_generated_code with complex real-world code.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """import unittest +from unittest.mock import patch,Mock,MagicMock +import os,sys +from typing import Dict,List,Optional + +class TestComplexClass(unittest.TestCase): + def setUp(self): + self.config={'key1':'value1','key2':'value2'} + self.data=[{'id':1,'name':'test1'},{'id':2,'name':'test2'}] + + def test_something(self): + result=process_data(self.data,lambda x:x['id']>0) + self.assertEqual(len(result),2) + + @patch('module.function') + def test_with_mock(self,mock_func): + mock_func.return_value={'status':'ok'} + response=make_request() + self.assertEqual(response['status'],'ok') + +def process_data(data:List[Dict],filter_func)->List[Dict]: + return [item for item in data if filter_func(item)]""" + + result = format_generated_code(test_code, ["black $file"]) + + # Check that formatting was applied + assert "self.config = {" in result + assert "self.data = [" in result + assert "result = process_data" in result + assert "mock_func.return_value = {" in result + # Check imports are formatted + assert "from unittest.mock import " in result + assert "from typing import Dict, List, Optional" in result + + +def test_format_generated_code_unicode(): + """Test format_generated_code with Unicode characters.""" + test_code = """def test(): + message = "Hello, δΈ–η•Œ! 🌍" + return message""" + + result = format_generated_code(test_code, None) + assert "Hello, δΈ–η•Œ! 🌍" in result + + +def test_format_generated_code_f_strings(): + """Test format_generated_code with f-strings.""" + try: + import black + except ImportError: + pytest.skip("black is not installed") + + test_code = """def test(name,age): + return f"Hello {name}, you are {age} years old" + +def test2(): + x=10 + y=20 + return f"{x}+{y}={x+y}" """ + + result = format_generated_code(test_code, ["black $file"]) + assert 'f"Hello {name}, you are {age} years old"' in result + assert "x = 10" in result + assert "y = 20" in result From ee37f9a405c6bd8f4d93afa39a6eca1188a87a13 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 11 Nov 2025 20:27:24 -0500 Subject: [PATCH 4/7] inference test --- codeflash/code_utils/formatter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index f877033d8..5d7b23d35 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -96,7 +96,7 @@ def is_diff_line(line: str) -> bool: return len(diff_lines) -def format_generated_code(generated_test_source: str, formatter_cmds: Union[list[str], None]) -> str: +def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str: formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" if formatter_name == "disabled": return re.sub(r"\n{2,}", "\n\n", generated_test_source) From 41ebb4182467f5f14f654725fa287d50c803da8a Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 11 Nov 2025 20:29:52 -0500 Subject: [PATCH 5/7] inference test --- codeflash/optimization/function_optimizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7df8fd96b..387babf54 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1413,12 +1413,14 @@ def process_review( generated_tests_str = "" for test in generated_tests.generated_tests: - formatted_generated_test = format_generated_code(test.generated_original_test_source) + formatted_generated_test = format_generated_code( + test.generated_original_test_source, self.args.formatter_cmds + ) generated_tests_str += f"```python\n{formatted_generated_test}\n```" generated_tests_str += "\n\n" if concolic_test_str: - formatted_generated_test = format_generated_code(concolic_test_str) + formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds) generated_tests_str += f"```python\n{formatted_generated_test}\n```\n\n" existing_tests, replay_tests, concolic_tests = existing_tests_source_for( From 7bff589e1101154e6d58566f6726b9cf987393d1 Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 11 Nov 2025 18:50:45 -0800 Subject: [PATCH 6/7] Update formatter.py --- codeflash/code_utils/formatter.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 5d7b23d35..82bbf0067 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -97,11 +97,8 @@ def is_diff_line(line: str) -> bool: def format_generated_code(generated_test_source: str, formatter_cmds: list[str]) -> str: - formatter_name = formatter_cmds[0].lower() if formatter_cmds else "disabled" - if formatter_name == "disabled": - return re.sub(r"\n{2,}", "\n\n", generated_test_source) with tempfile.TemporaryDirectory() as test_dir_str: - # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) + # try running formatter, if nothing changes (could be due to formatting failing or no actual formatting needed) return code with 2 or more newlines substituted with 2 newlines original_temp = Path(test_dir_str) / "original_temp.py" original_temp.write_text(generated_test_source, encoding="utf8") _, formatted_code, changed = apply_formatter_cmds( From a7e3b3f2c1bae0f7ad4892be534ddd774e7b679d Mon Sep 17 00:00:00 2001 From: Aseem Saxena Date: Tue, 11 Nov 2025 22:02:25 -0500 Subject: [PATCH 7/7] None not supported for formatter commands --- tests/test_formatter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 2af0e296a..635a7d10b 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -823,7 +823,7 @@ def another_function(): return 42""" # Test with None formatter - result = format_generated_code(test_code, None) + result = format_generated_code(test_code, ["disabled"]) # Multiple newlines (3+) are reduced to 2 expected = """import os @@ -859,7 +859,7 @@ def test_format_generated_code_disabled_case_insensitive(): def test_format_generated_code_empty_string(): """Test format_generated_code with empty string.""" - result = format_generated_code("", None) + result = format_generated_code("", ["disabled"]) assert result == "" result = format_generated_code("", ["disabled"]) @@ -1340,7 +1340,7 @@ def func1(): def func2(): pass""" - result = format_generated_code(test_code, None) + result = format_generated_code(test_code, ["disabled"]) # Should have at most two consecutive newlines assert "\n\n\n" not in result assert "import os\n\n" in result @@ -1395,7 +1395,7 @@ def test_format_generated_code_unicode(): message = "Hello, δΈ–η•Œ! 🌍" return message""" - result = format_generated_code(test_code, None) + result = format_generated_code(test_code, ["disabled"]) assert "Hello, δΈ–η•Œ! 🌍" in result