Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Avoid unnecessary range checks by inspecting instructions for casts #4039

Merged
merged 4 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,9 @@ impl FunctionBuilder {
if let Some(rhs_constant) = self.current_function.dfg.get_numeric_constant(rhs) {
// Happy case is that we know precisely by how many bits the the integer will
// increase: lhs_bit_size + rhs
let (rhs_bit_size_pow_2, overflows) =
2_u128.overflowing_pow(rhs_constant.to_u128() as u32);
let bit_shift_size = rhs_constant.to_u128() as u32;

let (rhs_bit_size_pow_2, overflows) = 2_u128.overflowing_pow(bit_shift_size);
if overflows {
assert!(bit_size < 128, "ICE - shift left with big integers are not supported");
if bit_size < 128 {
Expand All @@ -303,7 +304,10 @@ impl FunctionBuilder {
}
}
let pow = self.numeric_constant(FieldElement::from(rhs_bit_size_pow_2), typ);
(bit_size + (rhs_constant.to_u128() as u32), pow)

let max_lhs_bits = self.current_function.dfg.get_value_max_num_bits(lhs);

(max_lhs_bits + bit_shift_size, pow)
} else {
// we use a predicate to nullify the result in case of overflow
let bit_size_var =
Expand Down
19 changes: 19 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,25 @@ impl DataFlowGraph {
self.values[value].get_type().clone()
}

/// Returns the maximum possible number of bits that `value` can potentially be.
///
/// Should `value` be a numeric constant then this function will return the exact number of bits required,
/// otherwise it will return the minimum number of bits based on type information.
pub(crate) fn get_value_max_num_bits(&self, value: ValueId) -> u32 {
match self[value] {
Value::Instruction { instruction, .. } => {
if let Instruction::Cast(original_value, _) = self[instruction] {
self.type_of_value(original_value).bit_size()
} else {
self.type_of_value(value).bit_size()
}
}

Value::NumericConstant { constant, .. } => constant.num_bits(),
_ => self.type_of_value(value).bit_size(),
}
}

/// True if the type of this value is Type::Reference.
/// Using this method over type_of_value avoids cloning the value's type.
pub(crate) fn value_is_reference(&self, value: ValueId) -> bool {
Expand Down
84 changes: 63 additions & 21 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,29 +335,71 @@ impl<'a> FunctionContext<'a> {
}
}
Type::Numeric(NumericType::Unsigned { bit_size }) => {
let op_name = match operator {
BinaryOpKind::Add => "add",
BinaryOpKind::Subtract => "subtract",
BinaryOpKind::Multiply => "multiply",
BinaryOpKind::ShiftLeft => "left shift",
_ => unreachable!("operator {} should not overflow", operator),
};
let dfg = &self.builder.current_function.dfg;

if operator == BinaryOpKind::Multiply && bit_size == 1 {
result
} else if operator == BinaryOpKind::ShiftLeft
|| operator == BinaryOpKind::ShiftRight
{
self.check_shift_overflow(result, rhs, bit_size, location, false)
} else {
let message = format!("attempt to {} with overflow", op_name);
self.builder.set_location(location).insert_range_check(
result,
bit_size,
Some(message),
);
result
let max_lhs_bits = self.builder.current_function.dfg.get_value_max_num_bits(lhs);
let max_rhs_bits = self.builder.current_function.dfg.get_value_max_num_bits(rhs);

match operator {
BinaryOpKind::Add => {
if std::cmp::max(max_lhs_bits, max_rhs_bits) < bit_size {
// `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
return result;
}

let message = "attempt to add with overflow".to_string();
self.builder.set_location(location).insert_range_check(
result,
bit_size,
Some(message),
);
}
BinaryOpKind::Subtract => {
if dfg.is_constant(lhs) && max_lhs_bits > max_rhs_bits {
// `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0`
// Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible.
return result;
}

let message = "attempt to subtract with overflow".to_string();
self.builder.set_location(location).insert_range_check(
result,
bit_size,
Some(message),
);
}
BinaryOpKind::Multiply => {
if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size {
// Either performing boolean multiplication (which cannot overflow),
// or `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow.
return result;
}

let message = "attempt to multiply with overflow".to_string();
self.builder.set_location(location).insert_range_check(
result,
bit_size,
Some(message),
);
}
BinaryOpKind::ShiftLeft => {
if let Some(rhs_const) = dfg.get_numeric_constant(rhs) {
let bit_shift_size = rhs_const.to_u128() as u32;

if max_lhs_bits + bit_shift_size <= bit_size {
// `lhs` has been casted up from a smaller type such that shifting it by a constant
// `rhs` is known not to exceed the maximum bit size.
return result;
}
}

self.check_shift_overflow(result, rhs, bit_size, location, false);
}

_ => unreachable!("operator {} should not overflow", operator),
}

result
}
_ => result,
}
Expand Down
Loading