Skip to content

Commit

Permalink
feat: optimize out unnecessary truncation instructions (#3717)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves <!-- Link to GitHub Issue -->

## Summary\*

We're currently adding a double truncation when performing overflow
checks on left shifts. We could/should address this directly but as a
more general rule we can optimize out any truncation which is truncating
a value to fit into a type which is equal or larger to the source type.

I've also replaced a few instances of `insert_instruction` with the more
instruction-specific version to reduce verbosity.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: kevaundray <kevtheappdev@gmail.com>
  • Loading branch information
TomAFrench and kevaundray committed Dec 9, 2023
1 parent 61fe99d commit c9c72ae
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 46 deletions.
11 changes: 3 additions & 8 deletions compiler/noirc_evaluator/src/ssa/function_builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,11 @@ impl FunctionBuilder {
(FieldElement::max_num_bits(), self.insert_binary(predicate, BinaryOp::Mul, pow))
};

let instruction = Instruction::Binary(Binary { lhs, rhs: pow, operator: BinaryOp::Mul });
if max_bit <= bit_size {
self.insert_instruction(instruction, None).first()
self.insert_binary(lhs, BinaryOp::Mul, pow)
} else {
let result = self.insert_instruction(instruction, None).first();
self.insert_instruction(
Instruction::Truncate { value: result, bit_size, max_bit_size: max_bit },
None,
)
.first()
let result = self.insert_binary(lhs, BinaryOp::Mul, pow);
self.insert_truncate(result, bit_size, max_bit)
}
}

Expand Down
14 changes: 13 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ir/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,23 @@ impl Instruction {
}
None
}
Instruction::Truncate { value, bit_size, .. } => {
Instruction::Truncate { value, bit_size, max_bit_size } => {
if let Some((numeric_constant, typ)) = dfg.get_numeric_constant_with_type(*value) {
let integer_modulus = 2_u128.pow(*bit_size);
let truncated = numeric_constant.to_u128() % integer_modulus;
SimplifiedTo(dfg.make_constant(truncated.into(), typ))
} else if let Value::Instruction { instruction, .. } = &dfg[dfg.resolve(*value)] {
if let Instruction::Truncate { bit_size: src_bit_size, .. } = &dfg[*instruction]
{
// If we're truncating the value to fit into the same or larger bit size then this is a noop.
if src_bit_size <= bit_size && src_bit_size <= max_bit_size {
SimplifiedTo(*value)
} else {
None
}
} else {
None
}
} else {
None
}
Expand Down
49 changes: 12 additions & 37 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,11 @@ impl<'a> FunctionContext<'a> {
let bit_width =
self.builder.numeric_constant(FieldElement::from(2_i128.pow(bit_size)), Type::field());
let sign_not = self.builder.insert_binary(one, BinaryOp::Sub, sign);
let as_field =
self.builder.insert_instruction(Instruction::Cast(input, Type::field()), None).first();
let sign_field =
self.builder.insert_instruction(Instruction::Cast(sign, Type::field()), None).first();
let as_field = self.builder.insert_cast(input, Type::field());
let sign_field = self.builder.insert_cast(sign, Type::field());
let positive_predicate = self.builder.insert_binary(sign_field, BinaryOp::Mul, as_field);
let two_complement = self.builder.insert_binary(bit_width, BinaryOp::Sub, as_field);
let sign_not_field = self
.builder
.insert_instruction(Instruction::Cast(sign_not, Type::field()), None)
.first();
let sign_not_field = self.builder.insert_cast(sign_not, Type::field());
let negative_predicate =
self.builder.insert_binary(sign_not_field, BinaryOp::Mul, two_complement);
self.builder.insert_binary(positive_predicate, BinaryOp::Add, negative_predicate)
Expand Down Expand Up @@ -315,17 +310,8 @@ impl<'a> FunctionContext<'a> {
match operator {
BinaryOpKind::Add | BinaryOpKind::Subtract => {
// Result is computed modulo the bit size
let mut result = self
.builder
.insert_instruction(
Instruction::Truncate {
value: result,
bit_size,
max_bit_size: bit_size + 1,
},
None,
)
.first();
let mut result =
self.builder.insert_truncate(result, bit_size, bit_size + 1);
result = self.builder.insert_cast(result, Type::unsigned(bit_size));

self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location);
Expand All @@ -335,17 +321,7 @@ impl<'a> FunctionContext<'a> {
// Result is computed modulo the bit size
let mut result =
self.builder.insert_cast(result, Type::unsigned(2 * bit_size));
result = self
.builder
.insert_instruction(
Instruction::Truncate {
value: result,
bit_size,
max_bit_size: 2 * bit_size,
},
None,
)
.first();
result = self.builder.insert_truncate(result, bit_size, 2 * bit_size);

self.check_signed_overflow(result, lhs, rhs, operator, bit_size, location);
self.builder.insert_cast(result, result_type)
Expand Down Expand Up @@ -476,17 +452,16 @@ impl<'a> FunctionContext<'a> {

// Then we check the signed product fits in a signed integer of bit_size-bits
let not_same = self.builder.insert_binary(one, BinaryOp::Sub, same_sign);
let not_same_sign_field = self
.builder
.insert_instruction(Instruction::Cast(not_same, Type::unsigned(bit_size)), None)
.first();
let not_same_sign_field =
self.builder.insert_cast(not_same, Type::unsigned(bit_size));
let positive_maximum_with_offset =
self.builder.insert_binary(half_width, BinaryOp::Add, not_same_sign_field);
let product_overflow_check =
self.builder.insert_binary(product, BinaryOp::Lt, positive_maximum_with_offset);
self.builder.set_location(location).insert_instruction(
Instruction::Constrain(product_overflow_check, one, Some(message)),
None,
self.builder.set_location(location).insert_constrain(
product_overflow_check,
one,
Some(message),
);
}
BinaryOpKind::ShiftLeft => unreachable!("shift is not supported for signed integer"),
Expand Down

0 comments on commit c9c72ae

Please sign in to comment.