Skip to content

Commit

Permalink
fix(ssa refactor): Implement array equality in SSA-gen (#1704)
Browse files Browse the repository at this point in the history
* Implement array equality

* chore(ssa refactor): elements -> composite_type

---------

Co-authored-by: Joss <joss@aztecprotocol.com>
  • Loading branch information
jfecher and joss-aztec committed Jun 15, 2023
1 parent c7a7216 commit 0d31d83
Showing 1 changed file with 91 additions and 0 deletions.
91 changes: 91 additions & 0 deletions crates/noirc_evaluator/src/ssa_refactor/ssa_gen/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ impl<'a> FunctionContext<'a> {
) -> Values {
let op = convert_operator(operator);

if op == BinaryOp::Eq && matches!(self.builder.type_of_value(lhs), Type::Array(..)) {
return self.insert_array_equality(lhs, operator, rhs);
}

if operator_requires_swapped_operands(operator) {
std::mem::swap(&mut lhs, &mut rhs);
}
Expand Down Expand Up @@ -255,6 +259,93 @@ impl<'a> FunctionContext<'a> {
result.into()
}

/// The frontend claims to support equality (==) on arrays, so we must support it in SSA here.
/// The actual BinaryOp::Eq in SSA is meant only for primitive numeric types so we encode an
/// entire equality loop on each array element. The generated IR is as follows:
///
/// ...
/// result_alloc = allocate
/// store u1 1 in result_alloc
/// jmp loop_start(0)
/// loop_start(i: Field):
/// v0 = lt i, array_len
/// jmpif v0, then: loop_body, else: loop_end
/// loop_body():
/// v1 = array_get lhs, index i
/// v2 = array_get rhs, index i
/// v3 = eq v1, v2
/// v4 = load result_alloc
/// v5 = and v4, v3
/// store v5 in result_alloc
/// v6 = add i, Field 1
/// jmp loop_start(v6)
/// loop_end():
/// result = load result_alloc
fn insert_array_equality(
&mut self,
lhs: ValueId,
operator: noirc_frontend::BinaryOpKind,
rhs: ValueId,
) -> Values {
let lhs_type = self.builder.type_of_value(lhs);
let rhs_type = self.builder.type_of_value(rhs);

let (array_length, element_type) = match (lhs_type, rhs_type) {
(
Type::Array(lhs_composite_type, lhs_length),
Type::Array(rhs_composite_type, rhs_length),
) => {
assert!(
lhs_composite_type.len() == 1 && rhs_composite_type.len() == 1,
"== is unimplemented for arrays of structs"
);
assert_eq!(lhs_composite_type[0], rhs_composite_type[0]);
assert_eq!(lhs_length, rhs_length, "Expected two arrays of equal length");
(lhs_length, lhs_composite_type[0].clone())
}
_ => unreachable!("Expected two array values"),
};

let loop_start = self.builder.insert_block();
let loop_body = self.builder.insert_block();
let loop_end = self.builder.insert_block();

// pre-loop
let result_alloc = self.builder.insert_allocate();
let true_value = self.builder.numeric_constant(1u128, Type::bool());
self.builder.insert_store(result_alloc, true_value);
let zero = self.builder.field_constant(0u128);
self.builder.terminate_with_jmp(loop_start, vec![zero]);

// loop_start
self.builder.switch_to_block(loop_start);
let i = self.builder.add_block_parameter(loop_start, Type::field());
let array_length = self.builder.field_constant(array_length as u128);
let v0 = self.builder.insert_binary(i, BinaryOp::Lt, array_length);
self.builder.terminate_with_jmpif(v0, loop_body, loop_end);

// loop body
self.builder.switch_to_block(loop_body);
let v1 = self.builder.insert_array_get(lhs, i, element_type.clone());
let v2 = self.builder.insert_array_get(rhs, i, element_type);
let v3 = self.builder.insert_binary(v1, BinaryOp::Eq, v2);
let v4 = self.builder.insert_load(result_alloc, Type::bool());
let v5 = self.builder.insert_binary(v4, BinaryOp::And, v3);
self.builder.insert_store(result_alloc, v5);
let one = self.builder.field_constant(1u128);
let v6 = self.builder.insert_binary(i, BinaryOp::Add, one);
self.builder.terminate_with_jmp(loop_start, vec![v6]);

// loop end
self.builder.switch_to_block(loop_end);
let mut result = self.builder.insert_load(result_alloc, Type::bool());

if operator_requires_not(operator) {
result = self.builder.insert_not(result);
}
result.into()
}

/// Inserts a call instruction at the end of the current block and returns the results
/// of the call.
///
Expand Down

0 comments on commit 0d31d83

Please sign in to comment.