diff --git a/.github/workflows/unit-tests.yaml b/.github/workflows/unit-tests.yaml index a1e7da8ea..0540b29d3 100644 --- a/.github/workflows/unit-tests.yaml +++ b/.github/workflows/unit-tests.yaml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [3.9.18, 3.10.13, 3.11.6, 3.12.1, 3.13.0] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] continue-on-error: true runs-on: ubuntu-latest steps: @@ -32,7 +32,7 @@ jobs: run: uvx poetry install --with dev - name: Unit tests - run: uvx poetry run pytest tests/ --cov --cov-report=xml + run: uvx poetry run pytest tests/ --cov --cov-report=xml --disable-warnings - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 diff --git a/.gitignore b/.gitignore index 535acfb3e..8aab8f748 100644 --- a/.gitignore +++ b/.gitignore @@ -124,7 +124,7 @@ celerybeat.pid # Environments .env **/.env -.venv +.venv* env/ venv/ ENV/ diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 875fd0a1f..d60b45961 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -3,8 +3,10 @@ import os import shlex import subprocess +from functools import partial from typing import TYPE_CHECKING +import black import isort from codeflash.cli_cmds.console import console, logger @@ -12,6 +14,8 @@ if TYPE_CHECKING: from pathlib import Path +imports_sort = partial(isort.code, float_to_top=True) + def format_code(formatter_cmds: list[str], path: Path) -> str: # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution @@ -46,12 +50,19 @@ def format_code(formatter_cmds: list[str], path: Path) -> str: return path.read_text(encoding="utf8") -def sort_imports(code: str) -> str: +def format_code_in_memory(code: str, *, imports_only: bool = False) -> str: + if imports_only: + try: + sorted_code = imports_sort(code) + except Exception: # noqa: BLE001 + logger.debug("Failed to sort imports with isort.") + return code + return sorted_code try: - # Deduplicate and sort imports, modify the code in memory, not on disk - sorted_code = isort.code(code) - except Exception: - logger.exception("Failed to sort imports with isort.") - return code # Fall back to original code if isort fails + formatted_code = black.format_str(code, mode=black.FileMode()) + formatted_code = imports_sort(formatted_code) + except Exception: # noqa: BLE001 + logger.debug("Failed to format code with black.") + return code - return sorted_code + return formatted_code diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index e691c2ed2..cd869d944 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -4,10 +4,9 @@ from pathlib import Path from typing import TYPE_CHECKING -import isort - from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path +from codeflash.code_utils.formatter import format_code_in_memory from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent, TestingMode, VerificationType @@ -355,8 +354,7 @@ def inject_profiling_into_existing_test( if test_framework == "unittest": new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")])) tree.body = [*new_imports, create_wrapper_function(mode), *tree.body] - return True, isort.code(ast.unparse(tree), float_to_top=True) - + return True, format_code_in_memory(ast.unparse(tree)) def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef: lineno = 1 diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7fa8805c9..6272b86fc 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -35,7 +35,7 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, format_code_in_memory from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests @@ -541,14 +541,14 @@ def reformat_code_and_helpers( new_code = format_code(self.args.formatter_cmds, path) if should_sort_imports: - new_code = sort_imports(new_code) + new_code = format_code_in_memory(new_code, imports_only=True) new_helper_code: dict[Path, str] = {} helper_functions_paths = {hf.file_path for hf in helper_functions} for module_abspath in helper_functions_paths: formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) if should_sort_imports: - formatted_helper_code = sort_imports(formatted_helper_code) + formatted_helper_code = format_code_in_memory(formatted_helper_code, imports_only=True) new_helper_code[module_abspath] = formatted_helper_code return new_code, new_helper_code diff --git a/pyproject.toml b/pyproject.toml index 8a3d0d523..a2284940c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ exclude = [ [tool.poetry.dependencies] python = ">=3.9" unidiff = ">=0.7.4" -pytest = ">=7.0.0,<8.3.4" +pytest = ">=7.0.0" gitpython = ">=3.1.31" libcst = ">=1.0.1" jedi = ">=0.19.1" @@ -93,6 +93,7 @@ lxml = ">=5.3.0" crosshair-tool = ">=0.0.78" coverage = ">=7.6.4" line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13 +black = "^25.1.0" [tool.poetry.group.dev] optional = true diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..8d4869136 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -5,13 +5,13 @@ import pytest from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, format_code_in_memory def test_remove_duplicate_imports(): """Test that duplicate imports are removed when should_sort_imports is True.""" original_code = "import os\nimport os\n" - new_code = sort_imports(original_code) + new_code = format_code_in_memory(original_code, imports_only=True) assert new_code == "import os\n" @@ -19,7 +19,7 @@ def test_remove_multiple_duplicate_imports(): """Test that multiple duplicate imports are removed when should_sort_imports is True.""" original_code = "import sys\nimport os\nimport sys\n" - new_code = sort_imports(original_code) + new_code = format_code_in_memory(original_code, imports_only=True) assert new_code == "import os\nimport sys\n" @@ -27,7 +27,7 @@ def test_sorting_imports(): """Test that imports are sorted when should_sort_imports is True.""" original_code = "import sys\nimport unittest\nimport os\n" - new_code = sort_imports(original_code) + new_code = format_code_in_memory(original_code, imports_only=True) assert new_code == "import os\nimport sys\nimport unittest\n" @@ -40,7 +40,7 @@ def test_sort_imports_without_formatting(): new_code = format_code(formatter_cmds=["disabled"], path=tmp_path) assert new_code is not None - new_code = sort_imports(new_code) + new_code = format_code_in_memory(new_code, imports_only=True) assert new_code == "import os\nimport sys\nimport unittest\n" @@ -63,7 +63,7 @@ def foo(): return os.path.join(sys.path[0], 'bar') """ - actual = sort_imports(original_code) + actual = format_code_in_memory(original_code, imports_only=True) assert actual == expected @@ -90,7 +90,7 @@ def foo(): return os.path.join(sys.path[0], 'bar') """ - actual = sort_imports(original_code) + actual = format_code_in_memory(original_code, imports_only=True) assert actual == expected diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index 5bc942fdd..c2870d45d 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -12,6 +12,7 @@ from codeflash.models.models import CodePosition, FunctionParent, TestFile, TestFiles, TestingMode, TestType from codeflash.optimization.optimizer import Optimizer from codeflash.verification.equivalence import compare_test_results +from codeflash.code_utils.formatter import format_code_in_memory from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture # Used by cli instrumentation @@ -119,10 +120,12 @@ def test_sort(): os.chdir(original_cwd) assert success assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory( + expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + ) + ) with test_path.open("w") as f: f.write(new_test) @@ -307,9 +310,9 @@ def test_sort(): Path(f.name), [CodePosition(7, 13), CodePosition(12, 13)], fto, Path(f.name).parent, "pytest" ) assert success - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values")) - ).replace('"', "'") + )) tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() test_path = tests_root / "test_class_method_behavior_results_temp.py" test_path_perf = tests_root / "test_class_method_behavior_results_perf_temp.py" diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index 44661912a..9b413ad26 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -8,6 +8,7 @@ from pathlib import Path from codeflash.code_utils.code_utils import get_run_tmp_file +from codeflash.code_utils.formatter import format_code_in_memory from codeflash.code_utils.instrument_existing_tests import ( FunctionImportedAsVisitor, inject_profiling_into_existing_test, @@ -102,7 +103,7 @@ def test_sort(self): input = list(reversed(range(5000))) self.assertEqual(sorter(input), list(range(5000))) """ - expected = """import gc + expected = '''import gc import os import sqlite3 import time @@ -114,21 +115,37 @@ def test_sort(self): from code_to_optimize.bubble_sort import sorter -def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): - test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' - if not hasattr(codeflash_wrap, 'index'): +def codeflash_wrap( + wrapped, + test_module_name, + test_class_name, + test_name, + function_name, + line_id, + loop_index, + codeflash_cur, + codeflash_con, + *args, + **kwargs, +): + test_id = f"{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}" + if not hasattr(codeflash_wrap, "index"): codeflash_wrap.index = {{}} if test_id in codeflash_wrap.index: codeflash_wrap.index[test_id] += 1 else: codeflash_wrap.index[test_id] = 0 codeflash_test_index = codeflash_wrap.index[test_id] - invocation_id = f'{{line_id}}_{{codeflash_test_index}}' - """ + invocation_id = f"{{line_id}}_{{codeflash_test_index}}" +''' if sys.version_info < (3, 12): - expected += """print(f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!")""" + expected += """ print( + f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!" + )""" else: - expected += """print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')""" + expected += """ print( + f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!" + )""" expected += """ exception = None gc.disable() @@ -140,30 +157,88 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi codeflash_duration = time.perf_counter_ns() - counter exception = e gc.enable() - pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) - codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + pickled_return_value = ( + pickle.dumps(exception) if exception else pickle.dumps(return_value) + ) + codeflash_cur.execute( + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + test_module_name, + test_class_name, + test_name, + function_name, + loop_index, + invocation_id, + codeflash_duration, + pickled_return_value, + "function_call", + ), + ) codeflash_con.commit() if exception: raise exception return return_value + + + class TestPigLatin(unittest.TestCase): @timeout_decorator.timeout(15) def test_sort(self): - codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) - codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION'] - codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite') + codeflash_loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + codeflash_iteration = os.environ["CODEFLASH_TEST_ITERATION"] + codeflash_con = sqlite3.connect( + f"{tmp_dir_path}_{{codeflash_iteration}}.sqlite" + ) codeflash_cur = codeflash_con.cursor() - codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') + codeflash_cur.execute( + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap( + sorter, + "{module_path}", + "TestPigLatin", + "test_sort", + "sorter", + "1", + codeflash_loop_index, + codeflash_cur, + codeflash_con, + input, + ) self.assertEqual(output, [0, 1, 2, 3, 4, 5]) input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap( + sorter, + "{module_path}", + "TestPigLatin", + "test_sort", + "sorter", + "4", + codeflash_loop_index, + codeflash_cur, + codeflash_con, + input, + ) self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) input = list(reversed(range(5000))) - self.assertEqual(codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input), list(range(5000))) + self.assertEqual( + codeflash_wrap( + sorter, + "{module_path}", + "TestPigLatin", + "test_sort", + "sorter", + "7", + codeflash_loop_index, + codeflash_cur, + codeflash_con, + input, + ), + list(range(5000)), + ) codeflash_con.close() """ with tempfile.NamedTemporaryFile(mode="w") as f: @@ -200,37 +275,53 @@ def test_prepare_image_for_yolo(): ret = packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo(**args) assert compare_results(return_val_1, ret) """ - expected = """import gc + expected = '''import gc import os import sqlite3 import time import dill as pickle import pytest -from packagename.ml.yolo.image_reshaping_utils import \\ - prepare_image_for_yolo as \\ - packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo +from packagename.ml.yolo.image_reshaping_utils import ( + prepare_image_for_yolo as packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo, +) from codeflash.tracing.replay_test import get_next_arg_and_return from codeflash.validation.equivalence import compare_results -def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): - test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' - if not hasattr(codeflash_wrap, 'index'): +def codeflash_wrap( + wrapped, + test_module_name, + test_class_name, + test_name, + function_name, + line_id, + loop_index, + codeflash_cur, + codeflash_con, + *args, + **kwargs, +): + test_id = f"{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}" + if not hasattr(codeflash_wrap, "index"): codeflash_wrap.index = {{}} if test_id in codeflash_wrap.index: codeflash_wrap.index[test_id] += 1 else: codeflash_wrap.index[test_id] = 0 codeflash_test_index = codeflash_wrap.index[test_id] - invocation_id = f'{{line_id}}_{{codeflash_test_index}}' - """ + invocation_id = f"{{line_id}}_{{codeflash_test_index}}" +''' if sys.version_info < (3, 12): - expected += """print(f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!")""" + expected += """ print( + f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!" + )""" else: - expected += """print(f'!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!')""" - expected += """ + expected += """ print( + f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}######!" + )""" + expected += ''' exception = None gc.disable() try: @@ -241,26 +332,45 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi codeflash_duration = time.perf_counter_ns() - counter exception = e gc.enable() - pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value) - codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) + pickled_return_value = ( + pickle.dumps(exception) if exception else pickle.dumps(return_value) + ) + codeflash_cur.execute( + "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + test_module_name, + test_class_name, + test_name, + function_name, + loop_index, + invocation_id, + codeflash_duration, + pickled_return_value, + "function_call", + ), + ) codeflash_con.commit() if exception: raise exception return return_value def test_prepare_image_for_yolo(): - codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) - codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION'] - codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite') + codeflash_loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"]) + codeflash_iteration = os.environ["CODEFLASH_TEST_ITERATION"] + codeflash_con = sqlite3.connect(f"{tmp_dir_path}_{{codeflash_iteration}}.sqlite") codeflash_cur = codeflash_con.cursor() - codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') -""" - if sys.version_info < (3, 11): + codeflash_cur.execute( + "CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)" + ) +''' + if sys.version_info < (3, 12): expected += """ for (arg_val_pkl, return_val_pkl) in get_next_arg_and_return('/home/saurabh/packagename/traces/first.trace', 3): """ else: - expected += """ for arg_val_pkl, return_val_pkl in get_next_arg_and_return('/home/saurabh/packagename/traces/first.trace', 3): -""" + expected += ''' for arg_val_pkl, return_val_pkl in get_next_arg_and_return( + "/home/saurabh/packagename/traces/first.trace", 3 + ): +''' expected += """ args = pickle.loads(arg_val_pkl) return_val_1 = pickle.loads(return_val_pkl) ret = codeflash_wrap(packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo, '{module_path}', None, 'test_prepare_image_for_yolo', 'packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '0_2', codeflash_loop_index, codeflash_cur, codeflash_con, **args) @@ -279,9 +389,9 @@ def test_prepare_image_for_yolo(): ) os.chdir(original_cwd) assert success - assert new_test == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values")) - ) + )) def test_perfinjector_bubble_sort_results() -> None: @@ -379,10 +489,10 @@ def test_sort(): os.chdir(original_cwd) assert success assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) success, new_perf_test = inject_profiling_into_existing_test( test_path, @@ -394,11 +504,10 @@ def test_sort(): ) assert success assert new_perf_test is not None - assert new_perf_test.replace('"', "'") == expected_perfonly.format( + assert format_code_in_memory(new_perf_test) == format_code_in_memory(expected_perfonly.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") - + )) with test_path.open("w") as f: f.write(new_test) @@ -627,14 +736,14 @@ def test_sort_parametrized(input, expected_output): os.chdir(original_cwd) assert success assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") - assert new_test_perf.replace('"', "'") == expected_perfonly.format( + )) + assert format_code_in_memory(new_test_perf) == format_code_in_memory(expected_perfonly.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # # Overwrite old test with new instrumented test @@ -886,19 +995,19 @@ def test_sort_parametrized_loop(input, expected_output): os.chdir(original_cwd) assert success assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # Overwrite old test with new instrumented test with test_path_behavior.open("w") as f: f.write(new_test) - assert new_test_perf.replace('"', "'") == expected_perf.format( + assert format_code_in_memory(new_test_perf) == format_code_in_memory(expected_perf.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_parametrized_loop_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # Overwrite old test with new instrumented test with test_path_perf.open("w") as f: @@ -1225,15 +1334,15 @@ def test_sort(): os.chdir(original_cwd) assert success assert new_test_behavior is not None - assert new_test_behavior.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test_behavior) == format_code_in_memory(expected.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) - assert new_test_perf.replace('"', "'") == expected_perf.format( + assert format_code_in_memory(new_test_perf) == format_code_in_memory(expected_perf.format( module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_loop_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # Overwrite old test with new instrumented test @@ -1529,14 +1638,14 @@ def test_sort(self): assert success assert new_test_behavior is not None - assert new_test_behavior.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test_behavior) == format_code_in_memory(expected.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") - assert new_test_perf.replace('"', "'") == expected_perf.format( + )) + assert format_code_in_memory(new_test_perf) == format_code_in_memory(expected_perf.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # # Overwrite old test with new instrumented test with test_path_behavior.open("w") as f: @@ -1774,16 +1883,16 @@ def test_sort(self, input, expected_output): os.chdir(original_cwd) assert success assert new_test_behavior is not None - assert new_test_behavior.replace('"', "'") == expected_behavior.format( + assert format_code_in_memory(new_test_behavior) == format_code_in_memory(expected_behavior.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) assert new_test_perf is not None - assert new_test_perf.replace('"', "'") == expected_perf.format( + assert format_code_in_memory(new_test_perf) == format_code_in_memory(expected_perf.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # # Overwrite old test with new instrumented test @@ -2028,14 +2137,14 @@ def test_sort(self): os.chdir(original_cwd) assert success assert new_test_behavior is not None - assert new_test_behavior.replace('"', "'") == expected_behavior.format( + assert format_code_in_memory(new_test_behavior) == format_code_in_memory(expected_behavior.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") - assert new_test_perf.replace('"', "'") == expected_perf.format( + )) + assert format_code_in_memory(new_test_perf) == format_code_in_memory(expected_perf.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_loop_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # # # Overwrite old test with new instrumented test with test_path_behavior.open("w") as f: @@ -2275,14 +2384,14 @@ def test_sort(self, input, expected_output): os.chdir(original_cwd) assert success assert new_test_behavior is not None - assert new_test_behavior.replace('"', "'") == expected_behavior.format( + assert format_code_in_memory(new_test_behavior) == format_code_in_memory(expected_behavior.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") - assert new_test_perf.replace('"', "'") == expected_perf.format( + )) + assert format_code_in_memory(new_test_perf) == format_code_in_memory(expected_perf.format( module_path="code_to_optimize.tests.unittest.test_perfinjector_bubble_sort_unittest_parametrized_loop_results_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # # Overwrite old test with new instrumented test with test_path_behavior.open("w") as _f: @@ -2565,10 +2674,10 @@ def test_class_name_A_function_name(): test_path.unlink(missing_ok=True) assert success assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( tmp_dir_path=get_run_tmp_file(Path("test_return_values")), module_path="tests.pytest.test_class_function_instrumentation_temp", - ).replace('"', "'") + )) def test_wrong_function_instrumentation() -> None: @@ -2635,10 +2744,10 @@ def test_common_tags_1(): os.chdir(original_cwd) assert success assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path="tests.pytest.test_wrong_function_instrumentation_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) finally: test_path.unlink(missing_ok=True) @@ -2698,10 +2807,10 @@ def test_sort(): os.chdir(original_cwd) assert success assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path="tests.pytest.test_conditional_instrumentation_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) finally: test_path.unlink(missing_ok=True) @@ -2775,12 +2884,12 @@ def test_sort(): ) os.chdir(original_cwd) assert success - formatted_expected = expected.format( + formatted_expected = format_code_in_memory(expected.format( module_path="tests.pytest.test_perfinjector_bubble_sort_results_temp", tmp_dir_path=str(get_run_tmp_file(Path("test_return_values"))), - ) + )) assert new_test is not None - assert new_test.replace('"', "'") == formatted_expected.replace('"', "'") + assert format_code_in_memory(new_test) == formatted_expected finally: test_path.unlink(missing_ok=True) @@ -2899,9 +3008,9 @@ def test_code_replacement10() -> None: ) os.chdir(original_cwd) assert success - assert new_test == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path=Path(f.name).name, tmp_dir_path=get_run_tmp_file(Path("test_return_values")) - ) + )) def test_time_correction_instrumentation() -> None: @@ -2963,10 +3072,10 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time): test_type = TestType.EXISTING_UNIT_TEST assert success, "Test instrumentation failed" assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path="code_to_optimize.tests.pytest.test_time_correction_instrumentation_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # Overwrite old test with new instrumented test with test_path.open("w") as f: f.write(new_test) @@ -3082,10 +3191,10 @@ def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time): test_type = TestType.EXISTING_UNIT_TEST assert success, "Test instrumentation failed" assert new_test is not None - assert new_test.replace('"', "'") == expected.format( + assert format_code_in_memory(new_test) == format_code_in_memory(expected.format( module_path="code_to_optimize.tests.unittest.test_time_correction_instrumentation_unittest_temp", tmp_dir_path=get_run_tmp_file(Path("test_return_values")), - ).replace('"', "'") + )) # Overwrite old test with new instrumented test with test_path.open("w") as f: f.write(new_test)