Skip to content

Commit

Permalink
feat: remove unnecessary predicate from Lt instruction (#3922)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves <!-- Link to GitHub Issue -->

## Summary\*

This PR removes the predicate from the ACIR codegened from
`Instruction::Lt`.

I think this predicate got added as a knee-jerk reaction to there being
divisions involved however all these divisions are done with a known
non-zero constant divisor so we can always perform the division safely.

This fixes a discrepancy as we're applying
`current_side_effects_enabled_var` despite
https://github.com/noir-lang/noir/blob/aabe5c15bc9d509e6953e689a5f479b26b972384/compiler/noirc_evaluator/src/ssa/ir/instruction.rs#L263-L273
stating that `Lt` not having side effects. As a bonus we can benefit
from optimisations in acir gen.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x ] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
TomAFrench committed Jan 2, 2024
1 parent e58844d commit a63433f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 37 deletions.
29 changes: 14 additions & 15 deletions compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1002,23 +1002,25 @@ impl AcirContext {
lhs: AcirVar,
rhs: AcirVar,
bit_count: u32,
predicate: AcirVar,
) -> Result<AcirVar, RuntimeError> {
let pow_last = self.add_constant(FieldElement::from(1_u128 << (bit_count - 1)));
let pow = self.add_constant(FieldElement::from(1_u128 << (bit_count)));

// We check whether the inputs have same sign or not by computing the XOR of their bit sign

// Predicate is always active as `pow_last` is known to be non-zero.
let one = self.add_constant(1_u128);
let lhs_sign = self.div_var(
lhs,
pow_last,
AcirType::NumericType(NumericType::Unsigned { bit_size: bit_count }),
predicate,
one,
)?;
let rhs_sign = self.div_var(
rhs,
pow_last,
AcirType::NumericType(NumericType::Unsigned { bit_size: bit_count }),
predicate,
one,
)?;
let same_sign = self.xor_var(
lhs_sign,
Expand All @@ -1031,7 +1033,7 @@ impl AcirContext {
let diff = self.sub_var(no_underflow, rhs)?;

// We check the 'bit sign' of the difference
let diff_sign = self.less_than_var(diff, pow, bit_count + 1, predicate)?;
let diff_sign = self.less_than_var(diff, pow, bit_count + 1)?;

// Then the result is simply diff_sign XOR same_sign (can be checked with a truth table)
self.xor_var(
Expand All @@ -1048,7 +1050,6 @@ impl AcirContext {
lhs: AcirVar,
rhs: AcirVar,
max_bits: u32,
predicate: AcirVar,
) -> Result<AcirVar, RuntimeError> {
// Returns a `Witness` that is constrained to be:
// - `1` if lhs >= rhs
Expand All @@ -1073,6 +1074,7 @@ impl AcirContext {
//
// TODO: perhaps this should be a user error, instead of an assert
assert!(max_bits + 1 < FieldElement::max_num_bits());

let two_max_bits = self
.add_constant(FieldElement::from(2_i128).pow(&FieldElement::from(max_bits as i128)));
let diff = self.sub_var(lhs, rhs)?;
Expand Down Expand Up @@ -1102,13 +1104,11 @@ impl AcirContext {
// let k = b - a
// - 2^{max_bits} - k == q * 2^{max_bits} + r
// - This is only the case when q == 0 and r == 2^{max_bits} - k
//
let (q, _) = self.euclidean_division_var(
comparison_evaluation,
two_max_bits,
max_bits + 1,
predicate,
)?;

// Predicate is always active as we know `two_max_bits` is always non-zero.
let one = self.add_constant(1_u128);
let (q, _) =
self.euclidean_division_var(comparison_evaluation, two_max_bits, max_bits + 1, one)?;
Ok(q)
}

Expand All @@ -1119,11 +1119,10 @@ impl AcirContext {
lhs: AcirVar,
rhs: AcirVar,
bit_size: u32,
predicate: AcirVar,
) -> Result<AcirVar, RuntimeError> {
// Flip the result of calling more than equal method to
// compute less than.
let comparison = self.more_than_eq_var(lhs, rhs, bit_size, predicate)?;
let comparison = self.more_than_eq_var(lhs, rhs, bit_size)?;

let one = self.add_constant(FieldElement::one());
self.sub_var(one, comparison) // comparison_negated
Expand Down Expand Up @@ -1519,7 +1518,7 @@ impl AcirContext {
bit_size: u32,
predicate: AcirVar,
) -> Result<(), RuntimeError> {
let lhs_less_than_rhs = self.more_than_eq_var(rhs, lhs, bit_size, predicate)?;
let lhs_less_than_rhs = self.more_than_eq_var(rhs, lhs, bit_size)?;
self.maybe_eq_predicate(lhs_less_than_rhs, predicate)
}

Expand Down
30 changes: 8 additions & 22 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1571,15 +1571,10 @@ impl Context {
// this Eq instruction is being used for a constrain statement
BinaryOp::Eq => self.acir_context.eq_var(lhs, rhs),
BinaryOp::Lt => match binary_type {
AcirType::NumericType(NumericType::Signed { .. }) => self
.acir_context
.less_than_signed(lhs, rhs, bit_count, self.current_side_effects_enabled_var),
_ => self.acir_context.less_than_var(
lhs,
rhs,
bit_count,
self.current_side_effects_enabled_var,
),
AcirType::NumericType(NumericType::Signed { .. }) => {
self.acir_context.less_than_signed(lhs, rhs, bit_count)
}
_ => self.acir_context.less_than_var(lhs, rhs, bit_count),
},
BinaryOp::Xor => self.acir_context.xor_var(lhs, rhs, binary_type),
BinaryOp::And => self.acir_context.and_var(lhs, rhs, binary_type),
Expand Down Expand Up @@ -2141,19 +2136,11 @@ impl Context {
let current_index = self.acir_context.add_constant(i);

// Check that we are above the lower bound of the insertion index
let greater_eq_than_idx = self.acir_context.more_than_eq_var(
current_index,
flat_user_index,
64,
self.current_side_effects_enabled_var,
)?;
let greater_eq_than_idx =
self.acir_context.more_than_eq_var(current_index, flat_user_index, 64)?;
// Check that we are below the upper bound of the insertion index
let less_than_idx = self.acir_context.less_than_var(
current_index,
max_flat_user_index,
64,
self.current_side_effects_enabled_var,
)?;
let less_than_idx =
self.acir_context.less_than_var(current_index, max_flat_user_index, 64)?;

// Read from the original slice the value we want to insert into our new slice.
// We need to make sure that we read the previous element when our current index is greater than insertion index.
Expand Down Expand Up @@ -2328,7 +2315,6 @@ impl Context {
current_index,
flat_user_index,
64,
self.current_side_effects_enabled_var,
)?;

let shifted_value_pred =
Expand Down

0 comments on commit a63433f

Please sign in to comment.