diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 56a22fd4107..b972afa2990 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -266,15 +266,6 @@ impl FunctionBuilder { self.insert_instruction(Instruction::Call { func, arguments }, Some(result_types)).results() } - /// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs - pub(crate) fn insert_shift_left(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { - let base = self.field_constant(FieldElement::from(2_u128)); - let pow = self.pow(base, rhs); - let typ = self.current_function.dfg.type_of_value(lhs); - let pow = self.insert_cast(pow, typ); - self.insert_binary(lhs, BinaryOp::Mul, pow) - } - /// Insert ssa instructions which computes lhs << rhs by doing lhs*2^rhs /// and truncate the result to bit_size pub(crate) fn insert_wrapping_shift_left( @@ -308,8 +299,9 @@ impl FunctionBuilder { let one = self.numeric_constant(FieldElement::one(), Type::unsigned(1)); let predicate = self.insert_binary(overflow, BinaryOp::Eq, one); let predicate = self.insert_cast(predicate, typ.clone()); - - let pow = self.pow(base, rhs); + // we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value + let rhs_unsigned = self.insert_cast(rhs, Type::unsigned(bit_size)); + let pow = self.pow(base, rhs_unsigned); let pow = self.insert_cast(pow, typ); (FieldElement::max_num_bits(), self.insert_binary(predicate, BinaryOp::Mul, pow)) }; @@ -323,9 +315,16 @@ impl FunctionBuilder { } /// Insert ssa instructions which computes lhs >> rhs by doing lhs/2^rhs - pub(crate) fn insert_shift_right(&mut self, lhs: ValueId, rhs: ValueId) -> ValueId { + pub(crate) fn insert_shift_right( + &mut self, + lhs: ValueId, + rhs: ValueId, + bit_size: u32, + ) -> ValueId { let base = self.field_constant(FieldElement::from(2_u128)); - let pow = self.pow(base, rhs); + // we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value + let rhs_unsigned = self.insert_cast(rhs, Type::unsigned(bit_size)); + let pow = self.pow(base, rhs_unsigned); self.insert_binary(lhs, BinaryOp::Div, pow) } diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index 501a03bcb5c..5724bf56e8e 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -326,8 +326,8 @@ impl<'a> FunctionContext<'a> { self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location); self.builder.insert_cast(result, result_type) } - BinaryOpKind::ShiftLeft => { - unreachable!("shift is not supported for signed integer") + BinaryOpKind::ShiftLeft | BinaryOpKind::ShiftRight => { + self.check_shift_overflow(result, rhs, bit_size, location, true) } _ => unreachable!("operator {} should not overflow", operator), } @@ -343,8 +343,10 @@ impl<'a> FunctionContext<'a> { if operator == BinaryOpKind::Multiply && bit_size == 1 { result - } else if operator == BinaryOpKind::ShiftLeft { - self.check_left_shift_overflow(result, rhs, bit_size, location) + } else if operator == BinaryOpKind::ShiftLeft + || operator == BinaryOpKind::ShiftRight + { + self.check_shift_overflow(result, rhs, bit_size, location, false) } else { let message = format!("attempt to {} with overflow", op_name); let range_constraint = Instruction::RangeCheck { @@ -360,26 +362,44 @@ impl<'a> FunctionContext<'a> { } } - /// Overflow checks for shift-left - /// We use Rust behavior for shift left: + /// Overflow checks for bit-shift + /// We use Rust behavior for bit-shift: /// If rhs is more or equal than the bit size, then we overflow - /// If not, we do not overflow and shift left with 0 when bits are falling out of the bit size - fn check_left_shift_overflow( + /// If not, we do not overflow and shift with 0 when bits are falling out of the bit size + fn check_shift_overflow( &mut self, result: ValueId, rhs: ValueId, bit_size: u32, location: Location, + is_signed: bool, ) -> ValueId { + let one = self.builder.numeric_constant(FieldElement::one(), Type::bool()); + let rhs = + if is_signed { self.builder.insert_cast(rhs, Type::unsigned(bit_size)) } else { rhs }; + // Bit-shift with a negative number is an overflow + if is_signed { + // We compute the sign of rhs. + let half_width = self.builder.numeric_constant( + FieldElement::from(2_i128.pow(bit_size - 1)), + Type::unsigned(bit_size), + ); + let sign = self.builder.insert_binary(rhs, BinaryOp::Lt, half_width); + self.builder.set_location(location).insert_constrain( + sign, + one, + Some("attempt to bit-shift with overflow".to_string()), + ); + } + let max = self .builder .numeric_constant(FieldElement::from(bit_size as i128), Type::unsigned(bit_size)); let overflow = self.builder.insert_binary(rhs, BinaryOp::Lt, max); - let one = self.builder.numeric_constant(FieldElement::one(), Type::bool()); self.builder.set_location(location).insert_constrain( overflow, one, - Some("attempt to left shift with overflow".to_owned()), + Some("attempt to bit-shift with overflow".to_owned()), ); self.builder.insert_truncate(result, bit_size, bit_size + 1) } @@ -466,7 +486,6 @@ impl<'a> FunctionContext<'a> { Some(message), ); } - BinaryOpKind::ShiftLeft => unreachable!("shift is not supported for signed integer"), _ => unreachable!("operator {} should not overflow", operator), } } @@ -482,19 +501,26 @@ impl<'a> FunctionContext<'a> { mut rhs: ValueId, location: Location, ) -> Values { + let result_type = self.builder.type_of_value(lhs); let mut result = match operator { BinaryOpKind::ShiftLeft => { - let result_type = self.builder.current_function.dfg.type_of_value(lhs); let bit_size = match result_type { Type::Numeric(NumericType::Signed { bit_size }) | Type::Numeric(NumericType::Unsigned { bit_size }) => bit_size, - _ => unreachable!("ICE: Truncation attempted on non-integer"), + _ => unreachable!("ICE: left-shift attempted on non-integer"), }; self.builder.insert_wrapping_shift_left(lhs, rhs, bit_size) } - BinaryOpKind::ShiftRight => self.builder.insert_shift_right(lhs, rhs), + BinaryOpKind::ShiftRight => { + let bit_size = match result_type { + Type::Numeric(NumericType::Signed { bit_size }) + | Type::Numeric(NumericType::Unsigned { bit_size }) => bit_size, + _ => unreachable!("ICE: right-shift attempted on non-integer"), + }; + self.builder.insert_shift_right(lhs, rhs, bit_size) + } BinaryOpKind::Equal | BinaryOpKind::NotEqual - if matches!(self.builder.type_of_value(lhs), Type::Array(..)) => + if matches!(result_type, Type::Array(..)) => { return self.insert_array_equality(lhs, operator, rhs, location) } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index f7154895150..e8f9f23d378 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -11,7 +11,7 @@ use crate::{ types::Type, }, node_interner::{DefinitionKind, ExprId, FuncId, TraitId, TraitImplKind, TraitMethodId}, - BinaryOpKind, Signedness, TypeBinding, TypeBindings, TypeVariableKind, UnaryOp, + BinaryOpKind, TypeBinding, TypeBindings, TypeVariableKind, UnaryOp, }; use super::{errors::TypeCheckError, TypeChecker}; @@ -1121,13 +1121,7 @@ impl<'interner> TypeChecker<'interner> { span, }); } - if op.is_bit_shift() - && (*sign_x == Signedness::Signed || *sign_y == Signedness::Signed) - { - Err(TypeCheckError::InvalidInfixOp { kind: "Signed integer", span }) - } else { - Ok(Integer(*sign_x, *bit_width_x)) - } + Ok(Integer(*sign_x, *bit_width_x)) } (Integer(..), FieldElement) | (FieldElement, Integer(..)) => { Err(TypeCheckError::IntegerAndFieldBinaryOperation { span }) diff --git a/test_programs/execution_success/bit_shifts_runtime/src/main.nr b/test_programs/execution_success/bit_shifts_runtime/src/main.nr index a2c873a7e7f..33d68765598 100644 --- a/test_programs/execution_success/bit_shifts_runtime/src/main.nr +++ b/test_programs/execution_success/bit_shifts_runtime/src/main.nr @@ -5,4 +5,15 @@ fn main(x: u64, y: u64) { // runtime shifts on runtime values assert(x << y == 128); assert(x >> y == 32); + + // Bit-shift with signed integers + let mut a :i8 = y as i8; + let mut b: i8 = x as i8; + assert(b << 1 == -128); + assert(b >> 2 == 16); + assert(b >> a == 32); + a = -a; + assert(a << 7 == -128); + assert(a << -a == -2); + }