Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: prevent Instruction::Constrains for non-primitive types #3916

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
105 changes: 19 additions & 86 deletions compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,25 @@ impl<'f> Context<'f> {
match instruction {
Instruction::Constrain(lhs, rhs, message) => {
// Replace constraint `lhs == rhs` with `condition * lhs == condition * rhs`.
let lhs = self.handle_constrain_arg_side_effects(lhs, condition, &call_stack);
let rhs = self.handle_constrain_arg_side_effects(rhs, condition, &call_stack);

// Condition needs to be cast to argument type in order to multiply them together.
let argument_type = self.inserter.function.dfg.type_of_value(lhs);
// Sanity check that we're not constraining non-primitive types
assert!(matches!(argument_type, Type::Numeric(_)));

let casted_condition = self.insert_instruction(
Instruction::Cast(condition, argument_type),
call_stack.clone(),
);

let lhs = self.insert_instruction(
Instruction::binary(BinaryOp::Mul, lhs, casted_condition),
call_stack.clone(),
);
let rhs = self.insert_instruction(
Instruction::binary(BinaryOp::Mul, rhs, casted_condition),
call_stack,
);

Instruction::Constrain(lhs, rhs, message)
}
Expand Down Expand Up @@ -673,90 +690,6 @@ impl<'f> Context<'f> {
}
}

/// Given the arguments of a constrain instruction, multiplying them by the branch's condition
/// requires special handling in the case of complex types.
fn handle_constrain_arg_side_effects(
&mut self,
argument: ValueId,
condition: ValueId,
call_stack: &CallStack,
) -> ValueId {
let argument_type = self.inserter.function.dfg.type_of_value(argument);

match &argument_type {
Type::Numeric(_) => {
// Condition needs to be cast to argument type in order to multiply them together.
let casted_condition = self.insert_instruction(
Instruction::Cast(condition, argument_type),
call_stack.clone(),
);

self.insert_instruction(
Instruction::binary(BinaryOp::Mul, argument, casted_condition),
call_stack.clone(),
)
}
Type::Array(_, _) => {
self.handle_array_constrain_arg(argument_type, argument, condition, call_stack)
}
Type::Slice(_) => {
panic!("Cannot use slices directly in a constrain statement")
}
Type::Reference(_) => {
panic!("Cannot use references directly in a constrain statement")
}
Type::Function => {
panic!("Cannot use functions directly in a constrain statement")
}
}
}

fn handle_array_constrain_arg(
&mut self,
typ: Type,
argument: ValueId,
condition: ValueId,
call_stack: &CallStack,
) -> ValueId {
let mut new_array = im::Vector::new();

let (element_types, len) = match &typ {
Type::Array(elements, len) => (elements, *len),
_ => panic!("Expected array type"),
};

for i in 0..len {
for (element_index, element_type) in element_types.iter().enumerate() {
let index = ((i * element_types.len() + element_index) as u128).into();
let index = self.inserter.function.dfg.make_constant(index, Type::field());

let typevars = Some(vec![element_type.clone()]);

let mut get_element = |array, typevars| {
let get = Instruction::ArrayGet { array, index };
self.inserter
.function
.dfg
.insert_instruction_and_results(
get,
self.inserter.function.entry_block(),
typevars,
CallStack::new(),
)
.first()
};

let element = get_element(argument, typevars);

new_array.push_back(
self.handle_constrain_arg_side_effects(element, condition, call_stack),
);
}
}

self.inserter.function.dfg.make_array(new_array, typ)
}

fn undo_stores_in_then_branch(&mut self, then_branch: &Branch) {
for (address, store) in &then_branch.store_values {
let address = *address;
Expand Down
24 changes: 5 additions & 19 deletions compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
use iter_extended::{try_vecmap, vecmap};
use noirc_errors::Location;
use noirc_frontend::{
monomorphization::ast::{self, Binary, Expression, Program},
BinaryOpKind, Visibility,
monomorphization::ast::{self, Expression, Program},
Visibility,
};

use crate::{
Expand Down Expand Up @@ -448,7 +448,7 @@
/// br loop_entry(v0)
/// loop_entry(i: Field):
/// v2 = lt i v1
/// brif v2, then: loop_body, else: loop_end

Check warning on line 451 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// loop_body():
/// v3 = ... codegen body ...
/// v4 = add 1, i
Expand Down Expand Up @@ -502,7 +502,7 @@
/// For example, the expression `if cond { a } else { b }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: else_block

Check warning on line 505 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block():
/// v1 = ... codegen a ...
/// br end_if(v1)
Expand All @@ -515,7 +515,7 @@
/// As another example, the expression `if cond { a }` is codegen'd as:
///
/// v0 = ... codegen cond ...
/// brif v0, then: then_block, else: end_block

Check warning on line 518 in compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (brif)
/// then_block:
/// v1 = ... codegen a ...
/// br end_if()
Expand Down Expand Up @@ -653,24 +653,10 @@
location: Location,
assert_message: Option<String>,
) -> Result<Values, RuntimeError> {
match expr {
// If we're constraining an equality to be true then constrain the two sides directly.
Expression::Binary(Binary { lhs, operator: BinaryOpKind::Equal, rhs, .. }) => {
let lhs = self.codegen_non_tuple_expression(lhs)?;
let rhs = self.codegen_non_tuple_expression(rhs)?;
self.builder.set_location(location).insert_constrain(lhs, rhs, assert_message);
}
let expr = self.codegen_non_tuple_expression(expr)?;
let true_literal = self.builder.numeric_constant(true, Type::bool());
self.builder.set_location(location).insert_constrain(expr, true_literal, assert_message);

_ => {
let expr = self.codegen_non_tuple_expression(expr)?;
let true_literal = self.builder.numeric_constant(true, Type::bool());
self.builder.set_location(location).insert_constrain(
expr,
true_literal,
assert_message,
);
}
}
Ok(Self::unit_value())
}

Expand Down