Skip to content

Commit

Permalink
chore: non-deterministic array sort (#4279)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves #4171
We have an intrinsic sort function to do sorting on numeric arrays. It
allows field elements but fail with them.
At the same time we have a more generic sort_via method with works with
custom ordering function but it has bad performance.

## Summary\*
I have rationalised the sorting functions for arrays, we now have only
one which works on any type and support custom ordering functions.
The new sort function benefits from dynamic arrays and does not need
sorting networks, so I was able to remove a lot of deprecated code. See
below for the impact on performance.


The issue 4171 is resolved by specifying the Ord trait for array
elements.


## Additional Context
So we have 3 sort functions:
the new one 'ultra_sort'
the pre-ultra plonk 'array_sort'
the customisable one, 'sort_via'
With this PR, we only keep the new one 'ultra_sort'. Here how it
compares with the others:

Sorting an array of 10 elements:
- Ultra sort: 99 Acir gates, 202 UltraPlonk gates
- Array sort: 235 Acir gates, 279 UltraPlonk gates
- Sort_via: 1733 Acir gates, 1867 UltraPlonk gates

Sorting an array of 100 elements:
- Ultra sort: 999 Acir gates, 1597 UltraPlonk gates
- Array sort: 6674 Acir gates, 6673 UltraPlonk gates
- Sort_via: 1524713 Acir gates, 1522136 UltraPlonk gates



## Documentation\*

Check one:
- [X] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [X] I have tested the changes locally.
- [X] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
guipublic committed Feb 7, 2024
1 parent 9a4bf16 commit 2ffef26
Show file tree
Hide file tree
Showing 18 changed files with 52 additions and 749 deletions.
60 changes: 1 addition & 59 deletions acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,18 +915,7 @@ namespace Circuit {
static ToLeRadix bincodeDeserialize(std::vector<uint8_t>);
};

struct PermutationSort {
std::vector<std::vector<Circuit::Expression>> inputs;
uint32_t tuple;
std::vector<Circuit::Witness> bits;
std::vector<uint32_t> sort_by;

friend bool operator==(const PermutationSort&, const PermutationSort&);
std::vector<uint8_t> bincodeSerialize() const;
static PermutationSort bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<ToLeRadix, PermutationSort> value;
std::variant<ToLeRadix> value;

friend bool operator==(const Directive&, const Directive&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4960,53 +4949,6 @@ Circuit::Directive::ToLeRadix serde::Deserializable<Circuit::Directive::ToLeRadi
return obj;
}

namespace Circuit {

inline bool operator==(const Directive::PermutationSort &lhs, const Directive::PermutationSort &rhs) {
if (!(lhs.inputs == rhs.inputs)) { return false; }
if (!(lhs.tuple == rhs.tuple)) { return false; }
if (!(lhs.bits == rhs.bits)) { return false; }
if (!(lhs.sort_by == rhs.sort_by)) { return false; }
return true;
}

inline std::vector<uint8_t> Directive::PermutationSort::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<Directive::PermutationSort>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline Directive::PermutationSort Directive::PermutationSort::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<Directive::PermutationSort>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::Directive::PermutationSort>::serialize(const Circuit::Directive::PermutationSort &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.inputs)>::serialize(obj.inputs, serializer);
serde::Serializable<decltype(obj.tuple)>::serialize(obj.tuple, serializer);
serde::Serializable<decltype(obj.bits)>::serialize(obj.bits, serializer);
serde::Serializable<decltype(obj.sort_by)>::serialize(obj.sort_by, serializer);
}

template <>
template <typename Deserializer>
Circuit::Directive::PermutationSort serde::Deserializable<Circuit::Directive::PermutationSort>::deserialize(Deserializer &deserializer) {
Circuit::Directive::PermutationSort obj;
obj.inputs = serde::Deserializable<decltype(obj.inputs)>::deserialize(deserializer);
obj.tuple = serde::Deserializable<decltype(obj.tuple)>::deserialize(deserializer);
obj.bits = serde::Deserializable<decltype(obj.bits)>::deserialize(deserializer);
obj.sort_by = serde::Deserializable<decltype(obj.sort_by)>::deserialize(deserializer);
return obj;
}

namespace Circuit {

inline bool operator==(const Expression &lhs, const Expression &rhs) {
Expand Down
15 changes: 1 addition & 14 deletions acvm-repo/acir/src/circuit/directives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,5 @@ use serde::{Deserialize, Serialize};
/// In the future, this can be replaced with asm non-determinism blocks
pub enum Directive {
//decomposition of a: a=\sum b[i]*radix^i where b is an array of witnesses < radix in little endian form
ToLeRadix {
a: Expression,
b: Vec<Witness>,
radix: u32,
},

// Sort directive, using a sorting network
// This directive is used to generate the values of the control bits for the sorting network such that its outputs are properly sorted according to sort_by
PermutationSort {
inputs: Vec<Vec<Expression>>, // Array of tuples to sort
tuple: u32, // tuple size; if 1 then inputs is a single array [a0,a1,..], if 2 then inputs=[(a0,b0),..] is [a0,b0,a1,b1,..], etc..
bits: Vec<Witness>, // control bits of the network which permutes the inputs into its sorted version
sort_by: Vec<u32>, // specify primary index to sort by, then the secondary,... For instance, if tuple is 2 and sort_by is [1,0], then a=[(a0,b0),..] is sorted by bi and then ai.
},
ToLeRadix { a: Expression, b: Vec<Witness>, radix: u32 },
}
14 changes: 0 additions & 14 deletions acvm-repo/acir/src/circuit/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,6 @@ impl std::fmt::Display for Opcode {
b.last().unwrap().witness_index(),
)
}
Opcode::Directive(Directive::PermutationSort { inputs: a, tuple, bits, sort_by }) => {
write!(f, "DIR::PERMUTATIONSORT ")?;
write!(
f,
"(permutation size: {} {}-tuples, sort_by: {:#?}, bits: [_{}..._{}]))",
a.len(),
tuple,
sort_by,
// (Note): the bits do not have contiguous index but there are too many for display
bits.first().unwrap().witness_index(),
bits.last().unwrap().witness_index(),
)
}

Opcode::Brillig(brillig) => {
write!(f, "BRILLIG: ")?;
writeln!(f, "inputs: {:?}", brillig.inputs)?;
Expand Down
5 changes: 0 additions & 5 deletions acvm-repo/acvm/src/compiler/transformers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,6 @@ pub(super) fn transform_internal(
transformer.mark_solvable(*witness);
}
}
Directive::PermutationSort { bits, .. } => {
for witness in bits {
transformer.mark_solvable(*witness);
}
}
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode);
Expand Down
37 changes: 0 additions & 37 deletions acvm-repo/acvm/src/pwg/directives/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use std::cmp::Ordering;

use acir::{circuit::directives::Directive, native_types::WitnessMap, FieldElement};
use num_bigint::BigUint;

use crate::OpcodeResolutionError;

use super::{get_value, insert_value, ErrorLocation};

mod sorting;

/// Attempts to solve the [`Directive`] opcode `directive`.
/// If successful, `initial_witness` will be mutated to contain the new witness assignment.
///
Expand Down Expand Up @@ -48,38 +44,5 @@ pub(super) fn solve_directives(

Ok(())
}
Directive::PermutationSort { inputs: a, tuple, bits, sort_by } => {
let mut val_a = Vec::new();
let mut base = Vec::new();
for (i, element) in a.iter().enumerate() {
assert_eq!(element.len(), *tuple as usize);
let mut element_val = Vec::with_capacity(*tuple as usize + 1);
for e in element {
element_val.push(get_value(e, initial_witness)?);
}
let field_i = FieldElement::from(i as i128);
element_val.push(field_i);
base.push(field_i);
val_a.push(element_val);
}
val_a.sort_by(|a, b| {
for i in sort_by {
let int_a = BigUint::from_bytes_be(&a[*i as usize].to_be_bytes());
let int_b = BigUint::from_bytes_be(&b[*i as usize].to_be_bytes());
let cmp = int_a.cmp(&int_b);
if cmp != Ordering::Equal {
return cmp;
}
}
Ordering::Equal
});
let b = val_a.iter().map(|a| *a.last().unwrap()).collect();
let control = sorting::route(base, b);
for (w, value) in bits.iter().zip(control) {
let value = if value { FieldElement::one() } else { FieldElement::zero() };
insert_value(w, value, initial_witness)?;
}
Ok(())
}
}
}
Loading

0 comments on commit 2ffef26

Please sign in to comment.