Skip to content

Commit

Permalink
fix: prevent Instruction::Constrains for non-primitive types (#3916)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Followup to #3740

## Summary\*

#3740 fixed an issue where array equalities were making their way into
SSA and not having side effect predicates applied correctly by applying
the predicate to each of the array elements. We actually state that we
do not want array equalities in SSA (in `insert_array_equality`) so the
fundamental issue of array equalities in SSA still exists.

We were allowing an implicit array equality to sneak into SSA by
performing an optimization of `Constrain(Eq(x, y), 1)` into
`Constrain(x, y)` during codegen. This meant that if `x` and `y` were
arrays then we bypass the `insert_array_equality` function which the
`Eq` instruction would call, which would have calculated a primitive
predicate value for the constrain statement to act on.

This PR removes the extra logic from the `flatten_cfg` pass (while
adding an assert that we're only constraining primitive values) and
removes the faulty optimization from SSA codegen.


## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
TomAFrench committed Jan 2, 2024
1 parent a63433f commit 467948f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 105 deletions.
105 changes: 19 additions & 86 deletions compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,25 @@ impl<'f> Context<'f> {
match instruction {
Instruction::Constrain(lhs, rhs, message) => {
// Replace constraint `lhs == rhs` with `condition * lhs == condition * rhs`.
let lhs = self.handle_constrain_arg_side_effects(lhs, condition, &call_stack);
let rhs = self.handle_constrain_arg_side_effects(rhs, condition, &call_stack);

// Condition needs to be cast to argument type in order to multiply them together.
let argument_type = self.inserter.function.dfg.type_of_value(lhs);
// Sanity check that we're not constraining non-primitive types
assert!(matches!(argument_type, Type::Numeric(_)));

let casted_condition = self.insert_instruction(
Instruction::Cast(condition, argument_type),
call_stack.clone(),
);

let lhs = self.insert_instruction(
Instruction::binary(BinaryOp::Mul, lhs, casted_condition),
call_stack.clone(),
);
let rhs = self.insert_instruction(
Instruction::binary(BinaryOp::Mul, rhs, casted_condition),
call_stack,
);

Instruction::Constrain(lhs, rhs, message)
}
Expand Down Expand Up @@ -673,90 +690,6 @@ impl<'f> Context<'f> {
}
}

/// Given the arguments of a constrain instruction, multiplying them by the branch's condition
/// requires special handling in the case of complex types.
fn handle_constrain_arg_side_effects(
&mut self,
argument: ValueId,
condition: ValueId,
call_stack: &CallStack,
) -> ValueId {
let argument_type = self.inserter.function.dfg.type_of_value(argument);

match &argument_type {
Type::Numeric(_) => {
// Condition needs to be cast to argument type in order to multiply them together.
let casted_condition = self.insert_instruction(
Instruction::Cast(condition, argument_type),
call_stack.clone(),
);

self.insert_instruction(
Instruction::binary(BinaryOp::Mul, argument, casted_condition),
call_stack.clone(),
)
}
Type::Array(_, _) => {
self.handle_array_constrain_arg(argument_type, argument, condition, call_stack)
}
Type::Slice(_) => {
panic!("Cannot use slices directly in a constrain statement")
}
Type::Reference(_) => {
panic!("Cannot use references directly in a constrain statement")
}
Type::Function => {
panic!("Cannot use functions directly in a constrain statement")
}
}
}

fn handle_array_constrain_arg(
&mut self,
typ: Type,
argument: ValueId,
condition: ValueId,
call_stack: &CallStack,
) -> ValueId {
let mut new_array = im::Vector::new();

let (element_types, len) = match &typ {
Type::Array(elements, len) => (elements, *len),
_ => panic!("Expected array type"),
};

for i in 0..len {
for (element_index, element_type) in element_types.iter().enumerate() {
let index = ((i * element_types.len() + element_index) as u128).into();
let index = self.inserter.function.dfg.make_constant(index, Type::field());

let typevars = Some(vec![element_type.clone()]);

let mut get_element = |array, typevars| {
let get = Instruction::ArrayGet { array, index };
self.inserter
.function
.dfg
.insert_instruction_and_results(
get,
self.inserter.function.entry_block(),
typevars,
CallStack::new(),
)
.first()
};

let element = get_element(argument, typevars);

new_array.push_back(
self.handle_constrain_arg_side_effects(element, condition, call_stack),
);
}
}

self.inserter.function.dfg.make_array(new_array, typ)
}

fn undo_stores_in_then_branch(&mut self, then_branch: &Branch) {
for (address, store) in &then_branch.store_values {
let address = *address;
Expand Down
24 changes: 5 additions & 19 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use context::SharedContext;
use iter_extended::{try_vecmap, vecmap};
use noirc_errors::Location;
use noirc_frontend::{
monomorphization::ast::{self, Binary, Expression, Program},
BinaryOpKind, Visibility,
monomorphization::ast::{self, Expression, Program},
Visibility,
};

use crate::{
Expand Down Expand Up @@ -653,24 +653,10 @@ impl<'a> FunctionContext<'a> {
location: Location,
assert_message: Option<String>,
) -> Result<Values, RuntimeError> {
match expr {
// If we're constraining an equality to be true then constrain the two sides directly.
Expression::Binary(Binary { lhs, operator: BinaryOpKind::Equal, rhs, .. }) => {
let lhs = self.codegen_non_tuple_expression(lhs)?;
let rhs = self.codegen_non_tuple_expression(rhs)?;
self.builder.set_location(location).insert_constrain(lhs, rhs, assert_message);
}
let expr = self.codegen_non_tuple_expression(expr)?;
let true_literal = self.builder.numeric_constant(true, Type::bool());
self.builder.set_location(location).insert_constrain(expr, true_literal, assert_message);

_ => {
let expr = self.codegen_non_tuple_expression(expr)?;
let true_literal = self.builder.numeric_constant(true, Type::bool());
self.builder.set_location(location).insert_constrain(
expr,
true_literal,
assert_message,
);
}
}
Ok(Self::unit_value())
}

Expand Down

0 comments on commit 467948f

Please sign in to comment.