Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.code_utils.with_pytest_remover import remove_pytest_raises
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestingMode
from codeflash.verification.test_results import VerificationType
Expand Down Expand Up @@ -91,6 +92,7 @@ def find_and_update_line_node(
if self.mode == TestingMode.BEHAVIOR
else []
),
ast.Constant(value=False),
*call_node.args,
]
node.keywords = call_node.keywords
Expand All @@ -116,6 +118,7 @@ def find_and_update_line_node(
if self.mode == TestingMode.BEHAVIOR
else []
),
ast.Constant(value=False),
*call_node.args,
]
node.keywords = call_node.keywords
Expand Down Expand Up @@ -334,6 +337,9 @@ def inject_profiling_into_existing_test(
test_code = f.read()
try:
tree = ast.parse(test_code)
# Remove pytest.raises blocks if we're using pytest
if test_framework == "pytest":
tree = remove_pytest_raises(tree)
except SyntaxError:
logger.exception(f"Syntax error in code in file - {test_path}")
return False, None
Expand Down Expand Up @@ -721,11 +727,19 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
),
ast.If(
test=ast.Name(id="exception", ctx=ast.Load()),
body=[ast.Raise(exc=ast.Name(id="exception", ctx=ast.Load()), cause=None, lineno=lineno + 22)],
orelse=[],
body=[
ast.If(
test=ast.Name(id="reraise_exception", ctx=ast.Load()),
body=[ast.Raise(exc=ast.Name(id="exception", ctx=ast.Load()), cause=None, lineno=lineno + 22)],
orelse=[],
lineno=lineno + 22,
)
],
orelse=[
ast.Return(value=ast.Name(id="return_value", ctx=ast.Load()), lineno=lineno + 19)
],
lineno=lineno + 22,
),
ast.Return(value=ast.Name(id="return_value", ctx=ast.Load()), lineno=lineno + 19),
)
]
return ast.FunctionDef(
name="codeflash_wrap",
Expand All @@ -740,16 +754,17 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
ast.arg(arg="loop_index", annotation=None),
*([ast.arg(arg="codeflash_cur", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
*([ast.arg(arg="codeflash_con", annotation=None)] if mode == TestingMode.BEHAVIOR else []),
*([ast.arg(arg="reraise_exception", annotation=None)]),
],
vararg=ast.arg(arg="args"),
kwarg=ast.arg(arg="kwargs"),
posonlyargs=[],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
defaults=[*([ast.Constant(value=False)])],
),
body=wrapper_body,
lineno=lineno,
decorator_list=[],
returns=None,
)
)
37 changes: 37 additions & 0 deletions codeflash/code_utils/with_pytest_remover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import ast
class PytestRaisesRemover(ast.NodeTransformer):
"""Replaces 'with pytest.raises()' blocks with the content inside them."""

def visit_With(self, node: ast.With) -> ast.AST | list[ast.AST]:
# Process any nested with blocks first by recursively visiting children
node = self.generic_visit(node)

for item in node.items:
# Check if this is a pytest.raises block
if (isinstance(item.context_expr, ast.Call) and
isinstance(item.context_expr.func, ast.Attribute) and
isinstance(item.context_expr.func.value, ast.Name) and
item.context_expr.func.value.id == "pytest" and
item.context_expr.func.attr == "raises"):

Comment on lines +2 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class PytestRaisesRemover(ast.NodeTransformer):
"""Replaces 'with pytest.raises()' blocks with the content inside them."""
def visit_With(self, node: ast.With) -> ast.AST | list[ast.AST]:
# Process any nested with blocks first by recursively visiting children
node = self.generic_visit(node)
for item in node.items:
# Check if this is a pytest.raises block
if (isinstance(item.context_expr, ast.Call) and
isinstance(item.context_expr.func, ast.Attribute) and
isinstance(item.context_expr.func.value, ast.Name) and
item.context_expr.func.value.id == "pytest" and
item.context_expr.func.attr == "raises"):
# Directly visit children and check if they are nested with blocks
if (
isinstance(item.context_expr, ast.Call)
and isinstance(item.context_expr.func, ast.Attribute)
and isinstance(item.context_expr.func.value, ast.Name)
and item.context_expr.func.value.id == "pytest"
and item.context_expr.func.attr == "raises"
):
return self._unwrap_body(node.body)
# Generic visit for other types of 'with' blocks
return self.generic_visit(node)
def _unwrap_body(self, body: list[ast.stmt]) -> ast.AST | list[ast.AST]:
# Unwrap the body either as a single statement or a list of statements
if len(body) == 1:
return body[0]
return body

# Return the body contents instead of the with block
# If there's multiple statements in the body, return them all
if len(node.body) == 1:
return node.body[0]
return node.body

return node


def remove_pytest_raises(tree: ast.AST) -> ast.AST:
"""Removes pytest.raises blocks and shifts their content out.

Args:
tree: The AST tree to transform

Returns:
The transformed AST with pytest.raises blocks removed

"""
transformer = PytestRaisesRemover()
return transformer.visit(tree)
24 changes: 14 additions & 10 deletions tests/test_instrument_all_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from codeflash.verification.test_results import TestType

# Used by cli instrumentation
codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
codeflash_wrap_string = """def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, reraise_exception=False, *args, **kwargs):
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
Expand All @@ -41,8 +41,10 @@
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
if reraise_exception:
raise exception
else:
return return_value
"""


Expand Down Expand Up @@ -80,10 +82,10 @@ def test_sort():
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
input = [5, 4, 3, 2, 1, 0]
output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, input)
output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, False, input)
assert output == [0, 1, 2, 3, 4, 5]
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, input)
output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, False, input)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
codeflash_con.close()
"""
Expand Down Expand Up @@ -245,7 +247,7 @@ def test_sort():
from code_to_optimize.bubble_sort_method import BubbleSorter


def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, function_name, line_id, loop_index, codeflash_cur, codeflash_con, reraise_exception=False, *args, **kwargs):
test_id = f'{{test_module_name}}:{{test_class_name}}:{{test_name}}:{{line_id}}:{{loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
Expand Down Expand Up @@ -275,8 +277,10 @@ def codeflash_wrap(wrapped, test_module_name, test_class_name, test_name, functi
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', (test_module_name, test_class_name, test_name, function_name, loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call'))
codeflash_con.commit()
if exception:
raise exception
return return_value
if reraise_exception:
raise exception
else:
return return_value
"""
expected += """
def test_sort():
Expand All @@ -287,11 +291,11 @@ def test_sort():
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
input = [5, 4, 3, 2, 1, 0]
sort_class = BubbleSorter()
output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, input)
output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, False, input)
assert output == [0, 1, 2, 3, 4, 5]
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
sort_class = BubbleSorter()
output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '6', codeflash_loop_index, codeflash_cur, codeflash_con, input)
output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '6', codeflash_loop_index, codeflash_cur, codeflash_con, False, input)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
codeflash_con.close()
"""
Expand Down
Loading
Loading