From 4d1bacd5b7ef3e2986393fd7f23b0ac24bbe1227 Mon Sep 17 00:00:00 2001 From: Charles Cooper Date: Mon, 10 Jun 2024 11:16:53 -0700 Subject: [PATCH] fix[venom]: move loop invariant assertion to entry block (#4098) loop invariant bound check was in the body of the loop, not the entry block. move it up to the entry so we don't re-check the same assertion every loop iteration. --- tests/functional/syntax/test_for_range.py | 4 ++-- vyper/venom/ir_node_to_venom.py | 20 ++++++++------------ 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/tests/functional/syntax/test_for_range.py b/tests/functional/syntax/test_for_range.py index 1de32108c5..97e77f32f7 100644 --- a/tests/functional/syntax/test_for_range.py +++ b/tests/functional/syntax/test_for_range.py @@ -368,14 +368,14 @@ def foo(): """ @external def foo(): - x: int128 = 5 + x: int128 = 4 for i: int128 in range(x, bound=4): pass """, """ @external def foo(): - x: int128 = 5 + x: int128 = 4 for i: int128 in range(0, x, bound=4): pass """, diff --git a/vyper/venom/ir_node_to_venom.py b/vyper/venom/ir_node_to_venom.py index 61b3c081ff..2c99cf5668 100644 --- a/vyper/venom/ir_node_to_venom.py +++ b/vyper/venom/ir_node_to_venom.py @@ -468,14 +468,7 @@ def emit_body_blocks(): start, end, _ = _convert_ir_bb_list(fn, ir.args[1:4], symbols) assert ir.args[3].is_literal, "repeat bound expected to be literal" - bound = ir.args[3].value - if ( - isinstance(end, IRLiteral) - and isinstance(start, IRLiteral) - and end.value + start.value <= bound - ): - bound = None body = ir.args[4] @@ -491,9 +484,15 @@ def emit_body_blocks(): counter_var = entry_block.append_instruction("store", start) symbols[sym.value] = counter_var + + if bound is not None: + # assert le end bound + invalid_end = entry_block.append_instruction("gt", bound, end) + valid_end = entry_block.append_instruction("iszero", invalid_end) + entry_block.append_instruction("assert", valid_end) + end = entry_block.append_instruction("add", start, end) - if bound: - bound = entry_block.append_instruction("add", start, bound) + entry_block.append_instruction("jmp", cond_block.label) xor_ret = cond_block.append_instruction("xor", counter_var, end) @@ -501,9 +500,6 @@ def emit_body_blocks(): fn.append_basic_block(cond_block) fn.append_basic_block(body_block) - if bound: - xor_ret = body_block.append_instruction("xor", counter_var, bound) - body_block.append_instruction("assert", xor_ret) emit_body_blocks() body_end = fn.get_basic_block()