diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 0e7bfff7b6..fdd7c66684 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -641,8 +641,25 @@ impl<'f> Context<'f> { match instruction { Instruction::Constrain(lhs, rhs, message) => { // Replace constraint `lhs == rhs` with `condition * lhs == condition * rhs`. - let lhs = self.handle_constrain_arg_side_effects(lhs, condition, &call_stack); - let rhs = self.handle_constrain_arg_side_effects(rhs, condition, &call_stack); + + // Condition needs to be cast to argument type in order to multiply them together. + let argument_type = self.inserter.function.dfg.type_of_value(lhs); + // Sanity check that we're not constraining non-primitive types + assert!(matches!(argument_type, Type::Numeric(_))); + + let casted_condition = self.insert_instruction( + Instruction::Cast(condition, argument_type), + call_stack.clone(), + ); + + let lhs = self.insert_instruction( + Instruction::binary(BinaryOp::Mul, lhs, casted_condition), + call_stack.clone(), + ); + let rhs = self.insert_instruction( + Instruction::binary(BinaryOp::Mul, rhs, casted_condition), + call_stack, + ); Instruction::Constrain(lhs, rhs, message) } @@ -673,90 +690,6 @@ impl<'f> Context<'f> { } } - /// Given the arguments of a constrain instruction, multiplying them by the branch's condition - /// requires special handling in the case of complex types. - fn handle_constrain_arg_side_effects( - &mut self, - argument: ValueId, - condition: ValueId, - call_stack: &CallStack, - ) -> ValueId { - let argument_type = self.inserter.function.dfg.type_of_value(argument); - - match &argument_type { - Type::Numeric(_) => { - // Condition needs to be cast to argument type in order to multiply them together. - let casted_condition = self.insert_instruction( - Instruction::Cast(condition, argument_type), - call_stack.clone(), - ); - - self.insert_instruction( - Instruction::binary(BinaryOp::Mul, argument, casted_condition), - call_stack.clone(), - ) - } - Type::Array(_, _) => { - self.handle_array_constrain_arg(argument_type, argument, condition, call_stack) - } - Type::Slice(_) => { - panic!("Cannot use slices directly in a constrain statement") - } - Type::Reference(_) => { - panic!("Cannot use references directly in a constrain statement") - } - Type::Function => { - panic!("Cannot use functions directly in a constrain statement") - } - } - } - - fn handle_array_constrain_arg( - &mut self, - typ: Type, - argument: ValueId, - condition: ValueId, - call_stack: &CallStack, - ) -> ValueId { - let mut new_array = im::Vector::new(); - - let (element_types, len) = match &typ { - Type::Array(elements, len) => (elements, *len), - _ => panic!("Expected array type"), - }; - - for i in 0..len { - for (element_index, element_type) in element_types.iter().enumerate() { - let index = ((i * element_types.len() + element_index) as u128).into(); - let index = self.inserter.function.dfg.make_constant(index, Type::field()); - - let typevars = Some(vec![element_type.clone()]); - - let mut get_element = |array, typevars| { - let get = Instruction::ArrayGet { array, index }; - self.inserter - .function - .dfg - .insert_instruction_and_results( - get, - self.inserter.function.entry_block(), - typevars, - CallStack::new(), - ) - .first() - }; - - let element = get_element(argument, typevars); - - new_array.push_back( - self.handle_constrain_arg_side_effects(element, condition, call_stack), - ); - } - } - - self.inserter.function.dfg.make_array(new_array, typ) - } - fn undo_stores_in_then_branch(&mut self, then_branch: &Branch) { for (address, store) in &then_branch.store_values { let address = *address; diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index d7e6b8b0a3..c00fbbbcb4 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -8,8 +8,8 @@ use context::SharedContext; use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; use noirc_frontend::{ - monomorphization::ast::{self, Binary, Expression, Program}, - BinaryOpKind, Visibility, + monomorphization::ast::{self, Expression, Program}, + Visibility, }; use crate::{ @@ -653,24 +653,10 @@ impl<'a> FunctionContext<'a> { location: Location, assert_message: Option, ) -> Result { - match expr { - // If we're constraining an equality to be true then constrain the two sides directly. - Expression::Binary(Binary { lhs, operator: BinaryOpKind::Equal, rhs, .. }) => { - let lhs = self.codegen_non_tuple_expression(lhs)?; - let rhs = self.codegen_non_tuple_expression(rhs)?; - self.builder.set_location(location).insert_constrain(lhs, rhs, assert_message); - } + let expr = self.codegen_non_tuple_expression(expr)?; + let true_literal = self.builder.numeric_constant(true, Type::bool()); + self.builder.set_location(location).insert_constrain(expr, true_literal, assert_message); - _ => { - let expr = self.codegen_non_tuple_expression(expr)?; - let true_literal = self.builder.numeric_constant(true, Type::bool()); - self.builder.set_location(location).insert_constrain( - expr, - true_literal, - assert_message, - ); - } - } Ok(Self::unit_value()) }