Skip to content

Commit

Permalink
feat: make ACVM generic across fields (#5114)
Browse files Browse the repository at this point in the history
# Description

Step towards #5055 

## Summary\*

This PR starts us down the road towards removing the compile-time flags
for specifying which field is being used by replacing the usage of the
`FieldElement` type with an `AcirField` trait. This trait is mostly all
of the external methods from `FieldElement` dumped into it for now but
we can break it down in future PRs (XOR and AND are looking ready for
culling)

I've taken this route rather than just using `generic_ark::FieldElement`
as we'd need to have trait bounds for `PrimeField` anyway so we might as
well have our own trait.

I've had to delete a couple of trait implementations which fall afoul of
the orphan rule now it's being implemented for a generic type (e.g. we
need to use `MemoryValue::new_field` explicitly now) but nothing too
bad.

## Additional Context


## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
TomAFrench committed May 28, 2024
1 parent 56c1a85 commit 70f374c
Show file tree
Hide file tree
Showing 155 changed files with 1,760 additions and 1,519 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions acvm-repo/acir/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@ criterion.workspace = true
pprof.workspace = true

[features]
default = ["bn254"]
bn254 = ["acir_field/bn254", "brillig/bn254"]
bls12_381 = ["acir_field/bls12_381", "brillig/bls12_381"]
bn254 = ["acir_field/bn254"]
bls12_381 = ["acir_field/bls12_381"]

[[bench]]
name = "serialization"
Expand Down
8 changes: 4 additions & 4 deletions acvm-repo/acir/benches/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ use pprof::criterion::{Output, PProfProfiler};

const SIZES: [usize; 9] = [10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000];

fn sample_program(num_opcodes: usize) -> Program {
let assert_zero_opcodes: Vec<Opcode> = (0..num_opcodes)
fn sample_program(num_opcodes: usize) -> Program<FieldElement> {
let assert_zero_opcodes: Vec<Opcode<_>> = (0..num_opcodes)
.map(|i| {
Opcode::AssertZero(Expression {
mul_terms: vec![(
Expand Down Expand Up @@ -83,7 +83,7 @@ fn bench_deserialization(c: &mut Criterion) {
BenchmarkId::from_parameter(size),
&serialized_program,
|b, program| {
b.iter(|| Program::deserialize_program(program));
b.iter(|| Program::<FieldElement>::deserialize_program(program));
},
);
}
Expand All @@ -107,7 +107,7 @@ fn bench_deserialization(c: &mut Criterion) {
|b, program| {
b.iter(|| {
let mut deserializer = serde_json::Deserializer::from_slice(program);
Program::deserialize_program_base64(&mut deserializer)
Program::<FieldElement>::deserialize_program_base64(&mut deserializer)
});
},
);
Expand Down
10 changes: 5 additions & 5 deletions acvm-repo/acir/src/circuit/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ use serde::{Deserialize, Serialize};
/// Inputs for the Brillig VM. These are the initial inputs
/// that the Brillig VM will use to start.
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)]
pub enum BrilligInputs {
Single(Expression),
Array(Vec<Expression>),
pub enum BrilligInputs<F> {
Single(Expression<F>),
Array(Vec<Expression<F>>),
MemoryArray(BlockId),
}

Expand All @@ -24,6 +24,6 @@ pub enum BrilligOutputs {
/// a full Brillig function to be executed by the Brillig VM.
/// This is stored separately on a program and accessed through a [BrilligPointer].
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Debug)]
pub struct BrilligBytecode {
pub bytecode: Vec<BrilligOpcode>,
pub struct BrilligBytecode<F> {
pub bytecode: Vec<BrilligOpcode<F>>,
}
4 changes: 2 additions & 2 deletions acvm-repo/acir/src/circuit/directives.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
/// Directives do not apply any constraints.
/// You can think of them as opcodes that allow one to use non-determinism
/// In the future, this can be replaced with asm non-determinism blocks
pub enum Directive {
pub enum Directive<F> {
//decomposition of a: a=\sum b[i]*radix^i where b is an array of witnesses < radix in little endian form
ToLeRadix { a: Expression, b: Vec<Witness>, radix: u32 },
ToLeRadix { a: Expression<F>, b: Vec<Witness>, radix: u32 },
}
94 changes: 50 additions & 44 deletions acvm-repo/acir/src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub mod directives;
pub mod opcodes;

use crate::native_types::{Expression, Witness};
use acir_field::FieldElement;
use acir_field::AcirField;
pub use opcodes::Opcode;
use thiserror::Error;

Expand Down Expand Up @@ -38,17 +38,17 @@ pub enum ExpressionWidth {
/// A program represented by multiple ACIR circuits. The execution trace of these
/// circuits is dictated by construction of the [crate::native_types::WitnessStack].
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct Program {
pub functions: Vec<Circuit>,
pub unconstrained_functions: Vec<BrilligBytecode>,
pub struct Program<F> {
pub functions: Vec<Circuit<F>>,
pub unconstrained_functions: Vec<BrilligBytecode<F>>,
}

#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct Circuit {
pub struct Circuit<F> {
// current_witness_index is the highest witness index in the circuit. The next witness to be added to this circuit
// will take on this value. (The value is cached here as an optimization.)
pub current_witness_index: u32,
pub opcodes: Vec<Opcode>,
pub opcodes: Vec<Opcode<F>>,
pub expression_width: ExpressionWidth,

/// The set of private inputs to the circuit.
Expand All @@ -67,7 +67,7 @@ pub struct Circuit {
// Note: This should be a BTreeMap, but serde-reflect is creating invalid
// c++ code at the moment when it is, due to OpcodeLocation needing a comparison
// implementation which is never generated.
pub assert_messages: Vec<(OpcodeLocation, AssertionPayload)>,
pub assert_messages: Vec<(OpcodeLocation, AssertionPayload<F>)>,

/// States whether the backend should use a SNARK recursion friendly prover.
/// If implemented by a backend, this means that proofs generated with this circuit
Expand All @@ -76,15 +76,15 @@ pub struct Circuit {
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ExpressionOrMemory {
Expression(Expression),
pub enum ExpressionOrMemory<F> {
Expression(Expression<F>),
Memory(BlockId),
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum AssertionPayload {
pub enum AssertionPayload<F> {
StaticString(String),
Dynamic(/* error_selector */ u64, Vec<ExpressionOrMemory>),
Dynamic(/* error_selector */ u64, Vec<ExpressionOrMemory<F>>),
}

#[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)]
Expand Down Expand Up @@ -127,15 +127,15 @@ impl<'de> Deserialize<'de> for ErrorSelector {
pub const STRING_ERROR_SELECTOR: ErrorSelector = ErrorSelector(0);

#[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)]
pub struct RawAssertionPayload {
pub struct RawAssertionPayload<F> {
pub selector: ErrorSelector,
pub data: Vec<FieldElement>,
pub data: Vec<F>,
}

#[derive(Clone, PartialEq, Eq, Debug)]
pub enum ResolvedAssertionPayload {
pub enum ResolvedAssertionPayload<F> {
String(String),
Raw(RawAssertionPayload),
Raw(RawAssertionPayload<F>),
}

#[derive(Debug, Copy, Clone)]
Expand Down Expand Up @@ -204,7 +204,7 @@ impl FromStr for OpcodeLocation {
}
}

impl Circuit {
impl<F: AcirField> Circuit<F> {
pub fn num_vars(&self) -> u32 {
self.current_witness_index + 1
}
Expand All @@ -223,7 +223,7 @@ impl Circuit {
}
}

impl Program {
impl<F: Serialize> Program<F> {
fn write<W: std::io::Write>(&self, writer: W) -> std::io::Result<()> {
let buf = bincode::serialize(self).unwrap();
let mut encoder = flate2::write::GzEncoder::new(writer, Compression::default());
Expand All @@ -232,36 +232,38 @@ impl Program {
Ok(())
}

fn read<R: std::io::Read>(reader: R) -> std::io::Result<Self> {
let mut gz_decoder = flate2::read::GzDecoder::new(reader);
let mut buf_d = Vec::new();
gz_decoder.read_to_end(&mut buf_d)?;
bincode::deserialize(&buf_d)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
}

pub fn serialize_program(program: &Program) -> Vec<u8> {
pub fn serialize_program(program: &Self) -> Vec<u8> {
let mut program_bytes: Vec<u8> = Vec::new();
program.write(&mut program_bytes).expect("expected circuit to be serializable");
program_bytes
}

pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
Program::read(serialized_circuit)
}

// Serialize and base64 encode program
pub fn serialize_program_base64<S>(program: &Program, s: S) -> Result<S::Ok, S::Error>
pub fn serialize_program_base64<S>(program: &Self, s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let program_bytes = Program::serialize_program(program);
let encoded_b64 = base64::engine::general_purpose::STANDARD.encode(program_bytes);
s.serialize_str(&encoded_b64)
}
}

impl<F: for<'a> Deserialize<'a>> Program<F> {
fn read<R: std::io::Read>(reader: R) -> std::io::Result<Self> {
let mut gz_decoder = flate2::read::GzDecoder::new(reader);
let mut buf_d = Vec::new();
gz_decoder.read_to_end(&mut buf_d)?;
bincode::deserialize(&buf_d)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
}

pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result<Self> {
Program::read(serialized_circuit)
}

// Deserialize and base64 decode program
pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Program, D::Error>
pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Expand All @@ -274,7 +276,7 @@ impl Program {
}
}

impl std::fmt::Display for Circuit {
impl<F: AcirField> std::fmt::Display for Circuit<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "current witness index : {}", self.current_witness_index)?;

Expand Down Expand Up @@ -313,13 +315,13 @@ impl std::fmt::Display for Circuit {
}
}

impl std::fmt::Debug for Circuit {
impl<F: AcirField> std::fmt::Debug for Circuit<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
}

impl std::fmt::Display for Program {
impl<F: AcirField> std::fmt::Display for Program<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for (func_index, function) in self.functions.iter().enumerate() {
writeln!(f, "func {}", func_index)?;
Expand All @@ -333,7 +335,7 @@ impl std::fmt::Display for Program {
}
}

impl std::fmt::Debug for Program {
impl<F: AcirField> std::fmt::Debug for Program<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(self, f)
}
Expand Down Expand Up @@ -365,21 +367,22 @@ mod tests {
circuit::{ExpressionWidth, Program},
native_types::Witness,
};
use acir_field::FieldElement;
use acir_field::{AcirField, FieldElement};
use serde::{Deserialize, Serialize};

fn and_opcode() -> Opcode {
fn and_opcode<F: AcirField>() -> Opcode<F> {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND {
lhs: FunctionInput { witness: Witness(1), num_bits: 4 },
rhs: FunctionInput { witness: Witness(2), num_bits: 4 },
output: Witness(3),
})
}
fn range_opcode() -> Opcode {
fn range_opcode<F: AcirField>() -> Opcode<F> {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness: Witness(1), num_bits: 8 },
})
}
fn keccakf1600_opcode() -> Opcode {
fn keccakf1600_opcode<F: AcirField>() -> Opcode<F> {
let inputs: Box<[FunctionInput; 25]> = Box::new(std::array::from_fn(|i| FunctionInput {
witness: Witness(i as u32 + 1),
num_bits: 8,
Expand All @@ -388,7 +391,7 @@ mod tests {

Opcode::BlackBoxFuncCall(BlackBoxFuncCall::Keccakf1600 { inputs, outputs })
}
fn schnorr_verify_opcode() -> Opcode {
fn schnorr_verify_opcode<F: AcirField>() -> Opcode<F> {
let public_key_x =
FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() };
let public_key_y =
Expand All @@ -413,7 +416,7 @@ mod tests {
let circuit = Circuit {
current_witness_index: 5,
expression_width: ExpressionWidth::Unbounded,
opcodes: vec![and_opcode(), range_opcode(), schnorr_verify_opcode()],
opcodes: vec![and_opcode::<FieldElement>(), range_opcode(), schnorr_verify_opcode()],
private_parameters: BTreeSet::new(),
public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(12)])),
return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(4), Witness(12)])),
Expand All @@ -422,7 +425,9 @@ mod tests {
};
let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() };

fn read_write(program: Program) -> (Program, Program) {
fn read_write<F: AcirField + Serialize + for<'a> Deserialize<'a>>(
program: Program<F>,
) -> (Program<F>, Program<F>) {
let bytes = Program::serialize_program(&program);
let got_program = Program::deserialize_program(&bytes).unwrap();
(program, got_program)
Expand Down Expand Up @@ -475,7 +480,8 @@ mod tests {
encoder.write_all(bad_circuit).unwrap();
encoder.finish().unwrap();

let deserialization_result = Program::deserialize_program(&zipped_bad_circuit);
let deserialization_result: Result<Program<FieldElement>, _> =
Program::deserialize_program(&zipped_bad_circuit);
assert!(deserialization_result.is_err());
}
}

0 comments on commit 70f374c

Please sign in to comment.