diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 032116da7..f569e9e3a 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -8,6 +8,7 @@ from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path +from codeflash.code_utils.with_pytest_remover import remove_pytest_raises from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent, TestingMode from codeflash.verification.test_results import VerificationType @@ -91,6 +92,7 @@ def find_and_update_line_node( if self.mode == TestingMode.BEHAVIOR else [] ), + ast.Constant(value=False), *call_node.args, ] node.keywords = call_node.keywords @@ -116,6 +118,7 @@ def find_and_update_line_node( if self.mode == TestingMode.BEHAVIOR else [] ), + ast.Constant(value=False), *call_node.args, ] node.keywords = call_node.keywords @@ -334,6 +337,9 @@ def inject_profiling_into_existing_test( test_code = f.read() try: tree = ast.parse(test_code) + # Remove pytest.raises blocks if we're using pytest + if test_framework == "pytest": + tree = remove_pytest_raises(tree) except SyntaxError: logger.exception(f"Syntax error in code in file - {test_path}") return False, None @@ -721,11 +727,19 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ), ast.If( test=ast.Name(id="exception", ctx=ast.Load()), - body=[ast.Raise(exc=ast.Name(id="exception", ctx=ast.Load()), cause=None, lineno=lineno + 22)], - orelse=[], + body=[ + ast.If( + test=ast.Name(id="reraise_exception", ctx=ast.Load()), + body=[ast.Raise(exc=ast.Name(id="exception", ctx=ast.Load()), cause=None, lineno=lineno + 22)], + orelse=[], + lineno=lineno + 22, + ) + ], + orelse=[ + ast.Return(value=ast.Name(id="return_value", ctx=ast.Load()), lineno=lineno + 19) + ], lineno=lineno + 22, - ), - ast.Return(value=ast.Name(id="return_value", ctx=ast.Load()), lineno=lineno + 19), + ) ] return ast.FunctionDef( name="codeflash_wrap", @@ -740,16 +754,17 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ast.arg(arg="loop_index", annotation=None), *([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []), *([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []), + *([ast.arg(arg="reraise_exception", annotation=None)]), ], vararg=ast.arg(arg="args"), kwarg=ast.arg(arg="kwargs"), posonlyargs=[], kwonlyargs=[], kw_defaults=[], - defaults=[], + defaults=[*([ast.Constant(value=False)])], ), body=wrapper_body, lineno=lineno, decorator_list=[], returns=None, - ) + ) \ No newline at end of file diff --git a/codeflash/code_utils/with_pytest_remover.py b/codeflash/code_utils/with_pytest_remover.py new file mode 100644 index 000000000..9f1f5b6f3 --- /dev/null +++ b/codeflash/code_utils/with_pytest_remover.py @@ -0,0 +1,37 @@ +import ast +class PytestRaisesRemover(ast.NodeTransformer): + """Replaces 'with pytest.raises()' blocks with the content inside them.""" + + def visit_With(self, node: ast.With) -> ast.AST | list[ast.AST]: + # Process any nested with blocks first by recursively visiting children + node = self.generic_visit(node) + + for item in node.items: + # Check if this is a pytest.raises block + if (isinstance(item.context_expr, ast.Call) and + isinstance(item.context_expr.func, ast.Attribute) and + isinstance(item.context_expr.func.value, ast.Name) and + item.context_expr.func.value.id == "pytest" and + item.context_expr.func.attr == "raises"): + + # Return the body contents instead of the with block + # If there's multiple statements in the body, return them all + if len(node.body) == 1: + return node.body[0] + return node.body + + return node + + +def remove_pytest_raises(tree: ast.AST) -> ast.AST: + """Removes pytest.raises blocks and shifts their content out. + + Args: + tree: The AST tree to transform + + Returns: + The transformed AST with pytest.raises blocks removed + + """ + transformer = PytestRaisesRemover() + return transformer.visit(tree) diff --git a/tests/test_instrument_all_and_run.py b/tests/test_instrument_all_and_run.py index ce06c855a..37486eb01 100644 --- a/tests/test_instrument_all_and_run.py +++ b/tests/test_instrument_all_and_run.py @@ -16,7 +16,7 @@ from codeflash.verification.test_results import TestType # Used by cli instrumentation -codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): +codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, reraise_exception=False, *args, **kwargs): test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} @@ -41,8 +41,10 @@ codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: - raise exception - return return_value + if reraise_exception: + raise exception + else: + return return_value """ @@ -80,10 +82,10 @@ def test_sort(): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] codeflash_con.close() """ @@ -245,7 +247,7 @@ def test_sort(): from code_to_optimize.bubble_sort_method import BubbleSorter -def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): +def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, reraise_exception=False, *args, **kwargs): test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} @@ -275,8 +277,10 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: - raise exception - return return_value + if reraise_exception: + raise exception + else: + return return_value """ expected += """ def test_sort(): @@ -287,11 +291,11 @@ def test_sort(): codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] sort_class = BubbleSorter() - output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] sort_class = BubbleSorter() - output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '6', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '6', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] codeflash_con.close() """ diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index 79f4bc5dd..13cdb2b4b 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -18,7 +18,7 @@ from codeflash.verification.test_results import TestType from codeflash.verification.verification_utils import TestConfig -codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): +codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, reraise_exception=False, *args, **kwargs): test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} @@ -43,11 +43,13 @@ codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: - raise exception - return return_value + if reraise_exception: + raise exception + else: + return return_value """ -codeflash_wrap_perfonly_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, *args, **kwargs): +codeflash_wrap_perfonly_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, reraise_exception=False, *args, **kwargs): test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} @@ -69,8 +71,10 @@ gc.enable() print(f"!######{{test_module_name}}:{{(test_class_name + '.' if test_class_name else '')}}{{test_name}}:{{function_name}}:{{loop_index}}:{{invocation_id}}:{{codeflash_duration}}######!") if exception: - raise exception - return return_value + if reraise_exception: + raise exception + else: + return return_value """ @@ -105,7 +109,7 @@ def test_sort(self): from code_to_optimize.bubble_sort import sorter -def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): +def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, reraise_exception=False, *args, **kwargs): test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} @@ -135,8 +139,10 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: - raise exception - return return_value + if reraise_exception: + raise exception + else: + return return_value class TestPigLatin(unittest.TestCase): @@ -148,13 +154,13 @@ def test_sort(self): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) self.assertEqual(output, [0, 1, 2, 3, 4, 5]) input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) input = list(reversed(range(5000))) - self.assertEqual(codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input), list(range(5000))) + self.assertEqual(codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, False, input), list(range(5000))) codeflash_con.close() """ with tempfile.NamedTemporaryFile(mode="w") as f: @@ -206,7 +212,7 @@ def test_prepare_image_for_yolo(): from codeflash.validation.equivalence import compare_results -def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): +def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, reraise_exception=False, *args, **kwargs): test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} @@ -236,8 +242,10 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: - raise exception - return return_value + if reraise_exception: + raise exception + else: + return return_value def test_prepare_image_for_yolo(): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) @@ -254,7 +262,7 @@ def test_prepare_image_for_yolo(): """ expected += """ args = pickle.loads(arg_val_pkl) return_val_1 = pickle.loads(return_val_pkl) - ret = codeflash_wrap(packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo, '{module_path}', None, 'test_prepare_image_for_yolo', 'packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '0_2', codeflash_loop_index, codeflash_cur, codeflash_con, **args) + ret = codeflash_wrap(packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo, '{module_path}', None, 'test_prepare_image_for_yolo', 'packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo', '0_2', codeflash_loop_index, codeflash_cur, codeflash_con, False, **args) assert compare_results(return_val_1, ret) codeflash_con.close() """ @@ -309,10 +317,10 @@ def test_sort(): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] codeflash_con.close() """ @@ -332,10 +340,10 @@ def test_sort(): def test_sort(): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, False, input) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, False, input) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] """ ) @@ -533,7 +541,7 @@ def test_sort_parametrized(input, expected_output): codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite') codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized', 'sorter', '0', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized', 'sorter', '0', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == expected_output codeflash_con.close() """ @@ -555,7 +563,7 @@ def test_sort_parametrized(input, expected_output): @pytest.mark.parametrize('input, expected_output', [([5, 4, 3, 2, 1, 0], [0, 1, 2, 3, 4, 5]), ([5.0, 4.0, 3.0, 2.0, 1.0, 0.0], [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]), (list(reversed(range(50))), list(range(50)))]) def test_sort_parametrized(input, expected_output): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized', 'sorter', '0', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized', 'sorter', '0', codeflash_loop_index, False, input) assert output == expected_output """ ) @@ -761,7 +769,7 @@ def test_sort_parametrized_loop(input, expected_output): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') for i in range(2): - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized_loop', 'sorter', '0_0', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized_loop', 'sorter', '0_0', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == expected_output codeflash_con.close() """ @@ -783,7 +791,7 @@ def test_sort_parametrized_loop(input, expected_output): def test_sort_parametrized_loop(input, expected_output): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) for i in range(2): - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized_loop', 'sorter', '0_0', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort_parametrized_loop', 'sorter', '0_0', codeflash_loop_index, False, input) assert output == expected_output """ ) @@ -1071,7 +1079,7 @@ def test_sort(): for i in range(3): input = inputs[i] expected_output = expected_outputs[i] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '2_2', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '2_2', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == expected_output codeflash_con.close() """ @@ -1095,7 +1103,7 @@ def test_sort(): for i in range(3): input = inputs[i] expected_output = expected_outputs[i] - output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '2_2', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '2_2', codeflash_loop_index, False, input) assert output == expected_output """ ) @@ -1324,13 +1332,13 @@ def test_sort(self): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) self.assertEqual(output, [0, 1, 2, 3, 4, 5]) input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) input = list(reversed(range(50))) - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) self.assertEqual(output, list(range(50))) codeflash_con.close() """ @@ -1355,13 +1363,13 @@ class TestPigLatin(unittest.TestCase): def test_sort(self): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '1', codeflash_loop_index, False, input) self.assertEqual(output, [0, 1, 2, 3, 4, 5]) input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '4', codeflash_loop_index, False, input) self.assertEqual(output, [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) input = list(reversed(range(50))) - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '7', codeflash_loop_index, False, input) self.assertEqual(output, list(range(50))) """ ) @@ -1591,7 +1599,7 @@ def test_sort(self, input, expected_output): codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite') codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) self.assertEqual(output, expected_output) codeflash_con.close() """ @@ -1617,7 +1625,7 @@ class TestPigLatin(unittest.TestCase): @timeout_decorator.timeout(15) def test_sort(self, input, expected_output): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0', codeflash_loop_index, False, input) self.assertEqual(output, expected_output) """ ) @@ -1841,7 +1849,7 @@ def test_sort(self): for i in range(3): input = inputs[i] expected_output = expected_outputs[i] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) self.assertEqual(output, expected_output) codeflash_con.close() """ @@ -1871,7 +1879,7 @@ def test_sort(self): for i in range(3): input = inputs[i] expected_output = expected_outputs[i] - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '2_2', codeflash_loop_index, False, input) self.assertEqual(output, expected_output) """ ) @@ -2093,7 +2101,7 @@ def test_sort(self, input, expected_output): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') for i in range(2): - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) self.assertEqual(output, expected_output) codeflash_con.close() """ @@ -2120,7 +2128,7 @@ class TestPigLatin(unittest.TestCase): def test_sort(self, input, expected_output): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) for i in range(2): - output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, input) + output = codeflash_wrap(sorter, '{module_path}', 'TestPigLatin', 'test_sort', 'sorter', '0_0', codeflash_loop_index, False, input) self.assertEqual(output, expected_output) """ ) @@ -2417,7 +2425,7 @@ def test_class_name_A_function_name(): codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite') codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') - ret = codeflash_wrap(class_name_A.function_name, '{module_path}', None, 'test_class_name_A_function_name', 'class_name_A.function_name', '0', codeflash_loop_index, codeflash_cur, codeflash_con, **args) + ret = codeflash_wrap(class_name_A.function_name, '{module_path}', None, 'test_class_name_A_function_name', 'class_name_A.function_name', '0', codeflash_loop_index, codeflash_cur, codeflash_con, False, **args) codeflash_con.close() """ ) @@ -2487,9 +2495,9 @@ def test_common_tags_1(): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') articles_1 = [1, 2, 3] - assert codeflash_wrap(find_common_tags, '{module_path}', None, 'test_common_tags_1', 'find_common_tags', '1', codeflash_loop_index, codeflash_cur, codeflash_con, articles_1) == set(1, 2) + assert codeflash_wrap(find_common_tags, '{module_path}', None, 'test_common_tags_1', 'find_common_tags', '1', codeflash_loop_index, codeflash_cur, codeflash_con, False, articles_1) == set(1, 2) articles_2 = [1, 2] - assert codeflash_wrap(find_common_tags, '{module_path}', None, 'test_common_tags_1', 'find_common_tags', '3', codeflash_loop_index, codeflash_cur, codeflash_con, articles_2) == set(1) + assert codeflash_wrap(find_common_tags, '{module_path}', None, 'test_common_tags_1', 'find_common_tags', '3', codeflash_loop_index, codeflash_cur, codeflash_con, False, articles_2) == set(1) codeflash_con.close() """ ) @@ -2555,7 +2563,7 @@ def test_sort(): codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] if len(input) > 0: - assert codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1_0', codeflash_loop_index, codeflash_cur, codeflash_con, input) == [0, 1, 2, 3, 4, 5] + assert codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1_0', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) == [0, 1, 2, 3, 4, 5] codeflash_con.close() """ ) @@ -2621,10 +2629,10 @@ def test_sort(): codeflash_cur = codeflash_con.cursor() codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)') input = [5, 4, 3, 2, 1, 0] - output = codeflash_wrap(BubbleSorter.sorter, 'tests.pytest.test_perfinjector_bubble_sort_results_temp', None, 'test_sort', 'BubbleSorter.sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(BubbleSorter.sorter, 'tests.pytest.test_perfinjector_bubble_sort_results_temp', None, 'test_sort', 'BubbleSorter.sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == [0, 1, 2, 3, 4, 5] input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0] - output = codeflash_wrap(BubbleSorter.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input) + output = codeflash_wrap(BubbleSorter.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, False, input) assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] codeflash_con.close() """ @@ -2712,7 +2720,7 @@ def test_code_replacement10() -> None: from codeflash.optimization.optimizer import Optimizer -def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs): +def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, reraise_exception=False, *args, **kwargs): test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}' if not hasattr(codeflash_wrap, 'index'): codeflash_wrap.index = {{}} @@ -2742,8 +2750,10 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call')) codeflash_con.commit() if exception: - raise exception - return return_value + if reraise_exception: + raise exception + else: + return return_value def test_code_replacement10() -> None: codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) @@ -2757,9 +2767,9 @@ def test_code_replacement10() -> None: func_top_optimize = FunctionToOptimize(function_name='main_method', file_path=str(file_path), parents=[FunctionParent('MainClass', 'ClassDef')]) with open(file_path) as f: original_code = f.read() - code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_1', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code).unwrap() + code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_1', codeflash_loop_index, codeflash_cur, codeflash_con, False, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code).unwrap() assert code_context.code_to_optimize_with_helpers == get_code_output - code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_3', codeflash_loop_index, codeflash_cur, codeflash_con, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code) + code_context = codeflash_wrap(opt.get_code_optimization_context, '{module_path}', None, 'test_code_replacement10', 'Optimizer.get_code_optimization_context', '4_3', codeflash_loop_index, codeflash_cur, codeflash_con, False, function_to_optimize=func_top_optimize, project_root=str(file_path.parent), original_source_code=original_code) assert code_context.code_to_optimize_with_helpers == get_code_output codeflash_con.close() """ @@ -2814,7 +2824,7 @@ def test_sleepfunc_sequence_short(n, expected_total_sleep_time): @pytest.mark.parametrize('n, expected_total_sleep_time', [(0.01, 0.01), (0.02, 0.02)]) def test_sleepfunc_sequence_short(n, expected_total_sleep_time): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) - output = codeflash_wrap(accurate_sleepfunc, '{module_path}', None, 'test_sleepfunc_sequence_short', 'accurate_sleepfunc', '0', codeflash_loop_index, n) + output = codeflash_wrap(accurate_sleepfunc, '{module_path}', None, 'test_sleepfunc_sequence_short', 'accurate_sleepfunc', '0', codeflash_loop_index, False, n) assert output == expected_total_sleep_time """ ) @@ -2934,7 +2944,7 @@ class TestPigLatin(unittest.TestCase): @timeout_decorator.timeout(15) def test_sleepfunc_sequence_short(self, n, expected_total_sleep_time): codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX']) - output = codeflash_wrap(accurate_sleepfunc, '{module_path}', 'TestPigLatin', 'test_sleepfunc_sequence_short', 'accurate_sleepfunc', '0', codeflash_loop_index, n) + output = codeflash_wrap(accurate_sleepfunc, '{module_path}', 'TestPigLatin', 'test_sleepfunc_sequence_short', 'accurate_sleepfunc', '0', codeflash_loop_index, False, n) """ ) code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/sleeptime.py").resolve() diff --git a/tests/test_with_pytest_remover.py b/tests/test_with_pytest_remover.py new file mode 100644 index 000000000..60e4df3c3 --- /dev/null +++ b/tests/test_with_pytest_remover.py @@ -0,0 +1,93 @@ +import ast + +from codeflash.code_utils.with_pytest_remover import remove_pytest_raises + +def test_remove_single_pytest_raises(): + original = """ +def test_something(): + with pytest.raises(ValueError): + raise ValueError('test') +""" + expected = """ +def test_something(): + raise ValueError('test') +""" + tree = ast.parse(original) + result = remove_pytest_raises(tree) + # Convert result AST back to code and normalize whitespace + result_code = ast.unparse(result).strip() + assert result_code == expected.strip() + + +def test_remove_multiple_pytest_raises(): + original = """ +def test_multiple(): + with pytest.raises(TypeError): + int('abc') + with pytest.raises(ValueError): + int('') +""" + expected = """ +def test_multiple(): + int('abc') + int('') +""" + tree = ast.parse(original) + result = remove_pytest_raises(tree) + result_code = ast.unparse(result).strip() + assert result_code == expected.strip() + + +def test_preserve_other_with_blocks(): + original = """ +def test_mixed(): + with open('file.txt') as f: + content = f.read() + with pytest.raises(ValueError): + int('abc') + with contextlib.contextmanager(): + pass +""" + expected = """ +def test_mixed(): + with open('file.txt') as f: + content = f.read() + int('abc') + with contextlib.contextmanager(): + pass +""" + tree = ast.parse(original) + result = remove_pytest_raises(tree) + result_code = ast.unparse(result).strip() + assert result_code == expected.strip() + + +def test_nested_with_blocks(): + original = """ +def test_nested(): + with open('file.txt') as f: + with pytest.raises(ValueError): + int('abc') +""" + expected = """ +def test_nested(): + with open('file.txt') as f: + int('abc') +""" + tree = ast.parse(original) + result = remove_pytest_raises(tree) + result_code = ast.unparse(result).strip() + assert result_code == expected.strip() + + +def test_no_pytest_raises(): + original = """ +def test_normal(): + x = 1 + y = 2 + assert x + y == 3 +""" + tree = ast.parse(original) + result = remove_pytest_raises(tree) + result_code = ast.unparse(result).strip() + assert result_code == original.strip()