From 70f374c06642962d8f2b95b80f8c938fcf7761d7 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Tue, 28 May 2024 16:55:30 +0100 Subject: [PATCH] feat: make ACVM generic across fields (#5114) # Description Step towards #5055 ## Summary\* This PR starts us down the road towards removing the compile-time flags for specifying which field is being used by replacing the usage of the `FieldElement` type with an `AcirField` trait. This trait is mostly all of the external methods from `FieldElement` dumped into it for now but we can break it down in future PRs (XOR and AND are looking ready for culling) I've taken this route rather than just using `generic_ark::FieldElement` as we'd need to have trait bounds for `PrimeField` anyway so we might as well have our own trait. I've had to delete a couple of trait implementations which fall afoul of the orphan rule now it's being implemented for a generic type (e.g. we need to use `MemoryValue::new_field` explicitly now) but nothing too bad. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --- Cargo.lock | 2 + acvm-repo/acir/Cargo.toml | 5 +- acvm-repo/acir/benches/serialization.rs | 8 +- acvm-repo/acir/src/circuit/brillig.rs | 10 +- acvm-repo/acir/src/circuit/directives.rs | 4 +- acvm-repo/acir/src/circuit/mod.rs | 94 +++--- acvm-repo/acir/src/circuit/opcodes.rs | 21 +- .../opcodes/black_box_function_call.rs | 10 +- .../src/circuit/opcodes/memory_operation.rs | 15 +- acvm-repo/acir/src/lib.rs | 23 +- .../acir/src/native_types/expression/mod.rs | 135 ++++---- .../src/native_types/expression/operators.rs | 212 ++++++------ .../src/native_types/expression/ordering.rs | 10 +- acvm-repo/acir/src/native_types/witness.rs | 13 - .../acir/src/native_types/witness_map.rs | 37 +-- .../acir/src/native_types/witness_stack.rs | 26 +- .../acir/tests/test_program_serialization.rs | 37 ++- acvm-repo/acir_field/Cargo.toml | 15 +- acvm-repo/acir_field/src/generic_ark.rs | 312 +++++++++++------- acvm-repo/acir_field/src/lib.rs | 17 +- acvm-repo/acvm/Cargo.toml | 2 +- acvm-repo/acvm/src/compiler/mod.rs | 17 +- .../acvm/src/compiler/optimizers/general.rs | 20 +- acvm-repo/acvm/src/compiler/optimizers/mod.rs | 11 +- .../compiler/optimizers/redundant_range.rs | 21 +- .../src/compiler/optimizers/unused_memory.rs | 12 +- .../acvm/src/compiler/transformers/csat.rs | 244 +++++++------- .../acvm/src/compiler/transformers/mod.rs | 18 +- acvm-repo/acvm/src/lib.rs | 2 +- acvm-repo/acvm/src/pwg/arithmetic.rs | 139 ++++---- acvm-repo/acvm/src/pwg/blackbox/aes128.rs | 10 +- acvm-repo/acvm/src/pwg/blackbox/bigint.rs | 20 +- .../src/pwg/blackbox/embedded_curve_ops.rs | 17 +- acvm-repo/acvm/src/pwg/blackbox/hash.rs | 48 ++- acvm-repo/acvm/src/pwg/blackbox/logic.rs | 22 +- acvm-repo/acvm/src/pwg/blackbox/mod.rs | 18 +- acvm-repo/acvm/src/pwg/blackbox/pedersen.rs | 17 +- acvm-repo/acvm/src/pwg/blackbox/range.rs | 8 +- .../acvm/src/pwg/blackbox/signature/ecdsa.rs | 18 +- .../src/pwg/blackbox/signature/schnorr.rs | 16 +- acvm-repo/acvm/src/pwg/blackbox/utils.rs | 14 +- acvm-repo/acvm/src/pwg/brillig.rs | 68 ++-- acvm-repo/acvm/src/pwg/directives/mod.rs | 14 +- acvm-repo/acvm/src/pwg/memory_op.rs | 35 +- acvm-repo/acvm/src/pwg/mod.rs | 154 +++++---- acvm-repo/acvm/tests/solver.rs | 43 ++- acvm-repo/acvm_js/src/black_box_solvers.rs | 2 +- acvm-repo/acvm_js/src/execute.rs | 37 ++- acvm-repo/acvm_js/src/foreign_call/inputs.rs | 4 +- acvm-repo/acvm_js/src/foreign_call/mod.rs | 6 +- acvm-repo/acvm_js/src/foreign_call/outputs.rs | 12 +- acvm-repo/acvm_js/src/js_execution_error.rs | 7 +- acvm-repo/acvm_js/src/js_witness_map.rs | 13 +- acvm-repo/acvm_js/src/js_witness_stack.rs | 8 +- acvm-repo/acvm_js/src/public_witness.rs | 19 +- acvm-repo/blackbox_solver/Cargo.toml | 1 - .../src/curve_specific_solver.rs | 82 ++--- .../benches/criterion.rs | 2 +- .../src/embedded_curve_ops.rs | 7 +- acvm-repo/bn254_blackbox_solver/src/lib.rs | 2 +- .../src/pedersen/commitment.rs | 2 +- .../src/pedersen/hash.rs | 2 +- .../bn254_blackbox_solver/src/poseidon2.rs | 4 +- .../bn254_blackbox_solver/src/schnorr/mod.rs | 2 +- acvm-repo/brillig/Cargo.toml | 1 - acvm-repo/brillig/src/foreign_call.rs | 40 +-- acvm-repo/brillig/src/opcodes.rs | 6 +- acvm-repo/brillig_vm/Cargo.toml | 1 - acvm-repo/brillig_vm/src/arithmetic.rs | 37 ++- acvm-repo/brillig_vm/src/black_box.rs | 95 ++++-- acvm-repo/brillig_vm/src/lib.rs | 140 ++++---- acvm-repo/brillig_vm/src/memory.rs | 124 +++---- aztec_macros/Cargo.toml | 1 + aztec_macros/src/transforms/storage.rs | 1 + aztec_macros/src/utils/hir_utils.rs | 1 + compiler/noirc_driver/src/contract.rs | 4 +- compiler/noirc_driver/src/program.rs | 4 +- .../src/brillig/brillig_gen/brillig_block.rs | 2 +- .../brillig/brillig_gen/brillig_directive.rs | 1 + .../noirc_evaluator/src/brillig/brillig_ir.rs | 15 +- .../src/brillig/brillig_ir/artifact.rs | 14 +- .../brillig/brillig_ir/brillig_variable.rs | 1 + .../brillig/brillig_ir/codegen_intrinsic.rs | 1 + .../src/brillig/brillig_ir/entry_point.rs | 2 +- .../src/brillig/brillig_ir/instructions.rs | 5 +- compiler/noirc_evaluator/src/ssa.rs | 19 +- .../src/ssa/acir_gen/acir_ir/acir_variable.rs | 62 ++-- .../src/ssa/acir_gen/acir_ir/big_int.rs | 2 +- .../ssa/acir_gen/acir_ir/generated_acir.rs | 36 +- .../noirc_evaluator/src/ssa/acir_gen/mod.rs | 7 +- .../src/ssa/function_builder/mod.rs | 2 +- compiler/noirc_evaluator/src/ssa/ir/dfg.rs | 2 +- .../noirc_evaluator/src/ssa/ir/instruction.rs | 1 + .../src/ssa/ir/instruction/binary.rs | 2 +- .../src/ssa/ir/instruction/call.rs | 2 +- .../src/ssa/ir/instruction/cast.rs | 2 +- .../src/ssa/ir/instruction/constrain.rs | 2 +- .../noirc_evaluator/src/ssa/ir/printer.rs | 1 + compiler/noirc_evaluator/src/ssa/ir/types.rs | 2 +- .../src/ssa/opt/constant_folding.rs | 3 +- .../src/ssa/opt/flatten_cfg.rs | 4 +- .../ssa/opt/flatten_cfg/capacity_tracker.rs | 2 +- .../src/ssa/opt/flatten_cfg/value_merger.rs | 2 +- .../noirc_evaluator/src/ssa/opt/inlining.rs | 3 +- .../noirc_evaluator/src/ssa/opt/mem2reg.rs | 2 +- .../src/ssa/opt/remove_bit_shifts.rs | 2 +- .../src/ssa/opt/remove_enable_side_effects.rs | 2 +- .../src/ssa/opt/remove_if_else.rs | 2 +- .../src/ssa/opt/simplify_cfg.rs | 3 + .../noirc_evaluator/src/ssa/opt/unrolling.rs | 2 + .../src/ssa/ssa_gen/context.rs | 2 +- compiler/noirc_frontend/src/ast/expression.rs | 2 +- compiler/noirc_frontend/src/ast/mod.rs | 1 + compiler/noirc_frontend/src/ast/statement.rs | 1 + .../noirc_frontend/src/elaborator/types.rs | 1 + .../noirc_frontend/src/hir/comptime/errors.rs | 2 +- .../src/hir/comptime/interpreter.rs | 2 +- .../src/hir/resolution/resolver.rs | 2 + .../noirc_frontend/src/hir/type_check/stmt.rs | 1 + compiler/noirc_frontend/src/lexer/token.rs | 2 +- .../src/monomorphization/debug.rs | 1 + .../src/monomorphization/mod.rs | 2 +- compiler/noirc_printable_type/src/lib.rs | 14 +- compiler/wasm/Cargo.toml | 2 +- tooling/acvm_cli/src/cli/execute_cmd.rs | 7 +- tooling/acvm_cli/src/cli/fs/inputs.rs | 6 +- tooling/acvm_cli/src/cli/fs/witness.rs | 12 +- tooling/debugger/src/context.rs | 36 +- tooling/debugger/src/dap.rs | 16 +- tooling/debugger/src/foreign_calls.rs | 6 +- tooling/debugger/src/lib.rs | 16 +- tooling/debugger/src/repl.rs | 30 +- tooling/lsp/src/lib.rs | 12 +- tooling/lsp/src/solver.rs | 4 +- tooling/nargo/src/artifacts/contract.rs | 4 +- tooling/nargo/src/artifacts/program.rs | 3 +- tooling/nargo/src/errors.rs | 5 +- tooling/nargo/src/ops/execute.rs | 52 +-- tooling/nargo/src/ops/foreign_calls.rs | 52 +-- tooling/nargo/src/ops/optimize.rs | 7 +- tooling/nargo/src/ops/test.rs | 6 +- tooling/nargo/src/ops/transform.rs | 9 +- tooling/nargo_cli/Cargo.toml | 2 +- tooling/nargo_cli/src/cli/dap_cmd.rs | 3 +- tooling/nargo_cli/src/cli/debug_cmd.rs | 5 +- tooling/nargo_cli/src/cli/execute_cmd.rs | 5 +- tooling/nargo_cli/src/cli/fs/witness.rs | 4 +- tooling/nargo_cli/src/cli/test_cmd.rs | 6 +- tooling/noirc_abi/src/input_parser/json.rs | 2 +- tooling/noirc_abi/src/input_parser/mod.rs | 6 +- tooling/noirc_abi/src/input_parser/toml.rs | 2 +- tooling/noirc_abi/src/lib.rs | 8 +- tooling/noirc_abi_wasm/Cargo.toml | 2 +- tooling/noirc_abi_wasm/src/js_witness_map.rs | 10 +- tooling/noirc_abi_wasm/src/lib.rs | 15 +- 155 files changed, 1760 insertions(+), 1519 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bfc012d23f..919bdc4874 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -44,6 +44,7 @@ version = "0.46.0" dependencies = [ "acir", "acvm_blackbox_solver", + "ark-bls12-381", "brillig_vm", "indexmap 1.9.3", "num-bigint", @@ -445,6 +446,7 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" name = "aztec_macros" version = "0.30.0" dependencies = [ + "acvm", "convert_case 0.6.0", "iter-extended", "noirc_errors", diff --git a/acvm-repo/acir/Cargo.toml b/acvm-repo/acir/Cargo.toml index 32a9bbe830..101ce7a0f3 100644 --- a/acvm-repo/acir/Cargo.toml +++ b/acvm-repo/acir/Cargo.toml @@ -33,9 +33,8 @@ criterion.workspace = true pprof.workspace = true [features] -default = ["bn254"] -bn254 = ["acir_field/bn254", "brillig/bn254"] -bls12_381 = ["acir_field/bls12_381", "brillig/bls12_381"] +bn254 = ["acir_field/bn254"] +bls12_381 = ["acir_field/bls12_381"] [[bench]] name = "serialization" diff --git a/acvm-repo/acir/benches/serialization.rs b/acvm-repo/acir/benches/serialization.rs index a7f32b4a4c..792200c891 100644 --- a/acvm-repo/acir/benches/serialization.rs +++ b/acvm-repo/acir/benches/serialization.rs @@ -11,8 +11,8 @@ use pprof::criterion::{Output, PProfProfiler}; const SIZES: [usize; 9] = [10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000]; -fn sample_program(num_opcodes: usize) -> Program { - let assert_zero_opcodes: Vec = (0..num_opcodes) +fn sample_program(num_opcodes: usize) -> Program { + let assert_zero_opcodes: Vec> = (0..num_opcodes) .map(|i| { Opcode::AssertZero(Expression { mul_terms: vec![( @@ -83,7 +83,7 @@ fn bench_deserialization(c: &mut Criterion) { BenchmarkId::from_parameter(size), &serialized_program, |b, program| { - b.iter(|| Program::deserialize_program(program)); + b.iter(|| Program::::deserialize_program(program)); }, ); } @@ -107,7 +107,7 @@ fn bench_deserialization(c: &mut Criterion) { |b, program| { b.iter(|| { let mut deserializer = serde_json::Deserializer::from_slice(program); - Program::deserialize_program_base64(&mut deserializer) + Program::::deserialize_program_base64(&mut deserializer) }); }, ); diff --git a/acvm-repo/acir/src/circuit/brillig.rs b/acvm-repo/acir/src/circuit/brillig.rs index ecf6f7a976..ee25d05afb 100644 --- a/acvm-repo/acir/src/circuit/brillig.rs +++ b/acvm-repo/acir/src/circuit/brillig.rs @@ -6,9 +6,9 @@ use serde::{Deserialize, Serialize}; /// Inputs for the Brillig VM. These are the initial inputs /// that the Brillig VM will use to start. #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] -pub enum BrilligInputs { - Single(Expression), - Array(Vec), +pub enum BrilligInputs { + Single(Expression), + Array(Vec>), MemoryArray(BlockId), } @@ -24,6 +24,6 @@ pub enum BrilligOutputs { /// a full Brillig function to be executed by the Brillig VM. /// This is stored separately on a program and accessed through a [BrilligPointer]. #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default, Debug)] -pub struct BrilligBytecode { - pub bytecode: Vec, +pub struct BrilligBytecode { + pub bytecode: Vec>, } diff --git a/acvm-repo/acir/src/circuit/directives.rs b/acvm-repo/acir/src/circuit/directives.rs index 099d063439..3bc6628859 100644 --- a/acvm-repo/acir/src/circuit/directives.rs +++ b/acvm-repo/acir/src/circuit/directives.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; /// Directives do not apply any constraints. /// You can think of them as opcodes that allow one to use non-determinism /// In the future, this can be replaced with asm non-determinism blocks -pub enum Directive { +pub enum Directive { //decomposition of a: a=\sum b[i]*radix^i where b is an array of witnesses < radix in little endian form - ToLeRadix { a: Expression, b: Vec, radix: u32 }, + ToLeRadix { a: Expression, b: Vec, radix: u32 }, } diff --git a/acvm-repo/acir/src/circuit/mod.rs b/acvm-repo/acir/src/circuit/mod.rs index 6a26a45d88..7632afda42 100644 --- a/acvm-repo/acir/src/circuit/mod.rs +++ b/acvm-repo/acir/src/circuit/mod.rs @@ -4,7 +4,7 @@ pub mod directives; pub mod opcodes; use crate::native_types::{Expression, Witness}; -use acir_field::FieldElement; +use acir_field::AcirField; pub use opcodes::Opcode; use thiserror::Error; @@ -38,17 +38,17 @@ pub enum ExpressionWidth { /// A program represented by multiple ACIR circuits. The execution trace of these /// circuits is dictated by construction of the [crate::native_types::WitnessStack]. #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -pub struct Program { - pub functions: Vec, - pub unconstrained_functions: Vec, +pub struct Program { + pub functions: Vec>, + pub unconstrained_functions: Vec>, } #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Default)] -pub struct Circuit { +pub struct Circuit { // current_witness_index is the highest witness index in the circuit. The next witness to be added to this circuit // will take on this value. (The value is cached here as an optimization.) pub current_witness_index: u32, - pub opcodes: Vec, + pub opcodes: Vec>, pub expression_width: ExpressionWidth, /// The set of private inputs to the circuit. @@ -67,7 +67,7 @@ pub struct Circuit { // Note: This should be a BTreeMap, but serde-reflect is creating invalid // c++ code at the moment when it is, due to OpcodeLocation needing a comparison // implementation which is never generated. - pub assert_messages: Vec<(OpcodeLocation, AssertionPayload)>, + pub assert_messages: Vec<(OpcodeLocation, AssertionPayload)>, /// States whether the backend should use a SNARK recursion friendly prover. /// If implemented by a backend, this means that proofs generated with this circuit @@ -76,15 +76,15 @@ pub struct Circuit { } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum ExpressionOrMemory { - Expression(Expression), +pub enum ExpressionOrMemory { + Expression(Expression), Memory(BlockId), } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum AssertionPayload { +pub enum AssertionPayload { StaticString(String), - Dynamic(/* error_selector */ u64, Vec), + Dynamic(/* error_selector */ u64, Vec>), } #[derive(Debug, Copy, PartialEq, Eq, Hash, Clone, PartialOrd, Ord)] @@ -127,15 +127,15 @@ impl<'de> Deserialize<'de> for ErrorSelector { pub const STRING_ERROR_SELECTOR: ErrorSelector = ErrorSelector(0); #[derive(Clone, PartialEq, Eq, Debug, Serialize, Deserialize)] -pub struct RawAssertionPayload { +pub struct RawAssertionPayload { pub selector: ErrorSelector, - pub data: Vec, + pub data: Vec, } #[derive(Clone, PartialEq, Eq, Debug)] -pub enum ResolvedAssertionPayload { +pub enum ResolvedAssertionPayload { String(String), - Raw(RawAssertionPayload), + Raw(RawAssertionPayload), } #[derive(Debug, Copy, Clone)] @@ -204,7 +204,7 @@ impl FromStr for OpcodeLocation { } } -impl Circuit { +impl Circuit { pub fn num_vars(&self) -> u32 { self.current_witness_index + 1 } @@ -223,7 +223,7 @@ impl Circuit { } } -impl Program { +impl Program { fn write(&self, writer: W) -> std::io::Result<()> { let buf = bincode::serialize(self).unwrap(); let mut encoder = flate2::write::GzEncoder::new(writer, Compression::default()); @@ -232,26 +232,14 @@ impl Program { Ok(()) } - fn read(reader: R) -> std::io::Result { - let mut gz_decoder = flate2::read::GzDecoder::new(reader); - let mut buf_d = Vec::new(); - gz_decoder.read_to_end(&mut buf_d)?; - bincode::deserialize(&buf_d) - .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) - } - - pub fn serialize_program(program: &Program) -> Vec { + pub fn serialize_program(program: &Self) -> Vec { let mut program_bytes: Vec = Vec::new(); program.write(&mut program_bytes).expect("expected circuit to be serializable"); program_bytes } - pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result { - Program::read(serialized_circuit) - } - // Serialize and base64 encode program - pub fn serialize_program_base64(program: &Program, s: S) -> Result + pub fn serialize_program_base64(program: &Self, s: S) -> Result where S: Serializer, { @@ -259,9 +247,23 @@ impl Program { let encoded_b64 = base64::engine::general_purpose::STANDARD.encode(program_bytes); s.serialize_str(&encoded_b64) } +} + +impl Deserialize<'a>> Program { + fn read(reader: R) -> std::io::Result { + let mut gz_decoder = flate2::read::GzDecoder::new(reader); + let mut buf_d = Vec::new(); + gz_decoder.read_to_end(&mut buf_d)?; + bincode::deserialize(&buf_d) + .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) + } + + pub fn deserialize_program(serialized_circuit: &[u8]) -> std::io::Result { + Program::read(serialized_circuit) + } // Deserialize and base64 decode program - pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result + pub fn deserialize_program_base64<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, { @@ -274,7 +276,7 @@ impl Program { } } -impl std::fmt::Display for Circuit { +impl std::fmt::Display for Circuit { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { writeln!(f, "current witness index : {}", self.current_witness_index)?; @@ -313,13 +315,13 @@ impl std::fmt::Display for Circuit { } } -impl std::fmt::Debug for Circuit { +impl std::fmt::Debug for Circuit { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } } -impl std::fmt::Display for Program { +impl std::fmt::Display for Program { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for (func_index, function) in self.functions.iter().enumerate() { writeln!(f, "func {}", func_index)?; @@ -333,7 +335,7 @@ impl std::fmt::Display for Program { } } -impl std::fmt::Debug for Program { +impl std::fmt::Debug for Program { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } @@ -365,21 +367,22 @@ mod tests { circuit::{ExpressionWidth, Program}, native_types::Witness, }; - use acir_field::FieldElement; + use acir_field::{AcirField, FieldElement}; + use serde::{Deserialize, Serialize}; - fn and_opcode() -> Opcode { + fn and_opcode() -> Opcode { Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND { lhs: FunctionInput { witness: Witness(1), num_bits: 4 }, rhs: FunctionInput { witness: Witness(2), num_bits: 4 }, output: Witness(3), }) } - fn range_opcode() -> Opcode { + fn range_opcode() -> Opcode { Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input: FunctionInput { witness: Witness(1), num_bits: 8 }, }) } - fn keccakf1600_opcode() -> Opcode { + fn keccakf1600_opcode() -> Opcode { let inputs: Box<[FunctionInput; 25]> = Box::new(std::array::from_fn(|i| FunctionInput { witness: Witness(i as u32 + 1), num_bits: 8, @@ -388,7 +391,7 @@ mod tests { Opcode::BlackBoxFuncCall(BlackBoxFuncCall::Keccakf1600 { inputs, outputs }) } - fn schnorr_verify_opcode() -> Opcode { + fn schnorr_verify_opcode() -> Opcode { let public_key_x = FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() }; let public_key_y = @@ -413,7 +416,7 @@ mod tests { let circuit = Circuit { current_witness_index: 5, expression_width: ExpressionWidth::Unbounded, - opcodes: vec![and_opcode(), range_opcode(), schnorr_verify_opcode()], + opcodes: vec![and_opcode::(), range_opcode(), schnorr_verify_opcode()], private_parameters: BTreeSet::new(), public_parameters: PublicInputs(BTreeSet::from_iter(vec![Witness(2), Witness(12)])), return_values: PublicInputs(BTreeSet::from_iter(vec![Witness(4), Witness(12)])), @@ -422,7 +425,9 @@ mod tests { }; let program = Program { functions: vec![circuit], unconstrained_functions: Vec::new() }; - fn read_write(program: Program) -> (Program, Program) { + fn read_write Deserialize<'a>>( + program: Program, + ) -> (Program, Program) { let bytes = Program::serialize_program(&program); let got_program = Program::deserialize_program(&bytes).unwrap(); (program, got_program) @@ -475,7 +480,8 @@ mod tests { encoder.write_all(bad_circuit).unwrap(); encoder.finish().unwrap(); - let deserialization_result = Program::deserialize_program(&zipped_bad_circuit); + let deserialization_result: Result, _> = + Program::deserialize_program(&zipped_bad_circuit); assert!(deserialization_result.is_err()); } } diff --git a/acvm-repo/acir/src/circuit/opcodes.rs b/acvm-repo/acir/src/circuit/opcodes.rs index 6043196dff..20c6903dc5 100644 --- a/acvm-repo/acir/src/circuit/opcodes.rs +++ b/acvm-repo/acir/src/circuit/opcodes.rs @@ -3,6 +3,7 @@ use super::{ directives::Directive, }; use crate::native_types::{Expression, Witness}; +use acir_field::AcirField; use serde::{Deserialize, Serialize}; mod black_box_function_call; @@ -26,19 +27,19 @@ impl BlockType { #[allow(clippy::large_enum_variant)] #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum Opcode { - AssertZero(Expression), +pub enum Opcode { + AssertZero(Expression), /// Calls to "gadgets" which rely on backends implementing support for specialized constraints. /// /// Often used for exposing more efficient implementations of SNARK-unfriendly computations. BlackBoxFuncCall(BlackBoxFuncCall), - Directive(Directive), + Directive(Directive), /// Atomic operation on a block of memory MemoryOp { block_id: BlockId, - op: MemOp, + op: MemOp, /// Predicate of the memory operation - indicates if it should be skipped - predicate: Option, + predicate: Option>, }, MemoryInit { block_id: BlockId, @@ -51,11 +52,11 @@ pub enum Opcode { /// to fetch the appropriate Brillig bytecode from this id. id: u32, /// Inputs to the function call - inputs: Vec, + inputs: Vec>, /// Outputs to the function call outputs: Vec, /// Predicate of the Brillig execution - indicates if it should be skipped - predicate: Option, + predicate: Option>, }, /// Calls to functions represented as a separate circuit. A call opcode allows us /// to build a call stack when executing the outer-most circuit. @@ -68,11 +69,11 @@ pub enum Opcode { /// Outputs of the function call outputs: Vec, /// Predicate of the circuit execution - indicates if it should be skipped - predicate: Option, + predicate: Option>, }, } -impl std::fmt::Display for Opcode { +impl std::fmt::Display for Opcode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Opcode::AssertZero(expr) => { @@ -147,7 +148,7 @@ impl std::fmt::Display for Opcode { } } -impl std::fmt::Debug for Opcode { +impl std::fmt::Debug for Opcode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } diff --git a/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs b/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs index b0e77b15c2..362e9ba593 100644 --- a/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs +++ b/acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs @@ -481,11 +481,11 @@ where mod tests { use crate::{circuit::Opcode, native_types::Witness}; - use acir_field::FieldElement; + use acir_field::{AcirField, FieldElement}; use super::{BlackBoxFuncCall, FunctionInput}; - fn keccakf1600_opcode() -> Opcode { + fn keccakf1600_opcode() -> Opcode { let inputs: Box<[FunctionInput; 25]> = Box::new(std::array::from_fn(|i| FunctionInput { witness: Witness(i as u32 + 1), num_bits: 8, @@ -494,7 +494,7 @@ mod tests { Opcode::BlackBoxFuncCall(BlackBoxFuncCall::Keccakf1600 { inputs, outputs }) } - fn schnorr_verify_opcode() -> Opcode { + fn schnorr_verify_opcode() -> Opcode { let public_key_x = FunctionInput { witness: Witness(1), num_bits: FieldElement::max_num_bits() }; let public_key_y = @@ -516,7 +516,7 @@ mod tests { #[test] fn keccakf1600_serialization_roundtrip() { - let opcode = keccakf1600_opcode(); + let opcode = keccakf1600_opcode::(); let buf = bincode::serialize(&opcode).unwrap(); let recovered_opcode = bincode::deserialize(&buf).unwrap(); assert_eq!(opcode, recovered_opcode); @@ -524,7 +524,7 @@ mod tests { #[test] fn schnorr_serialization_roundtrip() { - let opcode = schnorr_verify_opcode(); + let opcode = schnorr_verify_opcode::(); let buf = bincode::serialize(&opcode).unwrap(); let recovered_opcode = bincode::deserialize(&buf).unwrap(); assert_eq!(opcode, recovered_opcode); diff --git a/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs b/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs index 0e94c0f051..2147e7430b 100644 --- a/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs +++ b/acvm-repo/acir/src/circuit/opcodes/memory_operation.rs @@ -1,4 +1,5 @@ use crate::native_types::{Expression, Witness}; +use acir_field::AcirField; use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash, Copy, Default)] @@ -7,22 +8,22 @@ pub struct BlockId(pub u32); /// Operation on a block of memory /// We can either write or read at an index in memory #[derive(Clone, PartialEq, Eq, Serialize, Deserialize, Debug)] -pub struct MemOp { +pub struct MemOp { /// Can be 0 (read) or 1 (write) - pub operation: Expression, - pub index: Expression, - pub value: Expression, + pub operation: Expression, + pub index: Expression, + pub value: Expression, } -impl MemOp { +impl MemOp { /// Creates a `MemOp` which reads from memory at `index` and inserts the read value /// into the [`WitnessMap`][crate::native_types::WitnessMap] at `witness` - pub fn read_at_mem_index(index: Expression, witness: Witness) -> Self { + pub fn read_at_mem_index(index: Expression, witness: Witness) -> Self { MemOp { operation: Expression::zero(), index, value: witness.into() } } /// Creates a `MemOp` which writes the [`Expression`] `value` into memory at `index`. - pub fn write_to_mem_index(index: Expression, value: Expression) -> Self { + pub fn write_to_mem_index(index: Expression, value: Expression) -> Self { MemOp { operation: Expression::one(), index, value } } } diff --git a/acvm-repo/acir/src/lib.rs b/acvm-repo/acir/src/lib.rs index f60f1b46b6..f064cfaca0 100644 --- a/acvm-repo/acir/src/lib.rs +++ b/acvm-repo/acir/src/lib.rs @@ -9,7 +9,7 @@ pub mod circuit; pub mod native_types; pub use acir_field; -pub use acir_field::FieldElement; +pub use acir_field::{AcirField, FieldElement}; pub use brillig; pub use circuit::black_box_functions::BlackBoxFunc; @@ -31,6 +31,7 @@ mod reflection { path::{Path, PathBuf}, }; + use acir_field::FieldElement; use brillig::{ BinaryFieldOp, BinaryIntOp, BlackBoxOp, HeapValueType, Opcode as BrilligOpcode, ValueOrArray, @@ -61,23 +62,23 @@ mod reflection { let mut tracer = Tracer::new(TracerConfig::default()); tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); + tracer.trace_simple_type::>().unwrap(); + tracer.trace_simple_type::>().unwrap(); tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); + tracer.trace_simple_type::>().unwrap(); tracer.trace_simple_type::().unwrap(); tracer.trace_simple_type::().unwrap(); tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); + tracer.trace_simple_type::>().unwrap(); tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); + tracer.trace_simple_type::>().unwrap(); tracer.trace_simple_type::().unwrap(); tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); + tracer.trace_simple_type::>().unwrap(); tracer.trace_simple_type::().unwrap(); tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); + tracer.trace_simple_type::>().unwrap(); + tracer.trace_simple_type::>().unwrap(); let registry = tracer.registry().unwrap(); @@ -110,8 +111,8 @@ mod reflection { let mut tracer = Tracer::new(TracerConfig::default()); tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); - tracer.trace_simple_type::().unwrap(); + tracer.trace_simple_type::>().unwrap(); + tracer.trace_simple_type::>().unwrap(); let registry = tracer.registry().unwrap(); diff --git a/acvm-repo/acir/src/native_types/expression/mod.rs b/acvm-repo/acir/src/native_types/expression/mod.rs index 402aa3eb3a..b34862429e 100644 --- a/acvm-repo/acir/src/native_types/expression/mod.rs +++ b/acvm-repo/acir/src/native_types/expression/mod.rs @@ -1,8 +1,7 @@ use crate::native_types::Witness; -use acir_field::FieldElement; +use acir_field::AcirField; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; - mod operators; mod ordering; @@ -14,30 +13,26 @@ mod ordering; // In the multiplication polynomial // XXX: If we allow the degree of the quotient polynomial to be arbitrary, then we will need a vector of wire values #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] -pub struct Expression { +pub struct Expression { // To avoid having to create intermediate variables pre-optimization // We collect all of the multiplication terms in the assert-zero opcode // A multiplication term if of the form q_M * wL * wR // Hence this vector represents the following sum: q_M1 * wL1 * wR1 + q_M2 * wL2 * wR2 + .. + - pub mul_terms: Vec<(FieldElement, Witness, Witness)>, + pub mul_terms: Vec<(F, Witness, Witness)>, - pub linear_combinations: Vec<(FieldElement, Witness)>, + pub linear_combinations: Vec<(F, Witness)>, // TODO: rename q_c to `constant` moreover q_X is not clear to those who // TODO are not familiar with PLONK - pub q_c: FieldElement, + pub q_c: F, } -impl Default for Expression { - fn default() -> Expression { - Expression { - mul_terms: Vec::new(), - linear_combinations: Vec::new(), - q_c: FieldElement::zero(), - } +impl Default for Expression { + fn default() -> Self { + Expression { mul_terms: Vec::new(), linear_combinations: Vec::new(), q_c: F::zero() } } } -impl std::fmt::Display for Expression { +impl std::fmt::Display for Expression { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { if let Some(witness) = self.to_witness() { write!(f, "x{}", witness.witness_index()) @@ -47,7 +42,7 @@ impl std::fmt::Display for Expression { } } -impl Expression { +impl Expression { // TODO: possibly remove, and move to noir repo. pub const fn can_defer_constraint(&self) -> bool { false @@ -58,30 +53,25 @@ impl Expression { self.mul_terms.len() } - pub fn from_field(q_c: FieldElement) -> Expression { + pub fn from_field(q_c: F) -> Self { Self { q_c, ..Default::default() } } - pub fn one() -> Expression { - Self::from_field(FieldElement::one()) + pub fn one() -> Self { + Self::from_field(F::one()) } - pub fn zero() -> Expression { + pub fn zero() -> Self { Self::default() } /// Adds a new linear term to the `Expression`. - pub fn push_addition_term(&mut self, coefficient: FieldElement, variable: Witness) { + pub fn push_addition_term(&mut self, coefficient: F, variable: Witness) { self.linear_combinations.push((coefficient, variable)); } /// Adds a new quadratic term to the `Expression`. - pub fn push_multiplication_term( - &mut self, - coefficient: FieldElement, - lhs: Witness, - rhs: Witness, - ) { + pub fn push_multiplication_term(&mut self, coefficient: F, lhs: Witness, rhs: Witness) { self.mul_terms.push((coefficient, lhs, rhs)); } @@ -145,7 +135,7 @@ impl Expression { /// - f(x,y) = 2*y + 6 would return `None` /// - f(x,y) = x + y would return `None` /// - f(x,y) = 5 would return `FieldElement(5)` - pub fn to_const(&self) -> Option { + pub fn to_const(&self) -> Option { self.is_const().then_some(self.q_c) } @@ -216,7 +206,7 @@ impl Expression { let mul_term = &self.mul_terms[0]; // The coefficient should be non-zero, as this method is ran after the compiler removes all zero coefficient terms - assert_ne!(mul_term.0, FieldElement::zero()); + assert_ne!(mul_term.0, F::zero()); let mut found_x = false; let mut found_y = false; @@ -240,18 +230,19 @@ impl Expression { } /// Returns `self + k*b` - pub fn add_mul(&self, k: FieldElement, b: &Expression) -> Expression { + pub fn add_mul(&self, k: F, b: &Self) -> Self { if k.is_zero() { return self.clone(); } else if self.is_const() { - return self.q_c + (k * b); + let kb = b * k; + return kb + self.q_c; } else if b.is_const() { return self.clone() + (k * b.q_c); } - let mut mul_terms: Vec<(FieldElement, Witness, Witness)> = + let mut mul_terms: Vec<(F, Witness, Witness)> = Vec::with_capacity(self.mul_terms.len() + b.mul_terms.len()); - let mut linear_combinations: Vec<(FieldElement, Witness)> = + let mut linear_combinations: Vec<(F, Witness)> = Vec::with_capacity(self.linear_combinations.len() + b.linear_combinations.len()); let q_c = self.q_c + k * b.q_c; @@ -338,7 +329,7 @@ impl Expression { while i2 < b.mul_terms.len() { let (b_c, b_wl, b_wr) = b.mul_terms[i2]; let coeff = b_c * k; - if coeff != FieldElement::zero() { + if coeff != F::zero() { mul_terms.push((coeff, b_wl, b_wr)); } i2 += 1; @@ -348,57 +339,63 @@ impl Expression { } } -impl From for Expression { - fn from(constant: FieldElement) -> Expression { +impl From for Expression { + fn from(constant: F) -> Self { Expression { q_c: constant, linear_combinations: Vec::new(), mul_terms: Vec::new() } } } -impl From for Expression { +impl From for Expression { /// Creates an Expression from a Witness. /// /// This is infallible since an `Expression` is /// a multi-variate polynomial and a `Witness` /// can be seen as a univariate polynomial - fn from(wit: Witness) -> Expression { + fn from(wit: Witness) -> Self { Expression { - q_c: FieldElement::zero(), - linear_combinations: vec![(FieldElement::one(), wit)], + q_c: F::zero(), + linear_combinations: vec![(F::one(), wit)], mul_terms: Vec::new(), } } } -#[test] -fn add_mul_smoketest() { - let a = Expression { - mul_terms: vec![(FieldElement::from(2u128), Witness(1), Witness(2))], - ..Default::default() - }; - - let k = FieldElement::from(10u128); - - let b = Expression { - mul_terms: vec![ - (FieldElement::from(3u128), Witness(0), Witness(2)), - (FieldElement::from(3u128), Witness(1), Witness(2)), - (FieldElement::from(4u128), Witness(4), Witness(5)), - ], - linear_combinations: vec![(FieldElement::from(4u128), Witness(4))], - q_c: FieldElement::one(), - }; - - let result = a.add_mul(k, &b); - assert_eq!( - result, - Expression { +#[cfg(test)] +mod tests { + use super::*; + use acir_field::{AcirField, FieldElement}; + + #[test] + fn add_mul_smoketest() { + let a = Expression { + mul_terms: vec![(FieldElement::from(2u128), Witness(1), Witness(2))], + ..Default::default() + }; + + let k = FieldElement::from(10u128); + + let b = Expression { mul_terms: vec![ - (FieldElement::from(30u128), Witness(0), Witness(2)), - (FieldElement::from(32u128), Witness(1), Witness(2)), - (FieldElement::from(40u128), Witness(4), Witness(5)), + (FieldElement::from(3u128), Witness(0), Witness(2)), + (FieldElement::from(3u128), Witness(1), Witness(2)), + (FieldElement::from(4u128), Witness(4), Witness(5)), ], - linear_combinations: vec![(FieldElement::from(40u128), Witness(4))], - q_c: FieldElement::from(10u128) - } - ); + linear_combinations: vec![(FieldElement::from(4u128), Witness(4))], + q_c: FieldElement::one(), + }; + + let result = a.add_mul(k, &b); + assert_eq!( + result, + Expression { + mul_terms: vec![ + (FieldElement::from(30u128), Witness(0), Witness(2)), + (FieldElement::from(32u128), Witness(1), Witness(2)), + (FieldElement::from(40u128), Witness(4), Witness(5)), + ], + linear_combinations: vec![(FieldElement::from(40u128), Witness(4))], + q_c: FieldElement::from(10u128) + } + ); + } } diff --git a/acvm-repo/acir/src/native_types/expression/operators.rs b/acvm-repo/acir/src/native_types/expression/operators.rs index 29cdc6967b..a8f5dc8e7a 100644 --- a/acvm-repo/acir/src/native_types/expression/operators.rs +++ b/acvm-repo/acir/src/native_types/expression/operators.rs @@ -1,5 +1,5 @@ use crate::native_types::Witness; -use acir_field::FieldElement; +use acir_field::AcirField; use std::{ cmp::Ordering, ops::{Add, Mul, Neg, Sub}, @@ -9,8 +9,8 @@ use super::Expression; // Negation -impl Neg for &Expression { - type Output = Expression; +impl Neg for &Expression { + type Output = Expression; fn neg(self) -> Self::Output { // XXX(med) : Implement an efficient way to do this @@ -27,9 +27,9 @@ impl Neg for &Expression { // FieldElement -impl Add for Expression { - type Output = Expression; - fn add(self, rhs: FieldElement) -> Self::Output { +impl Add for Expression { + type Output = Self; + fn add(self, rhs: F) -> Self::Output { // Increase the constant let q_c = self.q_c + rhs; @@ -37,17 +37,9 @@ impl Add for Expression { } } -impl Add for FieldElement { - type Output = Expression; - #[inline] - fn add(self, rhs: Expression) -> Self::Output { - rhs + self - } -} - -impl Sub for Expression { - type Output = Expression; - fn sub(self, rhs: FieldElement) -> Self::Output { +impl Sub for Expression { + type Output = Self; + fn sub(self, rhs: F) -> Self::Output { // Increase the constant let q_c = self.q_c - rhs; @@ -55,17 +47,9 @@ impl Sub for Expression { } } -impl Sub for FieldElement { - type Output = Expression; - #[inline] - fn sub(self, rhs: Expression) -> Self::Output { - rhs - self - } -} - -impl Mul for &Expression { - type Output = Expression; - fn mul(self, rhs: FieldElement) -> Self::Output { +impl Mul for &Expression { + type Output = Expression; + fn mul(self, rhs: F) -> Self::Output { // Scale the mul terms let mul_terms: Vec<_> = self.mul_terms.iter().map(|(q_m, w_l, w_r)| (*q_m * rhs, *w_l, *w_r)).collect(); @@ -81,42 +65,34 @@ impl Mul for &Expression { } } -impl Mul<&Expression> for FieldElement { - type Output = Expression; - #[inline] - fn mul(self, rhs: &Expression) -> Self::Output { - rhs * self - } -} - // Witness -impl Add for &Expression { - type Output = Expression; - fn add(self, rhs: Witness) -> Expression { +impl Add for &Expression { + type Output = Expression; + fn add(self, rhs: Witness) -> Self::Output { self + &Expression::from(rhs) } } -impl Add<&Expression> for Witness { - type Output = Expression; +impl Add<&Expression> for Witness { + type Output = Expression; #[inline] - fn add(self, rhs: &Expression) -> Expression { + fn add(self, rhs: &Expression) -> Self::Output { rhs + self } } -impl Sub for &Expression { - type Output = Expression; - fn sub(self, rhs: Witness) -> Expression { +impl Sub for &Expression { + type Output = Expression; + fn sub(self, rhs: Witness) -> Self::Output { self - &Expression::from(rhs) } } -impl Sub<&Expression> for Witness { - type Output = Expression; +impl Sub<&Expression> for Witness { + type Output = Expression; #[inline] - fn sub(self, rhs: &Expression) -> Expression { + fn sub(self, rhs: &Expression) -> Self::Output { rhs - self } } @@ -125,25 +101,25 @@ impl Sub<&Expression> for Witness { // Expression -impl Add<&Expression> for &Expression { - type Output = Expression; - fn add(self, rhs: &Expression) -> Expression { - self.add_mul(FieldElement::one(), rhs) +impl Add<&Expression> for &Expression { + type Output = Expression; + fn add(self, rhs: &Expression) -> Self::Output { + self.add_mul(F::one(), rhs) } } -impl Sub<&Expression> for &Expression { - type Output = Expression; - fn sub(self, rhs: &Expression) -> Expression { - self.add_mul(-FieldElement::one(), rhs) +impl Sub<&Expression> for &Expression { + type Output = Expression; + fn sub(self, rhs: &Expression) -> Self::Output { + self.add_mul(-F::one(), rhs) } } -impl Mul<&Expression> for &Expression { - type Output = Option; - fn mul(self, rhs: &Expression) -> Option { +impl Mul<&Expression> for &Expression { + type Output = Option>; + fn mul(self, rhs: &Expression) -> Self::Output { if self.is_const() { - return Some(self.q_c * rhs); + return Some(rhs * self.q_c); } else if rhs.is_const() { return Some(self * rhs.q_c); } else if !(self.is_linear() && rhs.is_linear()) { @@ -215,7 +191,7 @@ impl Mul<&Expression> for &Expression { } /// Returns `w*b.linear_combinations` -fn single_mul(w: Witness, b: &Expression) -> Expression { +fn single_mul(w: Witness, b: &Expression) -> Expression { Expression { mul_terms: b .linear_combinations @@ -229,62 +205,68 @@ fn single_mul(w: Witness, b: &Expression) -> Expression { } } -#[test] -fn add_smoke_test() { - let a = Expression { - mul_terms: vec![], - linear_combinations: vec![(FieldElement::from(2u128), Witness(2))], - q_c: FieldElement::from(2u128), - }; - - let b = Expression { - mul_terms: vec![], - linear_combinations: vec![(FieldElement::from(4u128), Witness(4))], - q_c: FieldElement::one(), - }; - - assert_eq!( - &a + &b, - Expression { +#[cfg(test)] +mod tests { + use super::*; + use acir_field::{AcirField, FieldElement}; + + #[test] + fn add_smoke_test() { + let a = Expression { mul_terms: vec![], - linear_combinations: vec![ - (FieldElement::from(2u128), Witness(2)), - (FieldElement::from(4u128), Witness(4)) - ], - q_c: FieldElement::from(3u128) - } - ); + linear_combinations: vec![(FieldElement::from(2u128), Witness(2))], + q_c: FieldElement::from(2u128), + }; - // Enforce commutativity - assert_eq!(&a + &b, &b + &a); -} + let b = Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::from(4u128), Witness(4))], + q_c: FieldElement::one(), + }; + + assert_eq!( + &a + &b, + Expression { + mul_terms: vec![], + linear_combinations: vec![ + (FieldElement::from(2u128), Witness(2)), + (FieldElement::from(4u128), Witness(4)) + ], + q_c: FieldElement::from(3u128) + } + ); -#[test] -fn mul_smoke_test() { - let a = Expression { - mul_terms: vec![], - linear_combinations: vec![(FieldElement::from(2u128), Witness(2))], - q_c: FieldElement::from(2u128), - }; - - let b = Expression { - mul_terms: vec![], - linear_combinations: vec![(FieldElement::from(4u128), Witness(4))], - q_c: FieldElement::one(), - }; - - assert_eq!( - (&a * &b).unwrap(), - Expression { - mul_terms: vec![(FieldElement::from(8u128), Witness(2), Witness(4)),], - linear_combinations: vec![ - (FieldElement::from(2u128), Witness(2)), - (FieldElement::from(8u128), Witness(4)) - ], - q_c: FieldElement::from(2u128) - } - ); + // Enforce commutativity + assert_eq!(&a + &b, &b + &a); + } - // Enforce commutativity - assert_eq!(&a * &b, &b * &a); + #[test] + fn mul_smoke_test() { + let a = Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::from(2u128), Witness(2))], + q_c: FieldElement::from(2u128), + }; + + let b = Expression { + mul_terms: vec![], + linear_combinations: vec![(FieldElement::from(4u128), Witness(4))], + q_c: FieldElement::one(), + }; + + assert_eq!( + (&a * &b).unwrap(), + Expression { + mul_terms: vec![(FieldElement::from(8u128), Witness(2), Witness(4)),], + linear_combinations: vec![ + (FieldElement::from(2u128), Witness(2)), + (FieldElement::from(8u128), Witness(4)) + ], + q_c: FieldElement::from(2u128) + } + ); + + // Enforce commutativity + assert_eq!(&a * &b, &b * &a); + } } diff --git a/acvm-repo/acir/src/native_types/expression/ordering.rs b/acvm-repo/acir/src/native_types/expression/ordering.rs index e24a25ec3a..b0e5e88454 100644 --- a/acvm-repo/acir/src/native_types/expression/ordering.rs +++ b/acvm-repo/acir/src/native_types/expression/ordering.rs @@ -1,3 +1,5 @@ +use acir_field::AcirField; + use crate::native_types::Witness; use std::cmp::Ordering; @@ -6,7 +8,7 @@ use super::Expression; // TODO: It's undecided whether `Expression` should implement `Ord/PartialOrd`. // This is currently used in ACVM in the compiler. -impl Ord for Expression { +impl Ord for Expression { fn cmp(&self, other: &Self) -> Ordering { let mut i1 = self.get_max_idx(); let mut i2 = other.get_max_idx(); @@ -17,13 +19,13 @@ impl Ord for Expression { if m1.is_none() && m2.is_none() { return Ordering::Equal; } - result = Expression::cmp_max(m1, m2); + result = Expression::::cmp_max(m1, m2); } result } } -impl PartialOrd for Expression { +impl PartialOrd for Expression { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } @@ -35,7 +37,7 @@ struct WitnessIdx { second_term: bool, } -impl Expression { +impl Expression { fn get_max_idx(&self) -> WitnessIdx { WitnessIdx { linear: self.linear_combinations.len(), diff --git a/acvm-repo/acir/src/native_types/witness.rs b/acvm-repo/acir/src/native_types/witness.rs index 740d10d295..3e9beb510b 100644 --- a/acvm-repo/acir/src/native_types/witness.rs +++ b/acvm-repo/acir/src/native_types/witness.rs @@ -1,10 +1,5 @@ -use std::ops::Add; - -use acir_field::FieldElement; use serde::{Deserialize, Serialize}; -use super::Expression; - // Witness might be a misnomer. This is an index that represents the position a witness will take #[derive( Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize, @@ -33,11 +28,3 @@ impl From for Witness { Self(value) } } - -impl Add for Witness { - type Output = Expression; - - fn add(self, rhs: Witness) -> Self::Output { - Expression::from(self).add_mul(FieldElement::one(), &Expression::from(rhs)) - } -} diff --git a/acvm-repo/acir/src/native_types/witness_map.rs b/acvm-repo/acir/src/native_types/witness_map.rs index 00245d5842..e508fe5b18 100644 --- a/acvm-repo/acir/src/native_types/witness_map.rs +++ b/acvm-repo/acir/src/native_types/witness_map.rs @@ -4,7 +4,6 @@ use std::{ ops::Index, }; -use acir_field::FieldElement; use flate2::bufread::GzDecoder; use flate2::bufread::GzEncoder; use flate2::Compression; @@ -25,63 +24,63 @@ pub struct WitnessMapError(#[from] SerializationError); /// A map from the witnesses in a constraint system to the field element values #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] -pub struct WitnessMap(BTreeMap); +pub struct WitnessMap(BTreeMap); -impl WitnessMap { +impl WitnessMap { pub fn new() -> Self { Self(BTreeMap::new()) } - pub fn get(&self, witness: &Witness) -> Option<&FieldElement> { + pub fn get(&self, witness: &Witness) -> Option<&F> { self.0.get(witness) } - pub fn get_index(&self, index: u32) -> Option<&FieldElement> { + pub fn get_index(&self, index: u32) -> Option<&F> { self.0.get(&index.into()) } pub fn contains_key(&self, key: &Witness) -> bool { self.0.contains_key(key) } - pub fn insert(&mut self, key: Witness, value: FieldElement) -> Option { + pub fn insert(&mut self, key: Witness, value: F) -> Option { self.0.insert(key, value) } } -impl Index<&Witness> for WitnessMap { - type Output = FieldElement; +impl Index<&Witness> for WitnessMap { + type Output = F; fn index(&self, index: &Witness) -> &Self::Output { &self.0[index] } } -pub struct IntoIter(btree_map::IntoIter); +pub struct IntoIter(btree_map::IntoIter); -impl Iterator for IntoIter { - type Item = (Witness, FieldElement); +impl Iterator for IntoIter { + type Item = (Witness, F); fn next(&mut self) -> Option { self.0.next() } } -impl IntoIterator for WitnessMap { - type Item = (Witness, FieldElement); - type IntoIter = IntoIter; +impl IntoIterator for WitnessMap { + type Item = (Witness, F); + type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { IntoIter(self.0.into_iter()) } } -impl From> for WitnessMap { - fn from(value: BTreeMap) -> Self { +impl From> for WitnessMap { + fn from(value: BTreeMap) -> Self { Self(value) } } -impl TryFrom for Vec { +impl TryFrom> for Vec { type Error = WitnessMapError; - fn try_from(val: WitnessMap) -> Result { + fn try_from(val: WitnessMap) -> Result { let buf = bincode::serialize(&val).unwrap(); let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best()); let mut buf_c = Vec::new(); @@ -90,7 +89,7 @@ impl TryFrom for Vec { } } -impl TryFrom<&[u8]> for WitnessMap { +impl Deserialize<'a>> TryFrom<&[u8]> for WitnessMap { type Error = WitnessMapError; fn try_from(bytes: &[u8]) -> Result { diff --git a/acvm-repo/acir/src/native_types/witness_stack.rs b/acvm-repo/acir/src/native_types/witness_stack.rs index 7c79e3db43..8a4fffa157 100644 --- a/acvm-repo/acir/src/native_types/witness_stack.rs +++ b/acvm-repo/acir/src/native_types/witness_stack.rs @@ -20,28 +20,28 @@ pub struct WitnessStackError(#[from] SerializationError); /// An ordered set of witness maps for separate circuits #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] -pub struct WitnessStack { - stack: Vec, +pub struct WitnessStack { + stack: Vec>, } #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Default, Serialize, Deserialize)] -pub struct StackItem { +pub struct StackItem { /// Index into a [crate::circuit::Program] function list for which we have an associated witness pub index: u32, /// A full witness for the respective constraint system specified by the index - pub witness: WitnessMap, + pub witness: WitnessMap, } -impl WitnessStack { - pub fn push(&mut self, index: u32, witness: WitnessMap) { +impl WitnessStack { + pub fn push(&mut self, index: u32, witness: WitnessMap) { self.stack.push(StackItem { index, witness }); } - pub fn pop(&mut self) -> Option { + pub fn pop(&mut self) -> Option> { self.stack.pop() } - pub fn peek(&self) -> Option<&StackItem> { + pub fn peek(&self) -> Option<&StackItem> { self.stack.last() } @@ -50,17 +50,17 @@ impl WitnessStack { } } -impl From for WitnessStack { - fn from(witness: WitnessMap) -> Self { +impl From> for WitnessStack { + fn from(witness: WitnessMap) -> Self { let stack = vec![StackItem { index: 0, witness }]; Self { stack } } } -impl TryFrom for Vec { +impl TryFrom> for Vec { type Error = WitnessStackError; - fn try_from(val: WitnessStack) -> Result { + fn try_from(val: WitnessStack) -> Result { let buf = bincode::serialize(&val).unwrap(); let mut deflater = GzEncoder::new(buf.as_slice(), Compression::best()); let mut buf_c = Vec::new(); @@ -69,7 +69,7 @@ impl TryFrom for Vec { } } -impl TryFrom<&[u8]> for WitnessStack { +impl Deserialize<'a>> TryFrom<&[u8]> for WitnessStack { type Error = WitnessStackError; fn try_from(bytes: &[u8]) -> Result { diff --git a/acvm-repo/acir/tests/test_program_serialization.rs b/acvm-repo/acir/tests/test_program_serialization.rs index 19e4beb615..d4c7a8782a 100644 --- a/acvm-repo/acir/tests/test_program_serialization.rs +++ b/acvm-repo/acir/tests/test_program_serialization.rs @@ -19,7 +19,7 @@ use acir::{ }, native_types::{Expression, Witness}, }; -use acir_field::FieldElement; +use acir_field::{AcirField, FieldElement}; use brillig::{HeapArray, HeapValueType, MemoryAddress, ValueOrArray}; #[test] @@ -34,12 +34,12 @@ fn addition_circuit() { q_c: FieldElement::zero(), }); - let circuit = Circuit { + let circuit: Circuit = Circuit { current_witness_index: 4, opcodes: vec![addition], private_parameters: BTreeSet::from([Witness(1), Witness(2)]), return_values: PublicInputs([Witness(3)].into()), - ..Circuit::default() + ..Circuit::::default() }; let program = Program { functions: vec![circuit], unconstrained_functions: vec![] }; @@ -59,18 +59,19 @@ fn addition_circuit() { #[test] fn multi_scalar_mul_circuit() { - let multi_scalar_mul = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::MultiScalarMul { - points: vec![ - FunctionInput { witness: Witness(1), num_bits: 128 }, - FunctionInput { witness: Witness(2), num_bits: 128 }, - FunctionInput { witness: Witness(3), num_bits: 1 }, - ], - scalars: vec![ - FunctionInput { witness: Witness(4), num_bits: 128 }, - FunctionInput { witness: Witness(5), num_bits: 128 }, - ], - outputs: (Witness(6), Witness(7), Witness(8)), - }); + let multi_scalar_mul: Opcode = + Opcode::BlackBoxFuncCall(BlackBoxFuncCall::MultiScalarMul { + points: vec![ + FunctionInput { witness: Witness(1), num_bits: 128 }, + FunctionInput { witness: Witness(2), num_bits: 128 }, + FunctionInput { witness: Witness(3), num_bits: 1 }, + ], + scalars: vec![ + FunctionInput { witness: Witness(4), num_bits: 128 }, + FunctionInput { witness: Witness(5), num_bits: 128 }, + ], + outputs: (Witness(6), Witness(7), Witness(8)), + }); let circuit = Circuit { current_witness_index: 9, @@ -107,7 +108,7 @@ fn pedersen_circuit() { domain_separator: 0, }); - let circuit = Circuit { + let circuit: Circuit = Circuit { current_witness_index: 4, opcodes: vec![pedersen], private_parameters: BTreeSet::from([Witness(1)]), @@ -151,7 +152,7 @@ fn schnorr_verify_circuit() { output, }); - let circuit = Circuit { + let circuit: Circuit = Circuit { current_witness_index: 100, opcodes: vec![schnorr], private_parameters: BTreeSet::from_iter((1..=last_input).map(Witness)), @@ -219,7 +220,7 @@ fn simple_brillig_foreign_call() { predicate: None, }]; - let circuit = Circuit { + let circuit: Circuit = Circuit { current_witness_index: 8, opcodes, private_parameters: BTreeSet::from([Witness(1), Witness(2)]), diff --git a/acvm-repo/acir_field/Cargo.toml b/acvm-repo/acir_field/Cargo.toml index dd4b7af9ff..8d8a1e9e2b 100644 --- a/acvm-repo/acir_field/Cargo.toml +++ b/acvm-repo/acir_field/Cargo.toml @@ -18,17 +18,12 @@ num-bigint.workspace = true serde.workspace = true num-traits.workspace = true -ark-bn254 = { version = "^0.4.0", optional = true, default-features = false, features = [ - "curve", -] } -ark-bls12-381 = { version = "^0.4.0", optional = true, default-features = false, features = [ - "curve", -] } -ark-ff = { version = "^0.4.0", optional = true, default-features = false } +ark-bn254 = { version = "^0.4.0", default-features = false, features = ["curve"] } +ark-bls12-381 = { version = "^0.4.0", optional = true, default-features = false, features = ["curve"] } +ark-ff = { version = "^0.4.0", default-features = false } cfg-if = "1.0.0" [features] -default = ["bn254"] -bn254 = ["dep:ark-bn254", "dep:ark-ff"] -bls12_381 = ["dep:ark-bls12-381", "dep:ark-ff"] +bn254 = [] +bls12_381 = ["dep:ark-bls12-381"] diff --git a/acvm-repo/acir_field/src/generic_ark.rs b/acvm-repo/acir_field/src/generic_ark.rs index 3178011a07..c1bc797192 100644 --- a/acvm-repo/acir_field/src/generic_ark.rs +++ b/acvm-repo/acir_field/src/generic_ark.rs @@ -4,9 +4,90 @@ use num_bigint::BigUint; use serde::{Deserialize, Serialize}; use std::borrow::Cow; +/// This trait is extremely unstable and WILL have breaking changes. +pub trait AcirField: + std::marker::Sized + + Display + + Debug + + Default + + Clone + + Copy + + Neg + + Add + + Sub + + Mul + + Div + + PartialOrd + + AddAssign + + SubAssign + + From + + From + // + From + // + From + // + From + // + From + + From + + Hash + + std::cmp::Eq +{ + fn one() -> Self; + fn zero() -> Self; + + fn is_zero(&self) -> bool; + fn is_one(&self) -> bool; + + fn pow(&self, exponent: &Self) -> Self; + + /// Maximum number of bits needed to represent a field element + /// This is not the amount of bits being used to represent a field element + /// Example, you only need 254 bits to represent a field element in BN256 + /// But the representation uses 256 bits, so the top two bits are always zero + /// This method would return 254 + fn max_num_bits() -> u32; + + /// Maximum numbers of bytes needed to represent a field element + /// We are not guaranteed that the number of bits being used to represent a field element + /// will always be divisible by 8. If the case that it is not, we add one to the max number of bytes + /// For example, a max bit size of 254 would give a max byte size of 32. + fn max_num_bytes() -> u32; + + fn modulus() -> BigUint; + + /// This is the number of bits required to represent this specific field element + fn num_bits(&self) -> u32; + + fn to_u128(self) -> u128; + + fn try_into_u128(self) -> Option; + + fn to_i128(self) -> i128; + + fn try_to_u64(&self) -> Option; + + /// Computes the inverse or returns zero if the inverse does not exist + /// Before using this FieldElement, please ensure that this behavior is necessary + fn inverse(&self) -> Self; + + fn to_hex(self) -> String; + + fn from_hex(hex_str: &str) -> Option; + + fn to_be_bytes(self) -> Vec; + + /// Converts bytes into a FieldElement and applies a reduction if needed. + fn from_be_bytes_reduce(bytes: &[u8]) -> Self; + + /// Returns the closest number of bytes to the bits specified + /// This method truncates + fn fetch_nearest_bytes(&self, num_bits: usize) -> Vec; + + fn and(&self, rhs: &Self, num_bits: u32) -> Self; + fn xor(&self, rhs: &Self, num_bits: u32) -> Self; +} + // XXX: Switch out for a trait and proper implementations // This implementation is in-efficient, can definitely remove hex usage and Iterator instances for trivial functionality -#[derive(Clone, Copy, Eq, PartialOrd, Ord)] +#[derive(Default, Clone, Copy, Eq, PartialOrd, Ord)] pub struct FieldElement(F); impl std::fmt::Display for FieldElement { @@ -161,25 +242,104 @@ impl From for FieldElement { } impl FieldElement { - pub fn one() -> FieldElement { + pub fn from_repr(field: F) -> Self { + Self(field) + } + + // XXX: This method is used while this field element + // implementation is not generic. + pub fn into_repr(self) -> F { + self.0 + } + + fn is_negative(&self) -> bool { + self.neg().num_bits() < self.num_bits() + } + + fn fits_in_u128(&self) -> bool { + self.num_bits() <= 128 + } + + /// Returns None, if the string is not a canonical + /// representation of a field element; less than the order + /// or if the hex string is invalid. + /// This method can be used for both hex and decimal representations. + pub fn try_from_str(input: &str) -> Option> { + if input.contains('x') { + return FieldElement::from_hex(input); + } + + let fr = F::from_str(input).ok()?; + Some(FieldElement(fr)) + } + + // mask_to methods will not remove any bytes from the field + // they are simply zeroed out + // Whereas truncate_to will remove those bits and make the byte array smaller + fn mask_to_be_bytes(&self, num_bits: u32) -> Vec { + let mut bytes = self.to_be_bytes(); + mask_vector_le(&mut bytes, num_bits as usize); + bytes + } + + fn bits(&self) -> Vec { + fn byte_to_bit(byte: u8) -> Vec { + let mut bits = Vec::with_capacity(8); + for index in (0..=7).rev() { + bits.push((byte & (1 << index)) >> index == 1); + } + bits + } + + let bytes = self.to_be_bytes(); + let mut bits = Vec::with_capacity(bytes.len() * 8); + for byte in bytes { + let _bits = byte_to_bit(byte); + bits.extend(_bits); + } + bits + } + + fn and_xor(&self, rhs: &FieldElement, num_bits: u32, is_xor: bool) -> FieldElement { + // XXX: Gadgets like SHA256 need to have their input be a multiple of 8 + // This is not a restriction caused by SHA256, as it works on bits + // but most backends assume bytes. + // We could implicitly pad, however this may not be intuitive for users. + // assert!( + // num_bits % 8 == 0, + // "num_bits is not a multiple of 8, it is {}", + // num_bits + // ); + + let lhs_bytes = self.mask_to_be_bytes(num_bits); + let rhs_bytes = rhs.mask_to_be_bytes(num_bits); + + let and_byte_arr: Vec<_> = lhs_bytes + .into_iter() + .zip(rhs_bytes) + .map(|(lhs, rhs)| if is_xor { lhs ^ rhs } else { lhs & rhs }) + .collect(); + + FieldElement::from_be_bytes_reduce(&and_byte_arr) + } +} + +impl AcirField for FieldElement { + fn one() -> FieldElement { FieldElement(F::one()) } - pub fn zero() -> FieldElement { + fn zero() -> FieldElement { FieldElement(F::zero()) } - pub fn is_zero(&self) -> bool { + fn is_zero(&self) -> bool { self == &Self::zero() } - pub fn is_one(&self) -> bool { + fn is_one(&self) -> bool { self == &Self::one() } - pub fn is_negative(&self) -> bool { - self.neg().num_bits() < self.num_bits() - } - - pub fn pow(&self, exponent: &Self) -> Self { + fn pow(&self, exponent: &Self) -> Self { FieldElement(self.0.pow(exponent.0.into_bigint())) } @@ -188,7 +348,7 @@ impl FieldElement { /// Example, you only need 254 bits to represent a field element in BN256 /// But the representation uses 256 bits, so the top two bits are always zero /// This method would return 254 - pub const fn max_num_bits() -> u32 { + fn max_num_bits() -> u32 { F::MODULUS_BIT_SIZE } @@ -196,7 +356,7 @@ impl FieldElement { /// We are not guaranteed that the number of bits being used to represent a field element /// will always be divisible by 8. If the case that it is not, we add one to the max number of bytes /// For example, a max bit size of 254 would give a max byte size of 32. - pub const fn max_num_bytes() -> u32 { + fn max_num_bytes() -> u32 { let num_bytes = Self::max_num_bits() / 8; if Self::max_num_bits() % 8 == 0 { num_bytes @@ -205,24 +365,12 @@ impl FieldElement { } } - pub fn modulus() -> BigUint { + fn modulus() -> BigUint { F::MODULUS.into() } - /// Returns None, if the string is not a canonical - /// representation of a field element; less than the order - /// or if the hex string is invalid. - /// This method can be used for both hex and decimal representations. - pub fn try_from_str(input: &str) -> Option> { - if input.contains('x') { - return FieldElement::from_hex(input); - } - - let fr = F::from_str(input).ok()?; - Some(FieldElement(fr)) - } /// This is the number of bits required to represent this specific field element - pub fn num_bits(&self) -> u32 { + fn num_bits(&self) -> u32 { let bits = self.bits(); // Iterate the number of bits and pop off all leading zeroes let iter = bits.iter().skip_while(|x| !(**x)); @@ -231,57 +379,39 @@ impl FieldElement { iter.count() as u32 } - pub fn fits_in_u128(&self) -> bool { - self.num_bits() <= 128 - } - - pub fn to_u128(self) -> u128 { + fn to_u128(self) -> u128 { let bytes = self.to_be_bytes(); u128::from_be_bytes(bytes[16..32].try_into().unwrap()) } - pub fn try_into_u128(self) -> Option { + fn try_into_u128(self) -> Option { self.fits_in_u128().then(|| self.to_u128()) } - pub fn to_i128(self) -> i128 { + fn to_i128(self) -> i128 { let is_negative = self.is_negative(); let bytes = if is_negative { self.neg() } else { self }.to_be_bytes(); i128::from_be_bytes(bytes[16..32].try_into().unwrap()) * if is_negative { -1 } else { 1 } } - pub fn try_to_u64(&self) -> Option { + fn try_to_u64(&self) -> Option { (self.num_bits() <= 64).then(|| self.to_u128() as u64) } /// Computes the inverse or returns zero if the inverse does not exist /// Before using this FieldElement, please ensure that this behavior is necessary - pub fn inverse(&self) -> FieldElement { + fn inverse(&self) -> FieldElement { let inv = self.0.inverse().unwrap_or_else(F::zero); FieldElement(inv) } - pub fn try_inverse(mut self) -> Option { - self.0.inverse_in_place().map(|f| FieldElement(*f)) - } - - pub fn from_repr(field: F) -> Self { - Self(field) - } - - // XXX: This method is used while this field element - // implementation is not generic. - pub fn into_repr(self) -> F { - self.0 - } - - pub fn to_hex(self) -> String { + fn to_hex(self) -> String { let mut bytes = Vec::new(); self.0.serialize_uncompressed(&mut bytes).unwrap(); bytes.reverse(); hex::encode(bytes) } - pub fn from_hex(hex_str: &str) -> Option> { + fn from_hex(hex_str: &str) -> Option> { let value = hex_str.strip_prefix("0x").unwrap_or(hex_str); // Values of odd length require an additional "0" prefix let sanitized_value = @@ -290,7 +420,7 @@ impl FieldElement { Some(FieldElement::from_be_bytes_reduce(&hex_as_bytes)) } - pub fn to_be_bytes(self) -> Vec { + fn to_be_bytes(self) -> Vec { // to_be_bytes! uses little endian which is why we reverse the output // TODO: Add a little endian equivalent, so the caller can use whichever one // TODO they desire @@ -302,31 +432,13 @@ impl FieldElement { /// Converts bytes into a FieldElement and applies a /// reduction if needed. - pub fn from_be_bytes_reduce(bytes: &[u8]) -> FieldElement { + fn from_be_bytes_reduce(bytes: &[u8]) -> FieldElement { FieldElement(F::from_be_bytes_mod_order(bytes)) } - pub fn bits(&self) -> Vec { - let bytes = self.to_be_bytes(); - let mut bits = Vec::with_capacity(bytes.len() * 8); - for byte in bytes { - let _bits = FieldElement::::byte_to_bit(byte); - bits.extend(_bits); - } - bits - } - - fn byte_to_bit(byte: u8) -> Vec { - let mut bits = Vec::with_capacity(8); - for index in (0..=7).rev() { - bits.push((byte & (1 << index)) >> index == 1); - } - bits - } - /// Returns the closest number of bytes to the bits specified /// This method truncates - pub fn fetch_nearest_bytes(&self, num_bits: usize) -> Vec { + fn fetch_nearest_bytes(&self, num_bits: usize) -> Vec { fn nearest_bytes(num_bits: usize) -> usize { ((num_bits + 7) / 8) * 8 } @@ -340,45 +452,17 @@ impl FieldElement { bytes[0..num_elements].to_vec() } - // mask_to methods will not remove any bytes from the field - // they are simply zeroed out - // Whereas truncate_to will remove those bits and make the byte array smaller - fn mask_to_be_bytes(&self, num_bits: u32) -> Vec { - let mut bytes = self.to_be_bytes(); - mask_vector_le(&mut bytes, num_bits as usize); - bytes - } - - fn and_xor(&self, rhs: &FieldElement, num_bits: u32, is_xor: bool) -> FieldElement { - // XXX: Gadgets like SHA256 need to have their input be a multiple of 8 - // This is not a restriction caused by SHA256, as it works on bits - // but most backends assume bytes. - // We could implicitly pad, however this may not be intuitive for users. - // assert!( - // num_bits % 8 == 0, - // "num_bits is not a multiple of 8, it is {}", - // num_bits - // ); - - let lhs_bytes = self.mask_to_be_bytes(num_bits); - let rhs_bytes = rhs.mask_to_be_bytes(num_bits); - - let and_byte_arr: Vec<_> = lhs_bytes - .into_iter() - .zip(rhs_bytes) - .map(|(lhs, rhs)| if is_xor { lhs ^ rhs } else { lhs & rhs }) - .collect(); - - FieldElement::from_be_bytes_reduce(&and_byte_arr) - } - pub fn and(&self, rhs: &FieldElement, num_bits: u32) -> FieldElement { + fn and(&self, rhs: &FieldElement, num_bits: u32) -> FieldElement { self.and_xor(rhs, num_bits, false) } - pub fn xor(&self, rhs: &FieldElement, num_bits: u32) -> FieldElement { + fn xor(&self, rhs: &FieldElement, num_bits: u32) -> FieldElement { self.and_xor(rhs, num_bits, true) } } +use std::fmt::Debug; +use std::fmt::Display; +use std::hash::Hash; use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub, SubAssign}; impl Neg for FieldElement { @@ -489,6 +573,8 @@ fn superscript(n: u64) -> String { #[cfg(test)] mod tests { + use super::{AcirField, FieldElement}; + #[test] fn and() { let max = 10_000u32; @@ -496,7 +582,7 @@ mod tests { let num_bits = (std::mem::size_of::() * 8) as u32 - max.leading_zeros(); for x in 0..max { - let x = crate::generic_ark::FieldElement::::from(x as i128); + let x = FieldElement::::from(x as i128); let res = x.and(&x, num_bits); assert_eq!(res.to_be_bytes(), x.to_be_bytes()); } @@ -513,8 +599,7 @@ mod tests { ]; for (i, string) in hex_strings.into_iter().enumerate() { - let minus_i_field_element = - -crate::generic_ark::FieldElement::::from(i as i128); + let minus_i_field_element = -FieldElement::::from(i as i128); assert_eq!(minus_i_field_element.to_hex(), string); } } @@ -525,12 +610,9 @@ mod tests { let hex_strings = vec![("0x0", "0x00"), ("0x1", "0x01"), ("0x002", "0x0002"), ("0x00003", "0x000003")]; for (i, case) in hex_strings.into_iter().enumerate() { - let i_field_element = - crate::generic_ark::FieldElement::::from(i as i128); - let odd_field_element = - crate::generic_ark::FieldElement::::from_hex(case.0).unwrap(); - let even_field_element = - crate::generic_ark::FieldElement::::from_hex(case.1).unwrap(); + let i_field_element = FieldElement::::from(i as i128); + let odd_field_element = FieldElement::::from_hex(case.0).unwrap(); + let even_field_element = FieldElement::::from_hex(case.1).unwrap(); assert_eq!(i_field_element, odd_field_element); assert_eq!(odd_field_element, even_field_element); @@ -539,7 +621,7 @@ mod tests { #[test] fn max_num_bits_smoke() { - let max_num_bits_bn254 = crate::generic_ark::FieldElement::::max_num_bits(); + let max_num_bits_bn254 = FieldElement::::max_num_bits(); assert_eq!(max_num_bits_bn254, 254); } } diff --git a/acvm-repo/acir_field/src/lib.rs b/acvm-repo/acir_field/src/lib.rs index eafe4bb2ad..7f06330ed8 100644 --- a/acvm-repo/acir_field/src/lib.rs +++ b/acvm-repo/acir_field/src/lib.rs @@ -5,19 +5,20 @@ use num_bigint::BigUint; use num_traits::Num; +mod generic_ark; -cfg_if::cfg_if! { - if #[cfg(feature = "bn254")] { - mod generic_ark; - pub type FieldElement = generic_ark::FieldElement; - pub const CHOSEN_FIELD : FieldOptions = FieldOptions::BN254; +pub use generic_ark::AcirField; + +/// Temporarily exported generic field to aid migration to `AcirField` +pub use generic_ark::FieldElement as GenericFieldElement; - } else if #[cfg(feature = "bls12_381")] { - mod generic_ark; +cfg_if::cfg_if! { + if #[cfg(feature = "bls12_381")] { pub type FieldElement = generic_ark::FieldElement; pub const CHOSEN_FIELD : FieldOptions = FieldOptions::BLS12_381; } else { - compile_error!("please specify a field to compile with"); + pub type FieldElement = generic_ark::FieldElement; + pub const CHOSEN_FIELD : FieldOptions = FieldOptions::BN254; } } diff --git a/acvm-repo/acvm/Cargo.toml b/acvm-repo/acvm/Cargo.toml index 74aed429f9..577978939b 100644 --- a/acvm-repo/acvm/Cargo.toml +++ b/acvm-repo/acvm/Cargo.toml @@ -24,7 +24,6 @@ acvm_blackbox_solver.workspace = true indexmap = "1.7.0" [features] -default = ["bn254"] bn254 = [ "acir/bn254", "brillig_vm/bn254", @@ -40,3 +39,4 @@ bls12_381 = [ rand = "0.8.5" proptest = "1.2.0" paste = "1.0.14" +ark-bls12-381 = { version = "^0.4.0", default-features = false, features = ["curve"] } \ No newline at end of file diff --git a/acvm-repo/acvm/src/compiler/mod.rs b/acvm-repo/acvm/src/compiler/mod.rs index 436db648ea..5ece3d19a6 100644 --- a/acvm-repo/acvm/src/compiler/mod.rs +++ b/acvm-repo/acvm/src/compiler/mod.rs @@ -1,6 +1,9 @@ use std::collections::HashMap; -use acir::circuit::{AssertionPayload, Circuit, ExpressionWidth, OpcodeLocation}; +use acir::{ + circuit::{AssertionPayload, Circuit, ExpressionWidth, OpcodeLocation}, + AcirField, +}; // The various passes that we can use over ACIR mod optimizers; @@ -53,10 +56,10 @@ impl AcirTransformationMap { } } -fn transform_assert_messages( - assert_messages: Vec<(OpcodeLocation, AssertionPayload)>, +fn transform_assert_messages( + assert_messages: Vec<(OpcodeLocation, AssertionPayload)>, map: &AcirTransformationMap, -) -> Vec<(OpcodeLocation, AssertionPayload)> { +) -> Vec<(OpcodeLocation, AssertionPayload)> { assert_messages .into_iter() .flat_map(|(location, message)| { @@ -67,10 +70,10 @@ fn transform_assert_messages( } /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. -pub fn compile( - acir: Circuit, +pub fn compile( + acir: Circuit, expression_width: ExpressionWidth, -) -> (Circuit, AcirTransformationMap) { +) -> (Circuit, AcirTransformationMap) { let (acir, acir_opcode_positions) = optimize_internal(acir); let (mut acir, acir_opcode_positions) = diff --git a/acvm-repo/acvm/src/compiler/optimizers/general.rs b/acvm-repo/acvm/src/compiler/optimizers/general.rs index a48a590a05..39a01a38ca 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/general.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/general.rs @@ -1,6 +1,6 @@ use acir::{ native_types::{Expression, Witness}, - FieldElement, + AcirField, }; use indexmap::IndexMap; @@ -10,7 +10,7 @@ use indexmap::IndexMap; pub(crate) struct GeneralOptimizer; impl GeneralOptimizer { - pub(crate) fn optimize(opcode: Expression) -> Expression { + pub(crate) fn optimize(opcode: Expression) -> Expression { // XXX: Perhaps this optimization can be done on the fly let opcode = remove_zero_coefficients(opcode); let opcode = simplify_mul_terms(opcode); @@ -19,7 +19,7 @@ impl GeneralOptimizer { } // Remove all terms with zero as a coefficient -fn remove_zero_coefficients(mut opcode: Expression) -> Expression { +fn remove_zero_coefficients(mut opcode: Expression) -> Expression { // Check the mul terms opcode.mul_terms.retain(|(scale, _, _)| !scale.is_zero()); // Check the linear combination terms @@ -28,8 +28,8 @@ fn remove_zero_coefficients(mut opcode: Expression) -> Expression { } // Simplifies all mul terms with the same bi-variate variables -fn simplify_mul_terms(mut gate: Expression) -> Expression { - let mut hash_map: IndexMap<(Witness, Witness), FieldElement> = IndexMap::new(); +fn simplify_mul_terms(mut gate: Expression) -> Expression { + let mut hash_map: IndexMap<(Witness, Witness), F> = IndexMap::new(); // Canonicalize the ordering of the multiplication, lets just order by variable name for (scale, w_l, w_r) in gate.mul_terms.into_iter() { @@ -37,7 +37,7 @@ fn simplify_mul_terms(mut gate: Expression) -> Expression { // Sort using rust sort algorithm pair.sort(); - *hash_map.entry((pair[0], pair[1])).or_insert_with(FieldElement::zero) += scale; + *hash_map.entry((pair[0], pair[1])).or_insert_with(F::zero) += scale; } gate.mul_terms = hash_map.into_iter().map(|((w_l, w_r), scale)| (scale, w_l, w_r)).collect(); @@ -45,17 +45,17 @@ fn simplify_mul_terms(mut gate: Expression) -> Expression { } // Simplifies all linear terms with the same variables -fn simplify_linear_terms(mut gate: Expression) -> Expression { - let mut hash_map: IndexMap = IndexMap::new(); +fn simplify_linear_terms(mut gate: Expression) -> Expression { + let mut hash_map: IndexMap = IndexMap::new(); // Canonicalize the ordering of the terms, lets just order by variable name for (scale, witness) in gate.linear_combinations.into_iter() { - *hash_map.entry(witness).or_insert_with(FieldElement::zero) += scale; + *hash_map.entry(witness).or_insert_with(F::zero) += scale; } gate.linear_combinations = hash_map .into_iter() - .filter(|(_, scale)| scale != &FieldElement::zero()) + .filter(|(_, scale)| !scale.is_zero()) .map(|(witness, scale)| (scale, witness)) .collect(); gate diff --git a/acvm-repo/acvm/src/compiler/optimizers/mod.rs b/acvm-repo/acvm/src/compiler/optimizers/mod.rs index dfe348d4ff..e20ad97a10 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/mod.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/mod.rs @@ -1,4 +1,7 @@ -use acir::circuit::{Circuit, Opcode}; +use acir::{ + circuit::{Circuit, Opcode}, + AcirField, +}; // mod constant_backpropagation; mod general; @@ -15,7 +18,7 @@ use self::unused_memory::UnusedMemoryOptimizer; use super::{transform_assert_messages, AcirTransformationMap}; /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] independent optimizations to a [`Circuit`]. -pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) { +pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) { let (mut acir, new_opcode_positions) = optimize_internal(acir); let transformation_map = AcirTransformationMap::new(new_opcode_positions); @@ -27,7 +30,7 @@ pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) { /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] independent optimizations to a [`Circuit`]. #[tracing::instrument(level = "trace", name = "optimize_acir" skip(acir))] -pub(super) fn optimize_internal(acir: Circuit) -> (Circuit, Vec) { +pub(super) fn optimize_internal(acir: Circuit) -> (Circuit, Vec) { // Track original acir opcode positions throughout the transformation passes of the compilation // by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert) let acir_opcode_positions = (0..acir.opcodes.len()).collect(); @@ -40,7 +43,7 @@ pub(super) fn optimize_internal(acir: Circuit) -> (Circuit, Vec) { info!("Number of opcodes before: {}", acir.opcodes.len()); // General optimizer pass - let opcodes: Vec = acir + let opcodes: Vec> = acir .opcodes .into_iter() .map(|opcode| { diff --git a/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs b/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs index 0e1629717b..7001c953d6 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs @@ -4,6 +4,7 @@ use acir::{ Circuit, Opcode, }, native_types::Witness, + AcirField, }; use std::collections::{BTreeMap, HashSet}; @@ -26,16 +27,16 @@ use std::collections::{BTreeMap, HashSet}; /// /// This optimization pass will keep the 16-bit range constraint /// and remove the 32-bit range constraint opcode. -pub(crate) struct RangeOptimizer { +pub(crate) struct RangeOptimizer { /// Maps witnesses to their lowest known bit sizes. lists: BTreeMap, - circuit: Circuit, + circuit: Circuit, } -impl RangeOptimizer { +impl RangeOptimizer { /// Creates a new `RangeOptimizer` by collecting all known range /// constraints from `Circuit`. - pub(crate) fn new(circuit: Circuit) -> Self { + pub(crate) fn new(circuit: Circuit) -> Self { let range_list = Self::collect_ranges(&circuit); Self { circuit, lists: range_list } } @@ -46,7 +47,7 @@ impl RangeOptimizer { /// both 32 bits and 16 bits. This function will /// only store the fact that we have constrained it to /// be 16 bits. - fn collect_ranges(circuit: &Circuit) -> BTreeMap { + fn collect_ranges(circuit: &Circuit) -> BTreeMap { let mut witness_to_bit_sizes: BTreeMap = BTreeMap::new(); for opcode in &circuit.opcodes { @@ -95,7 +96,10 @@ impl RangeOptimizer { /// Returns a `Circuit` where each Witness is only range constrained /// once to the lowest number `bit size` possible. - pub(crate) fn replace_redundant_ranges(self, order_list: Vec) -> (Circuit, Vec) { + pub(crate) fn replace_redundant_ranges( + self, + order_list: Vec, + ) -> (Circuit, Vec) { let mut already_seen_witness = HashSet::new(); let mut new_order_list = Vec::with_capacity(order_list.len()); @@ -148,10 +152,11 @@ mod tests { Circuit, ExpressionWidth, Opcode, PublicInputs, }, native_types::{Expression, Witness}, + FieldElement, }; - fn test_circuit(ranges: Vec<(Witness, u32)>) -> Circuit { - fn test_range_constraint(witness: Witness, num_bits: u32) -> Opcode { + fn test_circuit(ranges: Vec<(Witness, u32)>) -> Circuit { + fn test_range_constraint(witness: Witness, num_bits: u32) -> Opcode { Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input: FunctionInput { witness, num_bits }, }) diff --git a/acvm-repo/acvm/src/compiler/optimizers/unused_memory.rs b/acvm-repo/acvm/src/compiler/optimizers/unused_memory.rs index 5fdcf54a49..1963430210 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/unused_memory.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/unused_memory.rs @@ -2,15 +2,15 @@ use acir::circuit::{opcodes::BlockId, Circuit, Opcode}; use std::collections::HashSet; /// `UnusedMemoryOptimizer` will remove initializations of memory blocks which are unused. -pub(crate) struct UnusedMemoryOptimizer { +pub(crate) struct UnusedMemoryOptimizer { unused_memory_initializations: HashSet, - circuit: Circuit, + circuit: Circuit, } -impl UnusedMemoryOptimizer { +impl UnusedMemoryOptimizer { /// Creates a new `UnusedMemoryOptimizer ` by collecting unused memory init /// opcodes from `Circuit`. - pub(crate) fn new(circuit: Circuit) -> Self { + pub(crate) fn new(circuit: Circuit) -> Self { let unused_memory_initializations = Self::collect_unused_memory_initializations(&circuit); Self { circuit, unused_memory_initializations } } @@ -18,7 +18,7 @@ impl UnusedMemoryOptimizer { /// Creates a set of ids for memory blocks for which no [`Opcode::MemoryOp`]s exist. /// /// These memory blocks can be safely removed. - fn collect_unused_memory_initializations(circuit: &Circuit) -> HashSet { + fn collect_unused_memory_initializations(circuit: &Circuit) -> HashSet { let mut unused_memory_initialization = HashSet::new(); for opcode in &circuit.opcodes { @@ -39,7 +39,7 @@ impl UnusedMemoryOptimizer { pub(crate) fn remove_unused_memory_initializations( self, order_list: Vec, - ) -> (Circuit, Vec) { + ) -> (Circuit, Vec) { let mut new_order_list = Vec::with_capacity(order_list.len()); let mut optimized_opcodes = Vec::with_capacity(self.circuit.opcodes.len()); for (idx, opcode) in self.circuit.opcodes.into_iter().enumerate() { diff --git a/acvm-repo/acvm/src/compiler/transformers/csat.rs b/acvm-repo/acvm/src/compiler/transformers/csat.rs index 12a37e3e37..6cf74c0420 100644 --- a/acvm-repo/acvm/src/compiler/transformers/csat.rs +++ b/acvm-repo/acvm/src/compiler/transformers/csat.rs @@ -2,7 +2,7 @@ use std::{cmp::Ordering, collections::HashSet}; use acir::{ native_types::{Expression, Witness}, - FieldElement, + AcirField, }; use indexmap::IndexMap; @@ -30,7 +30,7 @@ impl CSatTransformer { } /// Check if the equation 'expression=0' can be solved, and if yes, add the solved witness to set of solvable witness - fn try_solve(&mut self, opcode: &Expression) { + fn try_solve(&mut self, opcode: &Expression) { let mut unresolved = Vec::new(); for (_, w1, w2) in &opcode.mul_terms { if !self.solvable_witness.contains(w1) { @@ -64,12 +64,12 @@ impl CSatTransformer { // Still missing dead witness optimization. // To do this, we will need the whole set of assert-zero opcodes // I think it can also be done before the local optimization seen here, as dead variables will come from the user - pub(crate) fn transform( + pub(crate) fn transform( &mut self, - opcode: Expression, - intermediate_variables: &mut IndexMap, + opcode: Expression, + intermediate_variables: &mut IndexMap, (F, Witness)>, num_witness: &mut u32, - ) -> Expression { + ) -> Expression { // Here we create intermediate variables and constrain them to be equal to any subset of the polynomial that can be represented as a full opcode let opcode = self.full_opcode_scan_optimization(opcode, intermediate_variables, num_witness); @@ -107,12 +107,12 @@ impl CSatTransformer { // The polynomial now looks like so t + t2 // We can no longer extract another full opcode, hence the algorithm terminates. Creating two intermediate variables t and t2. // This stage of preprocessing does not guarantee that all polynomials can fit into a opcode. It only guarantees that all full opcodes have been extracted from each polynomial - fn full_opcode_scan_optimization( + fn full_opcode_scan_optimization( &mut self, - mut opcode: Expression, - intermediate_variables: &mut IndexMap, + mut opcode: Expression, + intermediate_variables: &mut IndexMap, (F, Witness)>, num_witness: &mut u32, - ) -> Expression { + ) -> Expression { // We pass around this intermediate variable IndexMap, so that we do not create intermediate variables that we have created before // One instance where this might happen is t1 = wL * wR and t2 = wR * wL @@ -245,7 +245,7 @@ impl CSatTransformer { /// Normalize an expression by dividing it by its first coefficient /// The first coefficient here means coefficient of the first linear term, or of the first quadratic term if no linear terms exist. /// The function panic if the input expression is constant - fn normalize(mut expr: Expression) -> (FieldElement, Expression) { + fn normalize(mut expr: Expression) -> (F, Expression) { expr.sort(); let a = if !expr.linear_combinations.is_empty() { expr.linear_combinations[0].0 @@ -259,11 +259,11 @@ impl CSatTransformer { /// The sets of previously generated witness and their (normalized) expression is cached in the intermediate_variables map /// If there is no cache hit, we generate a new witness (and add the expression to the cache) /// else, we return the cached witness along with the scaling factor so it is equal to the provided expression - fn get_or_create_intermediate_vars( - intermediate_variables: &mut IndexMap, - expr: Expression, + fn get_or_create_intermediate_vars( + intermediate_variables: &mut IndexMap, (F, Witness)>, + expr: Expression, num_witness: &mut u32, - ) -> (FieldElement, Witness) { + ) -> (F, Witness) { let (k, normalized_expr) = Self::normalize(expr); if intermediate_variables.contains_key(&normalized_expr) { @@ -274,7 +274,7 @@ impl CSatTransformer { *num_witness += 1; // Add intermediate opcode and variable to map intermediate_variables.insert(normalized_expr, (k, inter_var)); - (FieldElement::one(), inter_var) + (F::one(), inter_var) } } @@ -315,12 +315,12 @@ impl CSatTransformer { // Also remember that since we did full opcode scan, there is no way we can have a non-zero mul term along with the wL and wR terms being non-zero // // Cases, a lot of mul terms, a lot of fan-in terms, 50/50 - fn partial_opcode_scan_optimization( + fn partial_opcode_scan_optimization( &mut self, - mut opcode: Expression, - intermediate_variables: &mut IndexMap, + mut opcode: Expression, + intermediate_variables: &mut IndexMap, (F, Witness)>, num_witness: &mut u32, - ) -> Expression { + ) -> Expression { // We will go for the easiest route, which is to convert all multiplications into additions using intermediate variables // Then use intermediate variables again to squash the fan-in, so that it can fit into the appropriate width @@ -409,101 +409,113 @@ impl CSatTransformer { } } -#[test] -fn simple_reduction_smoke_test() { - let a = Witness(0); - let b = Witness(1); - let c = Witness(2); - let d = Witness(3); - - // a = b + c + d; - let opcode_a = Expression { - mul_terms: vec![], - linear_combinations: vec![ - (FieldElement::one(), a), - (-FieldElement::one(), b), - (-FieldElement::one(), c), - (-FieldElement::one(), d), - ], - q_c: FieldElement::zero(), - }; - - let mut intermediate_variables: IndexMap = IndexMap::new(); - - let mut num_witness = 4; - - let mut optimizer = CSatTransformer::new(3); - optimizer.mark_solvable(b); - optimizer.mark_solvable(c); - optimizer.mark_solvable(d); - let got_optimized_opcode_a = - optimizer.transform(opcode_a, &mut intermediate_variables, &mut num_witness); - - // a = b + c + d => a - b - c - d = 0 - // For width3, the result becomes: - // a - d + e = 0 - // - c - b - e = 0 - // - // a - b + e = 0 - let e = Witness(4); - let expected_optimized_opcode_a = Expression { - mul_terms: vec![], - linear_combinations: vec![ - (FieldElement::one(), a), - (-FieldElement::one(), d), - (FieldElement::one(), e), - ], - q_c: FieldElement::zero(), - }; - assert_eq!(expected_optimized_opcode_a, got_optimized_opcode_a); - - assert_eq!(intermediate_variables.len(), 1); - - // e = - c - b - let expected_intermediate_opcode = Expression { - mul_terms: vec![], - linear_combinations: vec![(-FieldElement::one(), c), (-FieldElement::one(), b)], - q_c: FieldElement::zero(), - }; - let (_, normalized_opcode) = CSatTransformer::normalize(expected_intermediate_opcode); - assert!(intermediate_variables.contains_key(&normalized_opcode)); - assert_eq!(intermediate_variables[&normalized_opcode].1, e); -} +#[cfg(test)] +mod tests { + use super::*; + use acir::{AcirField, FieldElement}; + + #[test] + fn simple_reduction_smoke_test() { + let a = Witness(0); + let b = Witness(1); + let c = Witness(2); + let d = Witness(3); + + // a = b + c + d; + let opcode_a = Expression { + mul_terms: vec![], + linear_combinations: vec![ + (FieldElement::one(), a), + (-FieldElement::one(), b), + (-FieldElement::one(), c), + (-FieldElement::one(), d), + ], + q_c: FieldElement::zero(), + }; + + let mut intermediate_variables: IndexMap< + Expression, + (FieldElement, Witness), + > = IndexMap::new(); + + let mut num_witness = 4; + + let mut optimizer = CSatTransformer::new(3); + optimizer.mark_solvable(b); + optimizer.mark_solvable(c); + optimizer.mark_solvable(d); + let got_optimized_opcode_a = + optimizer.transform(opcode_a, &mut intermediate_variables, &mut num_witness); + + // a = b + c + d => a - b - c - d = 0 + // For width3, the result becomes: + // a - d + e = 0 + // - c - b - e = 0 + // + // a - b + e = 0 + let e = Witness(4); + let expected_optimized_opcode_a = Expression { + mul_terms: vec![], + linear_combinations: vec![ + (FieldElement::one(), a), + (-FieldElement::one(), d), + (FieldElement::one(), e), + ], + q_c: FieldElement::zero(), + }; + assert_eq!(expected_optimized_opcode_a, got_optimized_opcode_a); + + assert_eq!(intermediate_variables.len(), 1); -#[test] -fn stepwise_reduction_test() { - let a = Witness(0); - let b = Witness(1); - let c = Witness(2); - let d = Witness(3); - let e = Witness(4); - - // a = b + c + d + e; - let opcode_a = Expression { - mul_terms: vec![], - linear_combinations: vec![ - (-FieldElement::one(), a), - (FieldElement::one(), b), - (FieldElement::one(), c), - (FieldElement::one(), d), - (FieldElement::one(), e), - ], - q_c: FieldElement::zero(), - }; - - let mut intermediate_variables: IndexMap = IndexMap::new(); - - let mut num_witness = 4; - - let mut optimizer = CSatTransformer::new(3); - optimizer.mark_solvable(a); - optimizer.mark_solvable(c); - optimizer.mark_solvable(d); - optimizer.mark_solvable(e); - let got_optimized_opcode_a = - optimizer.transform(opcode_a, &mut intermediate_variables, &mut num_witness); - - // Since b is not known, it cannot be put inside intermediate opcodes, so it must belong to the transformed opcode. - let contains_b = got_optimized_opcode_a.linear_combinations.iter().any(|(_, w)| *w == b); - assert!(contains_b); + // e = - c - b + let expected_intermediate_opcode = Expression { + mul_terms: vec![], + linear_combinations: vec![(-FieldElement::one(), c), (-FieldElement::one(), b)], + q_c: FieldElement::zero(), + }; + let (_, normalized_opcode) = CSatTransformer::normalize(expected_intermediate_opcode); + assert!(intermediate_variables.contains_key(&normalized_opcode)); + assert_eq!(intermediate_variables[&normalized_opcode].1, e); + } + + #[test] + fn stepwise_reduction_test() { + let a = Witness(0); + let b = Witness(1); + let c = Witness(2); + let d = Witness(3); + let e = Witness(4); + + // a = b + c + d + e; + let opcode_a = Expression { + mul_terms: vec![], + linear_combinations: vec![ + (-FieldElement::one(), a), + (FieldElement::one(), b), + (FieldElement::one(), c), + (FieldElement::one(), d), + (FieldElement::one(), e), + ], + q_c: FieldElement::zero(), + }; + + let mut intermediate_variables: IndexMap< + Expression, + (FieldElement, Witness), + > = IndexMap::new(); + + let mut num_witness = 4; + + let mut optimizer = CSatTransformer::new(3); + optimizer.mark_solvable(a); + optimizer.mark_solvable(c); + optimizer.mark_solvable(d); + optimizer.mark_solvable(e); + let got_optimized_opcode_a = + optimizer.transform(opcode_a, &mut intermediate_variables, &mut num_witness); + + // Since b is not known, it cannot be put inside intermediate opcodes, so it must belong to the transformed opcode. + let contains_b = got_optimized_opcode_a.linear_combinations.iter().any(|(_, w)| *w == b); + assert!(contains_b); + } } diff --git a/acvm-repo/acvm/src/compiler/transformers/mod.rs b/acvm-repo/acvm/src/compiler/transformers/mod.rs index 0099519e4b..4fd8ba7883 100644 --- a/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -1,7 +1,7 @@ use acir::{ circuit::{brillig::BrilligOutputs, directives::Directive, Circuit, ExpressionWidth, Opcode}, native_types::{Expression, Witness}, - FieldElement, + AcirField, }; use indexmap::IndexMap; @@ -12,10 +12,10 @@ pub(crate) use csat::CSatTransformer; use super::{transform_assert_messages, AcirTransformationMap}; /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] specific optimizations to a [`Circuit`]. -pub fn transform( - acir: Circuit, +pub fn transform( + acir: Circuit, expression_width: ExpressionWidth, -) -> (Circuit, AcirTransformationMap) { +) -> (Circuit, AcirTransformationMap) { // Track original acir opcode positions throughout the transformation passes of the compilation // by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert) let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect(); @@ -34,11 +34,11 @@ pub fn transform( /// /// Accepts an injected `acir_opcode_positions` to allow transformations to be applied directly after optimizations. #[tracing::instrument(level = "trace", name = "transform_acir", skip(acir, acir_opcode_positions))] -pub(super) fn transform_internal( - acir: Circuit, +pub(super) fn transform_internal( + acir: Circuit, expression_width: ExpressionWidth, acir_opcode_positions: Vec, -) -> (Circuit, Vec) { +) -> (Circuit, Vec) { let mut transformer = match &expression_width { ExpressionWidth::Unbounded => { return (acir, acir_opcode_positions); @@ -64,7 +64,7 @@ pub(super) fn transform_internal( let mut next_witness_index = acir.current_witness_index + 1; // maps a normalized expression to the intermediate variable which represents the expression, along with its 'norm' // the 'norm' is simply the value of the first non zero coefficient in the expression, taken from the linear terms, or quadratic terms if there is none. - let mut intermediate_variables: IndexMap = IndexMap::new(); + let mut intermediate_variables: IndexMap, (F, Witness)> = IndexMap::new(); for (index, opcode) in acir.opcodes.into_iter().enumerate() { match opcode { Opcode::AssertZero(arith_expr) => { @@ -83,7 +83,7 @@ pub(super) fn transform_internal( // de-normalize let mut intermediate_opcode = g * *norm; // constrain the intermediate opcode to the intermediate variable - intermediate_opcode.linear_combinations.push((-FieldElement::one(), *w)); + intermediate_opcode.linear_combinations.push((-F::one(), *w)); intermediate_opcode.sort(); new_opcodes.push(intermediate_opcode); } diff --git a/acvm-repo/acvm/src/lib.rs b/acvm-repo/acvm/src/lib.rs index 00a253fde0..4c64e1da74 100644 --- a/acvm-repo/acvm/src/lib.rs +++ b/acvm-repo/acvm/src/lib.rs @@ -11,7 +11,7 @@ use pwg::OpcodeResolutionError; // re-export acir pub use acir; -pub use acir::FieldElement; +pub use acir::{AcirField, FieldElement}; // re-export brillig vm pub use brillig_vm; // re-export blackbox solver diff --git a/acvm-repo/acvm/src/pwg/arithmetic.rs b/acvm-repo/acvm/src/pwg/arithmetic.rs index b971e4a0ef..5eeabd8a83 100644 --- a/acvm-repo/acvm/src/pwg/arithmetic.rs +++ b/acvm-repo/acvm/src/pwg/arithmetic.rs @@ -1,6 +1,6 @@ use acir::{ native_types::{Expression, Witness, WitnessMap}, - FieldElement, + AcirField, }; use super::{insert_value, ErrorLocation, OpcodeNotSolvable, OpcodeResolutionError}; @@ -10,24 +10,24 @@ use super::{insert_value, ErrorLocation, OpcodeNotSolvable, OpcodeResolutionErro pub(crate) struct ExpressionSolver; #[allow(clippy::enum_variant_names)] -pub(super) enum OpcodeStatus { - OpcodeSatisfied(FieldElement), - OpcodeSolvable(FieldElement, (FieldElement, Witness)), +pub(super) enum OpcodeStatus { + OpcodeSatisfied(F), + OpcodeSolvable(F, (F, Witness)), OpcodeUnsolvable, } -pub(crate) enum MulTerm { - OneUnknown(FieldElement, Witness), // (qM * known_witness, unknown_witness) +pub(crate) enum MulTerm { + OneUnknown(F, Witness), // (qM * known_witness, unknown_witness) TooManyUnknowns, - Solved(FieldElement), + Solved(F), } impl ExpressionSolver { /// Derives the rest of the witness based on the initial low level variables - pub(crate) fn solve( - initial_witness: &mut WitnessMap, - opcode: &Expression, - ) -> Result<(), OpcodeResolutionError> { + pub(crate) fn solve( + initial_witness: &mut WitnessMap, + opcode: &Expression, + ) -> Result<(), OpcodeResolutionError> { let opcode = &ExpressionSolver::evaluate(opcode, initial_witness); // Evaluate multiplication term let mul_result = @@ -133,14 +133,14 @@ impl ExpressionSolver { /// If the witness values are not known, then the function returns a None /// XXX: Do we need to account for the case where 5xy + 6x = 0 ? We do not know y, but it can be solved given x . But I believe x can be solved with another opcode /// XXX: What about making a mul opcode = a constant 5xy + 7 = 0 ? This is the same as the above. - fn solve_mul_term( - arith_opcode: &Expression, - witness_assignments: &WitnessMap, - ) -> Result { + fn solve_mul_term( + arith_opcode: &Expression, + witness_assignments: &WitnessMap, + ) -> Result, OpcodeStatus> { // First note that the mul term can only contain one/zero term // We are assuming it has been optimized. match arith_opcode.mul_terms.len() { - 0 => Ok(MulTerm::Solved(FieldElement::zero())), + 0 => Ok(MulTerm::Solved(F::zero())), 1 => Ok(ExpressionSolver::solve_mul_term_helper( &arith_opcode.mul_terms[0], witness_assignments, @@ -149,10 +149,10 @@ impl ExpressionSolver { } } - fn solve_mul_term_helper( - term: &(FieldElement, Witness, Witness), - witness_assignments: &WitnessMap, - ) -> MulTerm { + fn solve_mul_term_helper( + term: &(F, Witness, Witness), + witness_assignments: &WitnessMap, + ) -> MulTerm { let (q_m, w_l, w_r) = term; // Check if these values are in the witness assignments let w_l_value = witness_assignments.get(w_l); @@ -166,10 +166,10 @@ impl ExpressionSolver { } } - fn solve_fan_in_term_helper( - term: &(FieldElement, Witness), - witness_assignments: &WitnessMap, - ) -> Option { + fn solve_fan_in_term_helper( + term: &(F, Witness), + witness_assignments: &WitnessMap, + ) -> Option { let (q_l, w_l) = term; // Check if we have w_l let w_l_value = witness_assignments.get(w_l); @@ -179,17 +179,17 @@ impl ExpressionSolver { /// Returns the summation of all of the variables, plus the unknown variable /// Returns None, if there is more than one unknown variable /// We cannot assign - pub(super) fn solve_fan_in_term( - arith_opcode: &Expression, - witness_assignments: &WitnessMap, - ) -> OpcodeStatus { + pub(super) fn solve_fan_in_term( + arith_opcode: &Expression, + witness_assignments: &WitnessMap, + ) -> OpcodeStatus { // This is assuming that the fan-in is more than 0 // This is the variable that we want to assign the value to - let mut unknown_variable = (FieldElement::zero(), Witness::default()); + let mut unknown_variable = (F::zero(), Witness::default()); let mut num_unknowns = 0; // This is the sum of all of the known variables - let mut result = FieldElement::zero(); + let mut result = F::zero(); for term in arith_opcode.linear_combinations.iter() { let value = ExpressionSolver::solve_fan_in_term_helper(term, witness_assignments); @@ -215,7 +215,10 @@ impl ExpressionSolver { } // Partially evaluate the opcode using the known witnesses - pub(crate) fn evaluate(expr: &Expression, initial_witness: &WitnessMap) -> Expression { + pub(crate) fn evaluate( + expr: &Expression, + initial_witness: &WitnessMap, + ) -> Expression { let mut result = Expression::default(); for &(c, w1, w2) in &expr.mul_terms { let mul_result = ExpressionSolver::solve_mul_term_helper(&(c, w1, w2), initial_witness); @@ -245,43 +248,49 @@ impl ExpressionSolver { } } -#[test] -fn expression_solver_smoke_test() { - let a = Witness(0); - let b = Witness(1); - let c = Witness(2); - let d = Witness(3); +#[cfg(test)] +mod tests { + use super::*; + use acir::FieldElement; - // a = b + c + d; - let opcode_a = Expression { - mul_terms: vec![], - linear_combinations: vec![ - (FieldElement::one(), a), - (-FieldElement::one(), b), - (-FieldElement::one(), c), - (-FieldElement::one(), d), - ], - q_c: FieldElement::zero(), - }; + #[test] + fn expression_solver_smoke_test() { + let a = Witness(0); + let b = Witness(1); + let c = Witness(2); + let d = Witness(3); - let e = Witness(4); - let opcode_b = Expression { - mul_terms: vec![], - linear_combinations: vec![ - (FieldElement::one(), e), - (-FieldElement::one(), a), - (-FieldElement::one(), b), - ], - q_c: FieldElement::zero(), - }; + // a = b + c + d; + let opcode_a = Expression { + mul_terms: vec![], + linear_combinations: vec![ + (FieldElement::one(), a), + (-FieldElement::one(), b), + (-FieldElement::one(), c), + (-FieldElement::one(), d), + ], + q_c: FieldElement::zero(), + }; - let mut values = WitnessMap::new(); - values.insert(b, FieldElement::from(2_i128)); - values.insert(c, FieldElement::from(1_i128)); - values.insert(d, FieldElement::from(1_i128)); + let e = Witness(4); + let opcode_b = Expression { + mul_terms: vec![], + linear_combinations: vec![ + (FieldElement::one(), e), + (-FieldElement::one(), a), + (-FieldElement::one(), b), + ], + q_c: FieldElement::zero(), + }; - assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(())); - assert_eq!(ExpressionSolver::solve(&mut values, &opcode_b), Ok(())); + let mut values = WitnessMap::new(); + values.insert(b, FieldElement::from(2_i128)); + values.insert(c, FieldElement::from(1_i128)); + values.insert(d, FieldElement::from(1_i128)); - assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128)); + assert_eq!(ExpressionSolver::solve(&mut values, &opcode_a), Ok(())); + assert_eq!(ExpressionSolver::solve(&mut values, &opcode_b), Ok(())); + + assert_eq!(values.get(&a).unwrap(), &FieldElement::from(4_i128)); + } } diff --git a/acvm-repo/acvm/src/pwg/blackbox/aes128.rs b/acvm-repo/acvm/src/pwg/blackbox/aes128.rs index c02c59a174..181a78a2a6 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/aes128.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/aes128.rs @@ -1,7 +1,7 @@ use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, }; use acvm_blackbox_solver::aes128_encrypt; @@ -9,13 +9,13 @@ use crate::{pwg::insert_value, OpcodeResolutionError}; use super::utils::{to_u8_array, to_u8_vec}; -pub(super) fn solve_aes128_encryption_opcode( - initial_witness: &mut WitnessMap, +pub(super) fn solve_aes128_encryption_opcode( + initial_witness: &mut WitnessMap, inputs: &[FunctionInput], iv: &[FunctionInput; 16], key: &[FunctionInput; 16], outputs: &[Witness], -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let scalars = to_u8_vec(initial_witness, inputs)?; let iv = to_u8_array(initial_witness, iv)?; @@ -25,7 +25,7 @@ pub(super) fn solve_aes128_encryption_opcode( // Write witness assignments for (output_witness, value) in outputs.iter().zip(ciphertext.into_iter()) { - insert_value(output_witness, FieldElement::from(value as u128), initial_witness)?; + insert_value(output_witness, F::from(value as u128), initial_witness)?; } Ok(()) diff --git a/acvm-repo/acvm/src/pwg/blackbox/bigint.rs b/acvm-repo/acvm/src/pwg/blackbox/bigint.rs index 3c05fb2761..be5a4613a5 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/bigint.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/bigint.rs @@ -1,7 +1,7 @@ use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, - BlackBoxFunc, FieldElement, + AcirField, BlackBoxFunc, }; use acvm_blackbox_solver::BigIntSolver; @@ -18,13 +18,13 @@ pub(crate) struct AcvmBigIntSolver { } impl AcvmBigIntSolver { - pub(crate) fn bigint_from_bytes( + pub(crate) fn bigint_from_bytes( &mut self, inputs: &[FunctionInput], modulus: &[u8], output: u32, - initial_witness: &mut WitnessMap, - ) -> Result<(), OpcodeResolutionError> { + initial_witness: &mut WitnessMap, + ) -> Result<(), OpcodeResolutionError> { let bytes = inputs .iter() .map(|input| initial_witness.get(&input.witness).unwrap().to_u128() as u8) @@ -33,29 +33,29 @@ impl AcvmBigIntSolver { Ok(()) } - pub(crate) fn bigint_to_bytes( + pub(crate) fn bigint_to_bytes( &self, input: u32, outputs: &[Witness], - initial_witness: &mut WitnessMap, - ) -> Result<(), OpcodeResolutionError> { + initial_witness: &mut WitnessMap, + ) -> Result<(), OpcodeResolutionError> { let mut bytes = self.bigint_solver.bigint_to_bytes(input)?; while bytes.len() < outputs.len() { bytes.push(0); } bytes.iter().zip(outputs.iter()).for_each(|(byte, output)| { - initial_witness.insert(*output, FieldElement::from(*byte as u128)); + initial_witness.insert(*output, F::from(*byte as u128)); }); Ok(()) } - pub(crate) fn bigint_op( + pub(crate) fn bigint_op( &mut self, lhs: u32, rhs: u32, output: u32, func: BlackBoxFunc, - ) -> Result<(), OpcodeResolutionError> { + ) -> Result<(), OpcodeResolutionError> { self.bigint_solver.bigint_op(lhs, rhs, output, func)?; Ok(()) } diff --git a/acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs b/acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs index 0b52ae295a..411a6d1b73 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/embedded_curve_ops.rs @@ -1,18 +1,19 @@ use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, + AcirField, }; use acvm_blackbox_solver::BlackBoxFunctionSolver; use crate::pwg::{insert_value, witness_to_value, OpcodeResolutionError}; -pub(super) fn multi_scalar_mul( - backend: &impl BlackBoxFunctionSolver, - initial_witness: &mut WitnessMap, +pub(super) fn multi_scalar_mul( + backend: &impl BlackBoxFunctionSolver, + initial_witness: &mut WitnessMap, points: &[FunctionInput], scalars: &[FunctionInput], outputs: (Witness, Witness, Witness), -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let points: Result, _> = points.iter().map(|input| witness_to_value(initial_witness, input.witness)).collect(); let points: Vec<_> = points?.into_iter().cloned().collect(); @@ -39,13 +40,13 @@ pub(super) fn multi_scalar_mul( Ok(()) } -pub(super) fn embedded_curve_add( - backend: &impl BlackBoxFunctionSolver, - initial_witness: &mut WitnessMap, +pub(super) fn embedded_curve_add( + backend: &impl BlackBoxFunctionSolver, + initial_witness: &mut WitnessMap, input1: [FunctionInput; 3], input2: [FunctionInput; 3], outputs: (Witness, Witness, Witness), -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let input1_x = witness_to_value(initial_witness, input1[0].witness)?; let input1_y = witness_to_value(initial_witness, input1[1].witness)?; let input1_infinite = witness_to_value(initial_witness, input1[2].witness)?; diff --git a/acvm-repo/acvm/src/pwg/blackbox/hash.rs b/acvm-repo/acvm/src/pwg/blackbox/hash.rs index caa09ea897..fe9bd46b09 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/hash.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/hash.rs @@ -1,7 +1,7 @@ use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, }; use acvm_blackbox_solver::{sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError}; @@ -10,13 +10,13 @@ use crate::OpcodeResolutionError; /// Attempts to solve a 256 bit hash function opcode. /// If successful, `initial_witness` will be mutated to contain the new witness assignment. -pub(super) fn solve_generic_256_hash_opcode( - initial_witness: &mut WitnessMap, +pub(super) fn solve_generic_256_hash_opcode( + initial_witness: &mut WitnessMap, inputs: &[FunctionInput], var_message_size: Option<&FunctionInput>, outputs: &[Witness; 32], hash_function: fn(data: &[u8]) -> Result<[u8; 32], BlackBoxResolutionError>, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let message_input = get_hash_input(initial_witness, inputs, var_message_size)?; let digest: [u8; 32] = hash_function(&message_input)?; @@ -24,11 +24,11 @@ pub(super) fn solve_generic_256_hash_opcode( } /// Reads the hash function input from a [`WitnessMap`]. -fn get_hash_input( - initial_witness: &WitnessMap, +fn get_hash_input( + initial_witness: &WitnessMap, inputs: &[FunctionInput], message_size: Option<&FunctionInput>, -) -> Result, OpcodeResolutionError> { +) -> Result, OpcodeResolutionError> { // Read witness assignments. let mut message_input = Vec::new(); for input in inputs.iter() { @@ -62,26 +62,22 @@ fn get_hash_input( } /// Writes a `digest` to the [`WitnessMap`] at witness indices `outputs`. -fn write_digest_to_outputs( - initial_witness: &mut WitnessMap, +fn write_digest_to_outputs( + initial_witness: &mut WitnessMap, outputs: &[Witness; 32], digest: [u8; 32], -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { for (output_witness, value) in outputs.iter().zip(digest.into_iter()) { - insert_value( - output_witness, - FieldElement::from_be_bytes_reduce(&[value]), - initial_witness, - )?; + insert_value(output_witness, F::from_be_bytes_reduce(&[value]), initial_witness)?; } Ok(()) } -fn to_u32_array( - initial_witness: &WitnessMap, +fn to_u32_array( + initial_witness: &WitnessMap, inputs: &[FunctionInput; N], -) -> Result<[u32; N], OpcodeResolutionError> { +) -> Result<[u32; N], OpcodeResolutionError> { let mut result = [0; N]; for (it, input) in result.iter_mut().zip(inputs) { let witness_value = witness_to_value(initial_witness, input.witness)?; @@ -90,31 +86,31 @@ fn to_u32_array( Ok(result) } -pub(crate) fn solve_sha_256_permutation_opcode( - initial_witness: &mut WitnessMap, +pub(crate) fn solve_sha_256_permutation_opcode( + initial_witness: &mut WitnessMap, inputs: &[FunctionInput; 16], hash_values: &[FunctionInput; 8], outputs: &[Witness; 8], -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let message = to_u32_array(initial_witness, inputs)?; let mut state = to_u32_array(initial_witness, hash_values)?; sha256compression(&mut state, &message); for (output_witness, value) in outputs.iter().zip(state.into_iter()) { - insert_value(output_witness, FieldElement::from(value as u128), initial_witness)?; + insert_value(output_witness, F::from(value as u128), initial_witness)?; } Ok(()) } -pub(crate) fn solve_poseidon2_permutation_opcode( - backend: &impl BlackBoxFunctionSolver, - initial_witness: &mut WitnessMap, +pub(crate) fn solve_poseidon2_permutation_opcode( + backend: &impl BlackBoxFunctionSolver, + initial_witness: &mut WitnessMap, inputs: &[FunctionInput], outputs: &[Witness], len: u32, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { if len as usize != inputs.len() { return Err(OpcodeResolutionError::BlackBoxFunctionFailed( acir::BlackBoxFunc::Poseidon2Permutation, diff --git a/acvm-repo/acvm/src/pwg/blackbox/logic.rs b/acvm-repo/acvm/src/pwg/blackbox/logic.rs index 8e69730f71..6e2ade3c49 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/logic.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/logic.rs @@ -3,17 +3,17 @@ use crate::OpcodeResolutionError; use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, }; /// Solves a [`BlackBoxFunc::And`][acir::circuit::black_box_functions::BlackBoxFunc::AND] opcode and inserts /// the result into the supplied witness map -pub(super) fn and( - initial_witness: &mut WitnessMap, +pub(super) fn and( + initial_witness: &mut WitnessMap, lhs: &FunctionInput, rhs: &FunctionInput, output: &Witness, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { assert_eq!( lhs.num_bits, rhs.num_bits, "number of bits specified for each input must be the same" @@ -25,12 +25,12 @@ pub(super) fn and( /// Solves a [`BlackBoxFunc::XOR`][acir::circuit::black_box_functions::BlackBoxFunc::XOR] opcode and inserts /// the result into the supplied witness map -pub(super) fn xor( - initial_witness: &mut WitnessMap, +pub(super) fn xor( + initial_witness: &mut WitnessMap, lhs: &FunctionInput, rhs: &FunctionInput, output: &Witness, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { assert_eq!( lhs.num_bits, rhs.num_bits, "number of bits specified for each input must be the same" @@ -41,13 +41,13 @@ pub(super) fn xor( } /// Derives the rest of the witness based on the initial low level variables -fn solve_logic_opcode( - initial_witness: &mut WitnessMap, +fn solve_logic_opcode( + initial_witness: &mut WitnessMap, a: &Witness, b: &Witness, result: Witness, - logic_op: impl Fn(&FieldElement, &FieldElement) -> FieldElement, -) -> Result<(), OpcodeResolutionError> { + logic_op: impl Fn(&F, &F) -> F, +) -> Result<(), OpcodeResolutionError> { let w_l_value = witness_to_value(initial_witness, *a)?; let w_r_value = witness_to_value(initial_witness, *b)?; let assignment = logic_op(w_l_value, w_r_value); diff --git a/acvm-repo/acvm/src/pwg/blackbox/mod.rs b/acvm-repo/acvm/src/pwg/blackbox/mod.rs index 99ed09a52e..8bda9221d8 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/mod.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/mod.rs @@ -1,7 +1,7 @@ use acir::{ circuit::opcodes::{BlackBoxFuncCall, FunctionInput}, native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, }; use acvm_blackbox_solver::{blake2s, blake3, keccak256, keccakf1600, sha256}; @@ -37,8 +37,8 @@ use signature::{ /// Check if all of the inputs to the function have assignments /// /// Returns the first missing assignment if any are missing -fn first_missing_assignment( - witness_assignments: &WitnessMap, +fn first_missing_assignment( + witness_assignments: &WitnessMap, inputs: &[FunctionInput], ) -> Option { inputs.iter().find_map(|input| { @@ -51,16 +51,16 @@ fn first_missing_assignment( } /// Check if all of the inputs to the function have assignments -fn contains_all_inputs(witness_assignments: &WitnessMap, inputs: &[FunctionInput]) -> bool { +fn contains_all_inputs(witness_assignments: &WitnessMap, inputs: &[FunctionInput]) -> bool { inputs.iter().all(|input| witness_assignments.contains_key(&input.witness)) } -pub(crate) fn solve( - backend: &impl BlackBoxFunctionSolver, - initial_witness: &mut WitnessMap, +pub(crate) fn solve( + backend: &impl BlackBoxFunctionSolver, + initial_witness: &mut WitnessMap, bb_func: &BlackBoxFuncCall, bigint_solver: &mut AcvmBigIntSolver, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let inputs = bb_func.get_inputs_vec(); if !contains_all_inputs(initial_witness, &inputs) { let unassigned_witness = first_missing_assignment(initial_witness, &inputs) @@ -108,7 +108,7 @@ pub(crate) fn solve( } let output_state = keccakf1600(state)?; for (output_witness, value) in outputs.iter().zip(output_state.into_iter()) { - insert_value(output_witness, FieldElement::from(value as u128), initial_witness)?; + insert_value(output_witness, F::from(value as u128), initial_witness)?; } Ok(()) } diff --git a/acvm-repo/acvm/src/pwg/blackbox/pedersen.rs b/acvm-repo/acvm/src/pwg/blackbox/pedersen.rs index bb214c1cea..f64a3a7946 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/pedersen.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/pedersen.rs @@ -1,6 +1,7 @@ use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, + AcirField, }; use crate::{ @@ -8,13 +9,13 @@ use crate::{ BlackBoxFunctionSolver, }; -pub(super) fn pedersen( - backend: &impl BlackBoxFunctionSolver, - initial_witness: &mut WitnessMap, +pub(super) fn pedersen( + backend: &impl BlackBoxFunctionSolver, + initial_witness: &mut WitnessMap, inputs: &[FunctionInput], domain_separator: u32, outputs: (Witness, Witness), -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let scalars: Result, _> = inputs.iter().map(|input| witness_to_value(initial_witness, input.witness)).collect(); let scalars: Vec<_> = scalars?.into_iter().cloned().collect(); @@ -27,13 +28,13 @@ pub(super) fn pedersen( Ok(()) } -pub(super) fn pedersen_hash( - backend: &impl BlackBoxFunctionSolver, - initial_witness: &mut WitnessMap, +pub(super) fn pedersen_hash( + backend: &impl BlackBoxFunctionSolver, + initial_witness: &mut WitnessMap, inputs: &[FunctionInput], domain_separator: u32, output: Witness, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let scalars: Result, _> = inputs.iter().map(|input| witness_to_value(initial_witness, input.witness)).collect(); let scalars: Vec<_> = scalars?.into_iter().cloned().collect(); diff --git a/acvm-repo/acvm/src/pwg/blackbox/range.rs b/acvm-repo/acvm/src/pwg/blackbox/range.rs index aac50b32fc..0ca001aff7 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/range.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/range.rs @@ -2,12 +2,12 @@ use crate::{ pwg::{witness_to_value, ErrorLocation}, OpcodeResolutionError, }; -use acir::{circuit::opcodes::FunctionInput, native_types::WitnessMap}; +use acir::{circuit::opcodes::FunctionInput, native_types::WitnessMap, AcirField}; -pub(crate) fn solve_range_opcode( - initial_witness: &WitnessMap, +pub(crate) fn solve_range_opcode( + initial_witness: &WitnessMap, input: &FunctionInput, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let w_value = witness_to_value(initial_witness, input.witness)?; if w_value.num_bits() > input.num_bits { return Err(OpcodeResolutionError::UnsatisfiedConstrain { diff --git a/acvm-repo/acvm/src/pwg/blackbox/signature/ecdsa.rs b/acvm-repo/acvm/src/pwg/blackbox/signature/ecdsa.rs index ce2e57e0bd..707e3f26af 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/signature/ecdsa.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/signature/ecdsa.rs @@ -1,7 +1,7 @@ use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, }; use acvm_blackbox_solver::{ecdsa_secp256k1_verify, ecdsa_secp256r1_verify}; @@ -13,14 +13,14 @@ use crate::{ OpcodeResolutionError, }; -pub(crate) fn secp256k1_prehashed( - initial_witness: &mut WitnessMap, +pub(crate) fn secp256k1_prehashed( + initial_witness: &mut WitnessMap, public_key_x_inputs: &[FunctionInput; 32], public_key_y_inputs: &[FunctionInput; 32], signature_inputs: &[FunctionInput; 64], hashed_message_inputs: &[FunctionInput], output: Witness, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let hashed_message = to_u8_vec(initial_witness, hashed_message_inputs)?; let pub_key_x: [u8; 32] = to_u8_array(initial_witness, public_key_x_inputs)?; @@ -29,17 +29,17 @@ pub(crate) fn secp256k1_prehashed( let is_valid = ecdsa_secp256k1_verify(&hashed_message, &pub_key_x, &pub_key_y, &signature)?; - insert_value(&output, FieldElement::from(is_valid), initial_witness) + insert_value(&output, F::from(is_valid), initial_witness) } -pub(crate) fn secp256r1_prehashed( - initial_witness: &mut WitnessMap, +pub(crate) fn secp256r1_prehashed( + initial_witness: &mut WitnessMap, public_key_x_inputs: &[FunctionInput; 32], public_key_y_inputs: &[FunctionInput; 32], signature_inputs: &[FunctionInput; 64], hashed_message_inputs: &[FunctionInput], output: Witness, -) -> Result<(), OpcodeResolutionError> { +) -> Result<(), OpcodeResolutionError> { let hashed_message = to_u8_vec(initial_witness, hashed_message_inputs)?; let pub_key_x: [u8; 32] = to_u8_array(initial_witness, public_key_x_inputs)?; @@ -48,5 +48,5 @@ pub(crate) fn secp256r1_prehashed( let is_valid = ecdsa_secp256r1_verify(&hashed_message, &pub_key_x, &pub_key_y, &signature)?; - insert_value(&output, FieldElement::from(is_valid), initial_witness) + insert_value(&output, F::from(is_valid), initial_witness) } diff --git a/acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs b/acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs index 7b085d9ff4..5e0ac94f8b 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/signature/schnorr.rs @@ -8,21 +8,21 @@ use crate::{ use acir::{ circuit::opcodes::FunctionInput, native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, }; #[allow(clippy::too_many_arguments)] -pub(crate) fn schnorr_verify( - backend: &impl BlackBoxFunctionSolver, - initial_witness: &mut WitnessMap, +pub(crate) fn schnorr_verify( + backend: &impl BlackBoxFunctionSolver, + initial_witness: &mut WitnessMap, public_key_x: FunctionInput, public_key_y: FunctionInput, signature: &[FunctionInput; 64], message: &[FunctionInput], output: Witness, -) -> Result<(), OpcodeResolutionError> { - let public_key_x: &FieldElement = witness_to_value(initial_witness, public_key_x.witness)?; - let public_key_y: &FieldElement = witness_to_value(initial_witness, public_key_y.witness)?; +) -> Result<(), OpcodeResolutionError> { + let public_key_x: &F = witness_to_value(initial_witness, public_key_x.witness)?; + let public_key_y: &F = witness_to_value(initial_witness, public_key_y.witness)?; let signature = to_u8_array(initial_witness, signature)?; let message = to_u8_vec(initial_witness, message)?; @@ -30,7 +30,7 @@ pub(crate) fn schnorr_verify( let valid_signature = backend.schnorr_verify(public_key_x, public_key_y, &signature, &message)?; - insert_value(&output, FieldElement::from(valid_signature), initial_witness)?; + insert_value(&output, F::from(valid_signature), initial_witness)?; Ok(()) } diff --git a/acvm-repo/acvm/src/pwg/blackbox/utils.rs b/acvm-repo/acvm/src/pwg/blackbox/utils.rs index 700f30890a..6880d21a32 100644 --- a/acvm-repo/acvm/src/pwg/blackbox/utils.rs +++ b/acvm-repo/acvm/src/pwg/blackbox/utils.rs @@ -1,11 +1,11 @@ -use acir::{circuit::opcodes::FunctionInput, native_types::WitnessMap}; +use acir::{circuit::opcodes::FunctionInput, native_types::WitnessMap, AcirField}; use crate::pwg::{witness_to_value, OpcodeResolutionError}; -pub(crate) fn to_u8_array( - initial_witness: &WitnessMap, +pub(crate) fn to_u8_array( + initial_witness: &WitnessMap, inputs: &[FunctionInput; N], -) -> Result<[u8; N], OpcodeResolutionError> { +) -> Result<[u8; N], OpcodeResolutionError> { let mut result = [0; N]; for (it, input) in result.iter_mut().zip(inputs) { let witness_value_bytes = witness_to_value(initial_witness, input.witness)?.to_be_bytes(); @@ -17,10 +17,10 @@ pub(crate) fn to_u8_array( Ok(result) } -pub(crate) fn to_u8_vec( - initial_witness: &WitnessMap, +pub(crate) fn to_u8_vec( + initial_witness: &WitnessMap, inputs: &[FunctionInput], -) -> Result, OpcodeResolutionError> { +) -> Result, OpcodeResolutionError> { let mut result = Vec::with_capacity(inputs.len()); for input in inputs { let witness_value_bytes = witness_to_value(initial_witness, input.witness)?.to_be_bytes(); diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index c911202c82..7e6c207b69 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -9,7 +9,7 @@ use acir::{ STRING_ERROR_SELECTOR, }, native_types::WitnessMap, - FieldElement, + AcirField, }; use acvm_blackbox_solver::BlackBoxFunctionSolver; use brillig_vm::{FailureReason, MemoryValue, VMStatus, VM}; @@ -19,31 +19,31 @@ use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError}; use super::{get_value, insert_value, memory_op::MemoryOpSolver}; #[derive(Debug)] -pub enum BrilligSolverStatus { +pub enum BrilligSolverStatus { Finished, InProgress, - ForeignCallWait(ForeignCallWaitInfo), + ForeignCallWait(ForeignCallWaitInfo), } -pub struct BrilligSolver<'b, B: BlackBoxFunctionSolver> { - vm: VM<'b, B>, +pub struct BrilligSolver<'b, F, B: BlackBoxFunctionSolver> { + vm: VM<'b, F, B>, acir_index: usize, } -impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { +impl<'b, B: BlackBoxFunctionSolver, F: AcirField> BrilligSolver<'b, F, B> { /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. pub(super) fn zero_out_brillig_outputs( - initial_witness: &mut WitnessMap, + initial_witness: &mut WitnessMap, outputs: &[BrilligOutputs], - ) -> Result<(), OpcodeResolutionError> { + ) -> Result<(), OpcodeResolutionError> { for output in outputs { match output { BrilligOutputs::Simple(witness) => { - insert_value(witness, FieldElement::zero(), initial_witness)?; + insert_value(witness, F::zero(), initial_witness)?; } BrilligOutputs::Array(witness_arr) => { for witness in witness_arr { - insert_value(witness, FieldElement::zero(), initial_witness)?; + insert_value(witness, F::zero(), initial_witness)?; } } } @@ -54,27 +54,27 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { /// Constructs a solver for a Brillig block given the bytecode and initial /// witness. pub(crate) fn new_call( - initial_witness: &WitnessMap, - memory: &HashMap, - inputs: &'b [BrilligInputs], - brillig_bytecode: &'b [BrilligOpcode], + initial_witness: &WitnessMap, + memory: &HashMap>, + inputs: &'b [BrilligInputs], + brillig_bytecode: &'b [BrilligOpcode], bb_solver: &'b B, acir_index: usize, - ) -> Result { + ) -> Result> { let vm = Self::setup_brillig_vm(initial_witness, memory, inputs, brillig_bytecode, bb_solver)?; Ok(Self { vm, acir_index }) } fn setup_brillig_vm( - initial_witness: &WitnessMap, - memory: &HashMap, - inputs: &[BrilligInputs], - brillig_bytecode: &'b [BrilligOpcode], + initial_witness: &WitnessMap, + memory: &HashMap>, + inputs: &[BrilligInputs], + brillig_bytecode: &'b [BrilligOpcode], bb_solver: &'b B, - ) -> Result, OpcodeResolutionError> { + ) -> Result, OpcodeResolutionError> { // Set input values - let mut calldata: Vec = Vec::new(); + let mut calldata: Vec = Vec::new(); // Each input represents an expression or array of expressions to evaluate. // Iterate over each input and evaluate the expression(s) associated with it. // Push the results into memory. @@ -123,11 +123,11 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { Ok(vm) } - pub fn get_memory(&self) -> &[MemoryValue] { + pub fn get_memory(&self) -> &[MemoryValue] { self.vm.get_memory() } - pub fn write_memory_at(&mut self, ptr: usize, value: MemoryValue) { + pub fn write_memory_at(&mut self, ptr: usize, value: MemoryValue) { self.vm.write_memory_at(ptr, value); } @@ -135,12 +135,12 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { self.vm.get_call_stack() } - pub(crate) fn solve(&mut self) -> Result { + pub(crate) fn solve(&mut self) -> Result, OpcodeResolutionError> { let status = self.vm.process_opcodes(); self.handle_vm_status(status) } - pub fn step(&mut self) -> Result { + pub fn step(&mut self) -> Result, OpcodeResolutionError> { let status = self.vm.process_opcode(); self.handle_vm_status(status) } @@ -151,8 +151,8 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { fn handle_vm_status( &self, - vm_status: VMStatus, - ) -> Result { + vm_status: VMStatus, + ) -> Result, OpcodeResolutionError> { // Check the status of the Brillig VM and return a resolution. // It may be finished, in-progress, failed, or may be waiting for results of a foreign call. // Return the "resolution" to the caller who may choose to make subsequent calls @@ -227,9 +227,9 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { pub(crate) fn finalize( self, - witness: &mut WitnessMap, + witness: &mut WitnessMap, outputs: &[BrilligOutputs], - ) -> Result<(), OpcodeResolutionError> { + ) -> Result<(), OpcodeResolutionError> { // Finish the Brillig execution by writing the outputs to the witness map let vm_status = self.vm.get_status(); match vm_status { @@ -243,11 +243,11 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { fn write_brillig_outputs( &self, - witness_map: &mut WitnessMap, + witness_map: &mut WitnessMap, return_data_offset: usize, return_data_size: usize, outputs: &[BrilligOutputs], - ) -> Result<(), OpcodeResolutionError> { + ) -> Result<(), OpcodeResolutionError> { // Write VM execution results into the witness map let memory = self.vm.get_memory(); let mut current_ret_data_idx = return_data_offset; @@ -274,7 +274,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { Ok(()) } - pub fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { + pub fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { match self.vm.get_status() { VMStatus::ForeignCallWait { .. } => self.vm.resolve_foreign_call(foreign_call_result), _ => unreachable!("Brillig VM is not waiting for a foreign call"), @@ -287,9 +287,9 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { /// /// The caller must resolve this opcode externally based upon the information in the request. #[derive(Debug, PartialEq, Clone)] -pub struct ForeignCallWaitInfo { +pub struct ForeignCallWaitInfo { /// An identifier interpreted by the caller process pub function: String, /// Resolved inputs to a foreign call computed in the previous steps of a Brillig VM process - pub inputs: Vec, + pub inputs: Vec>, } diff --git a/acvm-repo/acvm/src/pwg/directives/mod.rs b/acvm-repo/acvm/src/pwg/directives/mod.rs index db79379a37..d7bee88c27 100644 --- a/acvm-repo/acvm/src/pwg/directives/mod.rs +++ b/acvm-repo/acvm/src/pwg/directives/mod.rs @@ -1,4 +1,4 @@ -use acir::{circuit::directives::Directive, native_types::WitnessMap, FieldElement}; +use acir::{circuit::directives::Directive, native_types::WitnessMap, AcirField}; use num_bigint::BigUint; use crate::OpcodeResolutionError; @@ -11,10 +11,10 @@ use super::{get_value, insert_value, ErrorLocation}; /// Returns `Ok(OpcodeResolution)` to signal whether the directive was successful solved. /// /// Returns `Err(OpcodeResolutionError)` if a circuit constraint is unsatisfied. -pub(crate) fn solve_directives( - initial_witness: &mut WitnessMap, - directive: &Directive, -) -> Result<(), OpcodeResolutionError> { +pub(crate) fn solve_directives( + initial_witness: &mut WitnessMap, + directive: &Directive, +) -> Result<(), OpcodeResolutionError> { match directive { Directive::ToLeRadix { a, b, radix } => { let value_a = get_value(a, initial_witness)?; @@ -36,8 +36,8 @@ pub(crate) fn solve_directives( // If it is not available, which can happen when the decomposed integer // list is shorter than the witness list, we return 0. let value = match decomposed_integer.get(i) { - Some(digit) => FieldElement::from_be_bytes_reduce(&[*digit]), - None => FieldElement::zero(), + Some(digit) => F::from_be_bytes_reduce(&[*digit]), + None => F::zero(), }; insert_value(witness, value, initial_witness)?; diff --git a/acvm-repo/acvm/src/pwg/memory_op.rs b/acvm-repo/acvm/src/pwg/memory_op.rs index 672c13e11c..a9ed7f5d15 100644 --- a/acvm-repo/acvm/src/pwg/memory_op.rs +++ b/acvm-repo/acvm/src/pwg/memory_op.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use acir::{ circuit::opcodes::MemOp, native_types::{Expression, Witness, WitnessMap}, - FieldElement, + AcirField, }; use super::{ @@ -15,17 +15,17 @@ type MemoryIndex = u32; /// Maintains the state for solving [`MemoryInit`][`acir::circuit::Opcode::MemoryInit`] and [`MemoryOp`][`acir::circuit::Opcode::MemoryOp`] opcodes. #[derive(Default)] -pub(crate) struct MemoryOpSolver { - pub(super) block_value: HashMap, +pub(crate) struct MemoryOpSolver { + pub(super) block_value: HashMap, pub(super) block_len: u32, } -impl MemoryOpSolver { +impl MemoryOpSolver { fn write_memory_index( &mut self, index: MemoryIndex, - value: FieldElement, - ) -> Result<(), OpcodeResolutionError> { + value: F, + ) -> Result<(), OpcodeResolutionError> { if index >= self.block_len { return Err(OpcodeResolutionError::IndexOutOfBounds { opcode_location: ErrorLocation::Unresolved, @@ -37,7 +37,7 @@ impl MemoryOpSolver { Ok(()) } - fn read_memory_index(&self, index: MemoryIndex) -> Result { + fn read_memory_index(&self, index: MemoryIndex) -> Result> { self.block_value.get(&index).copied().ok_or(OpcodeResolutionError::IndexOutOfBounds { opcode_location: ErrorLocation::Unresolved, index, @@ -49,8 +49,8 @@ impl MemoryOpSolver { pub(crate) fn init( &mut self, init: &[Witness], - initial_witness: &WitnessMap, - ) -> Result<(), OpcodeResolutionError> { + initial_witness: &WitnessMap, + ) -> Result<(), OpcodeResolutionError> { self.block_len = init.len() as u32; for (memory_index, witness) in init.iter().enumerate() { self.write_memory_index( @@ -63,10 +63,10 @@ impl MemoryOpSolver { pub(crate) fn solve_memory_op( &mut self, - op: &MemOp, - initial_witness: &mut WitnessMap, - predicate: &Option, - ) -> Result<(), OpcodeResolutionError> { + op: &MemOp, + initial_witness: &mut WitnessMap, + predicate: &Option>, + ) -> Result<(), OpcodeResolutionError> { let operation = get_value(&op.operation, initial_witness)?; // Find the memory index associated with this memory operation. @@ -96,11 +96,8 @@ impl MemoryOpSolver { // A zero predicate indicates that we should skip the read operation // and zero out the operation's output. - let value_in_array = if skip_operation { - FieldElement::zero() - } else { - self.read_memory_index(memory_index)? - }; + let value_in_array = + if skip_operation { F::zero() } else { self.read_memory_index(memory_index)? }; insert_value(&value_read_witness, value_in_array, initial_witness) } else { // `arr[memory_index] = value_write` @@ -129,7 +126,7 @@ mod tests { use acir::{ circuit::opcodes::MemOp, native_types::{Expression, Witness, WitnessMap}, - FieldElement, + AcirField, FieldElement, }; use super::MemoryOpSolver; diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index f2649b9399..da4510db63 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -10,7 +10,7 @@ use acir::{ STRING_ERROR_SELECTOR, }, native_types::{Expression, Witness, WitnessMap}, - BlackBoxFunc, FieldElement, + AcirField, BlackBoxFunc, }; use acvm_blackbox_solver::BlackBoxResolutionError; @@ -36,7 +36,7 @@ pub use self::brillig::{BrilligSolver, BrilligSolverStatus}; pub use brillig::ForeignCallWaitInfo; #[derive(Debug, Clone, PartialEq)] -pub enum ACVMStatus { +pub enum ACVMStatus { /// All opcodes have been solved. Solved, @@ -45,23 +45,23 @@ pub enum ACVMStatus { /// The ACVM has encountered an irrecoverable error while executing the circuit and can not progress. /// Most commonly this will be due to an unsatisfied constraint due to invalid inputs to the circuit. - Failure(OpcodeResolutionError), + Failure(OpcodeResolutionError), /// The ACVM has encountered a request for a Brillig [foreign call][acir::brillig_vm::Opcode::ForeignCall] /// to retrieve information from outside of the ACVM. The result of the foreign call must be passed back /// to the ACVM using [`ACVM::resolve_pending_foreign_call`]. /// /// Once this is done, the ACVM can be restarted to solve the remaining opcodes. - RequiresForeignCall(ForeignCallWaitInfo), + RequiresForeignCall(ForeignCallWaitInfo), /// The ACVM has encountered a request for an ACIR [call][acir::circuit::Opcode] /// to execute a separate ACVM instance. The result of the ACIR call must be passd back to the ACVM. /// /// Once this is done, the ACVM can be restarted to solve the remaining opcodes. - RequiresAcirCall(AcirCallWaitInfo), + RequiresAcirCall(AcirCallWaitInfo), } -impl std::fmt::Display for ACVMStatus { +impl std::fmt::Display for ACVMStatus { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { ACVMStatus::Solved => write!(f, "Solved"), @@ -73,9 +73,9 @@ impl std::fmt::Display for ACVMStatus { } } -pub enum StepResult<'a, B: BlackBoxFunctionSolver> { - Status(ACVMStatus), - IntoBrillig(BrilligSolver<'a, B>), +pub enum StepResult<'a, F, B: BlackBoxFunctionSolver> { + Status(ACVMStatus), + IntoBrillig(BrilligSolver<'a, F, B>), } // This enum represents the different cases in which an @@ -87,13 +87,13 @@ pub enum StepResult<'a, B: BlackBoxFunctionSolver> { // TODO: we could have a error enum for expression solver failure cases in that module // TODO that can be converted into an OpcodeNotSolvable or OpcodeResolutionError enum #[derive(Clone, PartialEq, Eq, Debug, Error)] -pub enum OpcodeNotSolvable { +pub enum OpcodeNotSolvable { #[error("missing assignment for witness index {0}")] MissingAssignment(u32), #[error("Attempted to load uninitialized memory block")] MissingMemoryBlock(u32), #[error("expression has too many unknowns {0}")] - ExpressionHasTooManyUnknowns(Expression), + ExpressionHasTooManyUnknowns(Expression), } /// Allows to point to a specific opcode as cause in errors. @@ -117,13 +117,13 @@ impl std::fmt::Display for ErrorLocation { } #[derive(Clone, PartialEq, Eq, Debug, Error)] -pub enum OpcodeResolutionError { +pub enum OpcodeResolutionError { #[error("Cannot solve opcode: {0}")] - OpcodeNotSolvable(#[from] OpcodeNotSolvable), + OpcodeNotSolvable(#[from] OpcodeNotSolvable), #[error("Cannot satisfy constraint")] UnsatisfiedConstrain { opcode_location: ErrorLocation, - payload: Option, + payload: Option>, }, #[error("Index out of bounds, array has size {array_size:?}, but index was {index:?}")] IndexOutOfBounds { opcode_location: ErrorLocation, index: u32, array_size: u32 }, @@ -132,7 +132,7 @@ pub enum OpcodeResolutionError { #[error("Failed to solve brillig function")] BrilligFunctionFailed { call_stack: Vec, - payload: Option, + payload: Option>, }, #[error("Attempted to call `main` with a `Call` opcode")] AcirMainCallAttempted { opcode_location: ErrorLocation }, @@ -140,7 +140,7 @@ pub enum OpcodeResolutionError { AcirCallOutputsMismatch { opcode_location: ErrorLocation, results_size: u32, outputs_size: u32 }, } -impl From for OpcodeResolutionError { +impl From for OpcodeResolutionError { fn from(value: BlackBoxResolutionError) -> Self { match value { BlackBoxResolutionError::Failed(func, reason) => { @@ -150,45 +150,45 @@ impl From for OpcodeResolutionError { } } -pub struct ACVM<'a, B: BlackBoxFunctionSolver> { - status: ACVMStatus, +pub struct ACVM<'a, F, B: BlackBoxFunctionSolver> { + status: ACVMStatus, backend: &'a B, /// Stores the solver for memory operations acting on blocks of memory disambiguated by [block][`BlockId`]. - block_solvers: HashMap, + block_solvers: HashMap>, bigint_solver: AcvmBigIntSolver, /// A list of opcodes which are to be executed by the ACVM. - opcodes: &'a [Opcode], + opcodes: &'a [Opcode], /// Index of the next opcode to be executed. instruction_pointer: usize, - witness_map: WitnessMap, + witness_map: WitnessMap, - brillig_solver: Option>, + brillig_solver: Option>, /// A counter maintained throughout an ACVM process that determines /// whether the caller has resolved the results of an ACIR [call][Opcode::Call]. acir_call_counter: usize, /// Represents the outputs of all ACIR calls during an ACVM process /// List is appended onto by the caller upon reaching a [ACVMStatus::RequiresAcirCall] - acir_call_results: Vec>, + acir_call_results: Vec>, // Each unconstrained function referenced in the program - unconstrained_functions: &'a [BrilligBytecode], + unconstrained_functions: &'a [BrilligBytecode], - assertion_payloads: &'a [(OpcodeLocation, AssertionPayload)], + assertion_payloads: &'a [(OpcodeLocation, AssertionPayload)], } -impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { +impl<'a, F: AcirField, B: BlackBoxFunctionSolver> ACVM<'a, F, B> { pub fn new( backend: &'a B, - opcodes: &'a [Opcode], - initial_witness: WitnessMap, - unconstrained_functions: &'a [BrilligBytecode], - assertion_payloads: &'a [(OpcodeLocation, AssertionPayload)], + opcodes: &'a [Opcode], + initial_witness: WitnessMap, + unconstrained_functions: &'a [BrilligBytecode], + assertion_payloads: &'a [(OpcodeLocation, AssertionPayload)], ) -> Self { let status = if opcodes.is_empty() { ACVMStatus::Solved } else { ACVMStatus::InProgress }; ACVM { @@ -210,20 +210,16 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { /// Returns a reference to the current state of the ACVM's [`WitnessMap`]. /// /// Once execution has completed, the witness map can be extracted using [`ACVM::finalize`] - pub fn witness_map(&self) -> &WitnessMap { + pub fn witness_map(&self) -> &WitnessMap { &self.witness_map } - pub fn overwrite_witness( - &mut self, - witness: Witness, - value: FieldElement, - ) -> Option { + pub fn overwrite_witness(&mut self, witness: Witness, value: F) -> Option { self.witness_map.insert(witness, value) } /// Returns a slice containing the opcodes of the circuit being executed. - pub fn opcodes(&self) -> &[Opcode] { + pub fn opcodes(&self) -> &[Opcode] { self.opcodes } @@ -233,7 +229,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { } /// Finalize the ACVM execution, returning the resulting [`WitnessMap`]. - pub fn finalize(self) -> WitnessMap { + pub fn finalize(self) -> WitnessMap { if self.status != ACVMStatus::Solved { panic!("ACVM execution is not complete: ({})", self.status); } @@ -242,29 +238,29 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { /// Updates the current status of the VM. /// Returns the given status. - fn status(&mut self, status: ACVMStatus) -> ACVMStatus { + fn status(&mut self, status: ACVMStatus) -> ACVMStatus { self.status = status.clone(); status } - pub fn get_status(&self) -> &ACVMStatus { + pub fn get_status(&self) -> &ACVMStatus { &self.status } /// Sets the VM status to [ACVMStatus::Failure] using the provided `error`. /// Returns the new status. - fn fail(&mut self, error: OpcodeResolutionError) -> ACVMStatus { + fn fail(&mut self, error: OpcodeResolutionError) -> ACVMStatus { self.status(ACVMStatus::Failure(error)) } /// Sets the status of the VM to `RequiresForeignCall`. /// Indicating that the VM is now waiting for a foreign call to be resolved. - fn wait_for_foreign_call(&mut self, foreign_call: ForeignCallWaitInfo) -> ACVMStatus { + fn wait_for_foreign_call(&mut self, foreign_call: ForeignCallWaitInfo) -> ACVMStatus { self.status(ACVMStatus::RequiresForeignCall(foreign_call)) } /// Return a reference to the arguments for the next pending foreign call, if one exists. - pub fn get_pending_foreign_call(&self) -> Option<&ForeignCallWaitInfo> { + pub fn get_pending_foreign_call(&self) -> Option<&ForeignCallWaitInfo> { if let ACVMStatus::RequiresForeignCall(foreign_call) = &self.status { Some(foreign_call) } else { @@ -275,7 +271,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { /// Resolves a foreign call's [result][acir::brillig_vm::ForeignCallResult] using a result calculated outside of the ACVM. /// /// The ACVM can then be restarted to solve the remaining Brillig VM process as well as the remaining ACIR opcodes. - pub fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { + pub fn resolve_pending_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { if !matches!(self.status, ACVMStatus::RequiresForeignCall(_)) { panic!("ACVM is not expecting a foreign call response as no call was made"); } @@ -289,14 +285,14 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { /// Sets the status of the VM to `RequiresAcirCall` /// Indicating that the VM is now waiting for an ACIR call to be resolved - fn wait_for_acir_call(&mut self, acir_call: AcirCallWaitInfo) -> ACVMStatus { + fn wait_for_acir_call(&mut self, acir_call: AcirCallWaitInfo) -> ACVMStatus { self.status(ACVMStatus::RequiresAcirCall(acir_call)) } /// Resolves an ACIR call's result (simply a list of fields) using a result calculated by a separate ACVM instance. /// /// The current ACVM instance can then be restarted to solve the remaining ACIR opcodes. - pub fn resolve_pending_acir_call(&mut self, call_result: Vec) { + pub fn resolve_pending_acir_call(&mut self, call_result: Vec) { if !matches!(self.status, ACVMStatus::RequiresAcirCall(_)) { panic!("ACVM is not expecting an ACIR call response as no call was made"); } @@ -316,14 +312,14 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { /// 1. All opcodes have been executed successfully. /// 2. The circuit has been found to be unsatisfiable. /// 2. A Brillig [foreign call][`ForeignCallWaitInfo`] has been encountered and must be resolved. - pub fn solve(&mut self) -> ACVMStatus { + pub fn solve(&mut self) -> ACVMStatus { while self.status == ACVMStatus::InProgress { self.solve_opcode(); } self.status.clone() } - pub fn solve_opcode(&mut self) -> ACVMStatus { + pub fn solve_opcode(&mut self) -> ACVMStatus { let opcode = &self.opcodes[self.instruction_pointer]; let resolution = match opcode { @@ -357,8 +353,8 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { fn handle_opcode_resolution( &mut self, - resolution: Result<(), OpcodeResolutionError>, - ) -> ACVMStatus { + resolution: Result<(), OpcodeResolutionError>, + ) -> ACVMStatus { match resolution { Ok(()) => { self.instruction_pointer += 1; @@ -400,7 +396,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { fn extract_assertion_payload( &self, location: OpcodeLocation, - ) -> Option { + ) -> Option> { let (_, found_assertion_payload) = self.assertion_payloads.iter().find(|(loc, _)| location == *loc)?; match found_assertion_payload { @@ -458,7 +454,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { fn solve_brillig_call_opcode( &mut self, - ) -> Result, OpcodeResolutionError> { + ) -> Result>, OpcodeResolutionError> { let Opcode::BrilligCall { id, inputs, outputs, predicate } = &self.opcodes[self.instruction_pointer] else { @@ -466,13 +462,13 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { }; if is_predicate_false(&self.witness_map, predicate)? { - return BrilligSolver::::zero_out_brillig_outputs(&mut self.witness_map, outputs) + return BrilligSolver::::zero_out_brillig_outputs(&mut self.witness_map, outputs) .map(|_| None); } // If we're resuming execution after resolving a foreign call then // there will be a cached `BrilligSolver` to avoid recomputation. - let mut solver: BrilligSolver<'_, B> = match self.brillig_solver.take() { + let mut solver: BrilligSolver<'_, F, B> = match self.brillig_solver.take() { Some(solver) => solver, None => BrilligSolver::new_call( &self.witness_map, @@ -503,7 +499,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { } } - fn map_brillig_error(&self, mut err: OpcodeResolutionError) -> OpcodeResolutionError { + fn map_brillig_error(&self, mut err: OpcodeResolutionError) -> OpcodeResolutionError { match &mut err { OpcodeResolutionError::BrilligFunctionFailed { call_stack, payload } => { // Some brillig errors have static strings as payloads, we can resolve them here @@ -528,7 +524,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { } } - pub fn step_into_brillig(&mut self) -> StepResult<'a, B> { + pub fn step_into_brillig(&mut self) -> StepResult<'a, F, B> { let Opcode::BrilligCall { id, inputs, outputs, predicate } = &self.opcodes[self.instruction_pointer] else { @@ -541,7 +537,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { Err(err) => return StepResult::Status(self.handle_opcode_resolution(Err(err))), }; if should_skip { - let resolution = BrilligSolver::::zero_out_brillig_outputs(witness, outputs); + let resolution = BrilligSolver::::zero_out_brillig_outputs(witness, outputs); return StepResult::Status(self.handle_opcode_resolution(resolution)); } @@ -559,7 +555,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { } } - pub fn finish_brillig_with_solver(&mut self, solver: BrilligSolver<'a, B>) -> ACVMStatus { + pub fn finish_brillig_with_solver(&mut self, solver: BrilligSolver<'a, F, B>) -> ACVMStatus { if !matches!(self.opcodes[self.instruction_pointer], Opcode::BrilligCall { .. }) { unreachable!("Not executing a Brillig/BrilligCall opcode"); } @@ -567,7 +563,9 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { self.solve_opcode() } - pub fn solve_call_opcode(&mut self) -> Result, OpcodeResolutionError> { + pub fn solve_call_opcode( + &mut self, + ) -> Result>, OpcodeResolutionError> { let Opcode::Call { id, inputs, outputs, predicate } = &self.opcodes[self.instruction_pointer] else { @@ -584,7 +582,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { if is_predicate_false(&self.witness_map, predicate)? { // Zero out the outputs if we have a false predicate for output in outputs { - insert_value(output, FieldElement::zero(), &mut self.witness_map)?; + insert_value(output, F::zero(), &mut self.witness_map)?; } return Ok(None); } @@ -621,10 +619,10 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> { // Returns the concrete value for a particular witness // If the witness has no assignment, then // an error is returned -pub fn witness_to_value( - initial_witness: &WitnessMap, +pub fn witness_to_value( + initial_witness: &WitnessMap, witness: Witness, -) -> Result<&FieldElement, OpcodeResolutionError> { +) -> Result<&F, OpcodeResolutionError> { match initial_witness.get(&witness) { Some(value) => Ok(value), None => Err(OpcodeNotSolvable::MissingAssignment(witness.0).into()), @@ -633,10 +631,10 @@ pub fn witness_to_value( // TODO: There is an issue open to decide on whether we need to get values from Expressions // TODO versus just getting values from Witness -pub fn get_value( - expr: &Expression, - initial_witness: &WitnessMap, -) -> Result { +pub fn get_value( + expr: &Expression, + initial_witness: &WitnessMap, +) -> Result> { let expr = ExpressionSolver::evaluate(expr, initial_witness); match expr.to_const() { Some(value) => Ok(value), @@ -650,11 +648,11 @@ pub fn get_value( /// /// Returns an error if there was already a value in the map /// which does not match the value that one is about to insert -pub fn insert_value( +pub fn insert_value( witness: &Witness, - value_to_insert: FieldElement, - initial_witness: &mut WitnessMap, -) -> Result<(), OpcodeResolutionError> { + value_to_insert: F, + initial_witness: &mut WitnessMap, +) -> Result<(), OpcodeResolutionError> { let optional_old_value = initial_witness.insert(*witness, value_to_insert); let old_value = match optional_old_value { @@ -675,7 +673,7 @@ pub fn insert_value( // Returns one witness belonging to an expression, in no relevant order // Returns None if the expression is const // The function is used during partial witness generation to report unsolved witness -fn any_witness_from_expression(expr: &Expression) -> Option { +fn any_witness_from_expression(expr: &Expression) -> Option { if expr.linear_combinations.is_empty() { if expr.mul_terms.is_empty() { None @@ -690,10 +688,10 @@ fn any_witness_from_expression(expr: &Expression) -> Option { /// Returns `true` if the predicate is zero /// A predicate is used to indicate whether we should skip a certain operation. /// If we have a zero predicate it means the operation should be skipped. -pub(crate) fn is_predicate_false( - witness: &WitnessMap, - predicate: &Option, -) -> Result { +pub(crate) fn is_predicate_false( + witness: &WitnessMap, + predicate: &Option>, +) -> Result> { match predicate { Some(pred) => get_value(pred, witness).map(|pred_value| pred_value.is_zero()), // If the predicate is `None`, then we treat it as an unconditional `true` @@ -702,9 +700,9 @@ pub(crate) fn is_predicate_false( } #[derive(Debug, Clone, PartialEq)] -pub struct AcirCallWaitInfo { +pub struct AcirCallWaitInfo { /// Index in the list of ACIR function's that should be called pub id: u32, /// Initial witness for the given circuit to be called - pub initial_witness: WitnessMap, + pub initial_witness: WitnessMap, } diff --git a/acvm-repo/acvm/tests/solver.rs b/acvm-repo/acvm/tests/solver.rs index 495389d7b3..e55dbb73ae 100644 --- a/acvm-repo/acvm/tests/solver.rs +++ b/acvm-repo/acvm/tests/solver.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use acir::{ + acir_field::GenericFieldElement, brillig::{BinaryFieldOp, HeapArray, MemoryAddress, Opcode as BrilligOpcode, ValueOrArray}, circuit::{ brillig::{BrilligBytecode, BrilligInputs, BrilligOutputs}, @@ -8,7 +9,7 @@ use acir::{ Opcode, OpcodeLocation, }, native_types::{Expression, Witness, WitnessMap}, - FieldElement, + AcirField, FieldElement, }; use acvm::pwg::{ACVMStatus, ErrorLocation, ForeignCallWaitInfo, OpcodeResolutionError, ACVM}; @@ -17,6 +18,38 @@ use brillig_vm::brillig::HeapValueType; // Reenable these test cases once we move the brillig implementation of inversion down into the acvm stdlib. +#[test] +fn bls12_381_circuit() { + type Bls12FieldElement = GenericFieldElement; + + let addition = Opcode::AssertZero(Expression { + mul_terms: Vec::new(), + linear_combinations: vec![ + (Bls12FieldElement::one(), Witness(1)), + (Bls12FieldElement::one(), Witness(2)), + (-Bls12FieldElement::one(), Witness(3)), + ], + q_c: Bls12FieldElement::zero(), + }); + let opcodes = [addition]; + + let witness_assignments = BTreeMap::from([ + (Witness(1), Bls12FieldElement::from(2u128)), + (Witness(2), Bls12FieldElement::from(3u128)), + ]) + .into(); + + let mut acvm = ACVM::new(&StubbedBlackBoxSolver, &opcodes, witness_assignments, &[], &[]); + // use the partial witness generation solver with our acir program + let solver_status = acvm.solve(); + assert_eq!(solver_status, ACVMStatus::Solved, "should be fully solved"); + + // ACVM should be able to be finalized in `Solved` state. + let witness_stack = acvm.finalize(); + + assert_eq!(witness_stack.get(&Witness(3)).unwrap(), &Bls12FieldElement::from(5u128)); +} + #[test] fn inversion_brillig_oracle_equivalence() { // Opcodes below describe the following: @@ -123,7 +156,7 @@ fn inversion_brillig_oracle_equivalence() { ); assert_eq!(acvm.instruction_pointer(), 0, "brillig should have been removed"); - let foreign_call_wait_info: &ForeignCallWaitInfo = + let foreign_call_wait_info: &ForeignCallWaitInfo = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); @@ -268,7 +301,7 @@ fn double_inversion_brillig_oracle() { ); assert_eq!(acvm.instruction_pointer(), 0, "should stall on brillig"); - let foreign_call_wait_info: &ForeignCallWaitInfo = + let foreign_call_wait_info: &ForeignCallWaitInfo = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); @@ -405,7 +438,7 @@ fn oracle_dependent_execution() { ); assert_eq!(acvm.instruction_pointer(), 1, "should stall on brillig"); - let foreign_call_wait_info: &ForeignCallWaitInfo = + let foreign_call_wait_info: &ForeignCallWaitInfo = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); @@ -421,7 +454,7 @@ fn oracle_dependent_execution() { ); assert_eq!(acvm.instruction_pointer(), 1, "should stall on brillig"); - let foreign_call_wait_info: &ForeignCallWaitInfo = + let foreign_call_wait_info: &ForeignCallWaitInfo = acvm.get_pending_foreign_call().expect("should have a brillig foreign call request"); assert_eq!(foreign_call_wait_info.inputs.len(), 1, "Should be waiting for a single input"); diff --git a/acvm-repo/acvm_js/src/black_box_solvers.rs b/acvm-repo/acvm_js/src/black_box_solvers.rs index 188e5334ed..4f2676f8d2 100644 --- a/acvm-repo/acvm_js/src/black_box_solvers.rs +++ b/acvm-repo/acvm_js/src/black_box_solvers.rs @@ -2,7 +2,7 @@ use js_sys::JsString; use wasm_bindgen::prelude::*; use crate::js_witness_map::{field_element_to_js_string, js_value_to_field_element}; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; /// Performs a bitwise AND operation between `lhs` and `rhs` #[wasm_bindgen] diff --git a/acvm-repo/acvm_js/src/execute.rs b/acvm-repo/acvm_js/src/execute.rs index 9f2b07b31f..85cc3f455a 100644 --- a/acvm-repo/acvm_js/src/execute.rs +++ b/acvm-repo/acvm_js/src/execute.rs @@ -2,12 +2,12 @@ use std::{future::Future, pin::Pin}; use acvm::acir::circuit::brillig::BrilligBytecode; use acvm::acir::circuit::ResolvedAssertionPayload; -use acvm::BlackBoxFunctionSolver; use acvm::{ acir::circuit::{Circuit, Program}, acir::native_types::{WitnessMap, WitnessStack}, pwg::{ACVMStatus, ErrorLocation, OpcodeResolutionError, ACVM}, }; +use acvm::{BlackBoxFunctionSolver, FieldElement}; use bn254_blackbox_solver::Bn254BlackBoxSolver; use js_sys::Error; @@ -72,7 +72,7 @@ pub async fn execute_circuit_with_return_witness( ) -> Result { console_error_panic_hook::set_once(); - let program: Program = Program::deserialize_program(&program) + let program: Program = Program::deserialize_program(&program) .map_err(|_| JsExecutionError::new("Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None, None))?; let mut witness_stack = execute_program_with_native_program_and_return( @@ -148,8 +148,8 @@ async fn execute_program_with_native_type_return( program: Vec, initial_witness: JsWitnessMap, foreign_call_executor: &ForeignCallHandler, -) -> Result { - let program: Program = Program::deserialize_program(&program) +) -> Result, Error> { + let program: Program = Program::deserialize_program(&program) .map_err(|_| JsExecutionError::new( "Failed to deserialize circuit. This is likely due to differing serialization formats between ACVM_JS and your compiler".to_string(), None, @@ -160,10 +160,10 @@ async fn execute_program_with_native_type_return( } async fn execute_program_with_native_program_and_return( - program: &Program, + program: &Program, initial_witness: JsWitnessMap, foreign_call_executor: &ForeignCallHandler, -) -> Result { +) -> Result, Error> { let blackbox_solver = Bn254BlackBoxSolver; let executor = ProgramExecutor::new( &program.functions, @@ -176,20 +176,20 @@ async fn execute_program_with_native_program_and_return( Ok(witness_stack) } -struct ProgramExecutor<'a, B: BlackBoxFunctionSolver> { - functions: &'a [Circuit], +struct ProgramExecutor<'a, B: BlackBoxFunctionSolver> { + functions: &'a [Circuit], - unconstrained_functions: &'a [BrilligBytecode], + unconstrained_functions: &'a [BrilligBytecode], blackbox_solver: &'a B, foreign_call_handler: &'a ForeignCallHandler, } -impl<'a, B: BlackBoxFunctionSolver> ProgramExecutor<'a, B> { +impl<'a, B: BlackBoxFunctionSolver> ProgramExecutor<'a, B> { fn new( - functions: &'a [Circuit], - unconstrained_functions: &'a [BrilligBytecode], + functions: &'a [Circuit], + unconstrained_functions: &'a [BrilligBytecode], blackbox_solver: &'a B, foreign_call_handler: &'a ForeignCallHandler, ) -> Self { @@ -201,7 +201,10 @@ impl<'a, B: BlackBoxFunctionSolver> ProgramExecutor<'a, B> { } } - async fn execute(&self, initial_witness: WitnessMap) -> Result { + async fn execute( + &self, + initial_witness: WitnessMap, + ) -> Result, Error> { let main = &self.functions[0]; let mut witness_stack = WitnessStack::default(); @@ -212,10 +215,10 @@ impl<'a, B: BlackBoxFunctionSolver> ProgramExecutor<'a, B> { fn execute_circuit( &'a self, - circuit: &'a Circuit, - initial_witness: WitnessMap, - witness_stack: &'a mut WitnessStack, - ) -> Pin> + 'a>> { + circuit: &'a Circuit, + initial_witness: WitnessMap, + witness_stack: &'a mut WitnessStack, + ) -> Pin, Error>> + 'a>> { Box::pin(async { let mut acvm = ACVM::new( self.blackbox_solver, diff --git a/acvm-repo/acvm_js/src/foreign_call/inputs.rs b/acvm-repo/acvm_js/src/foreign_call/inputs.rs index ebd29fb7d5..dd12bc639e 100644 --- a/acvm-repo/acvm_js/src/foreign_call/inputs.rs +++ b/acvm-repo/acvm_js/src/foreign_call/inputs.rs @@ -1,9 +1,9 @@ -use acvm::brillig_vm::brillig::ForeignCallParam; +use acvm::{brillig_vm::brillig::ForeignCallParam, FieldElement}; use crate::js_witness_map::field_element_to_js_string; pub(super) fn encode_foreign_call_inputs( - foreign_call_inputs: &[ForeignCallParam], + foreign_call_inputs: &[ForeignCallParam], ) -> js_sys::Array { let inputs = js_sys::Array::default(); for input in foreign_call_inputs { diff --git a/acvm-repo/acvm_js/src/foreign_call/mod.rs b/acvm-repo/acvm_js/src/foreign_call/mod.rs index 9ccaf733f8..4884c1173d 100644 --- a/acvm-repo/acvm_js/src/foreign_call/mod.rs +++ b/acvm-repo/acvm_js/src/foreign_call/mod.rs @@ -1,4 +1,4 @@ -use acvm::{brillig_vm::brillig::ForeignCallResult, pwg::ForeignCallWaitInfo}; +use acvm::{brillig_vm::brillig::ForeignCallResult, pwg::ForeignCallWaitInfo, FieldElement}; use js_sys::{Error, JsString}; use wasm_bindgen::{prelude::wasm_bindgen, JsValue}; @@ -29,8 +29,8 @@ extern "C" { pub(super) async fn resolve_brillig( foreign_call_callback: &ForeignCallHandler, - foreign_call_wait_info: &ForeignCallWaitInfo, -) -> Result { + foreign_call_wait_info: &ForeignCallWaitInfo, +) -> Result, Error> { // Prepare to call let name = JsString::from(foreign_call_wait_info.function.clone()); let inputs = inputs::encode_foreign_call_inputs(&foreign_call_wait_info.inputs); diff --git a/acvm-repo/acvm_js/src/foreign_call/outputs.rs b/acvm-repo/acvm_js/src/foreign_call/outputs.rs index 78fa520aa1..2b3f44fe98 100644 --- a/acvm-repo/acvm_js/src/foreign_call/outputs.rs +++ b/acvm-repo/acvm_js/src/foreign_call/outputs.rs @@ -1,9 +1,12 @@ -use acvm::brillig_vm::brillig::{ForeignCallParam, ForeignCallResult}; +use acvm::{ + brillig_vm::brillig::{ForeignCallParam, ForeignCallResult}, + FieldElement, +}; use wasm_bindgen::JsValue; use crate::js_witness_map::js_value_to_field_element; -fn decode_foreign_call_output(output: JsValue) -> Result { +fn decode_foreign_call_output(output: JsValue) -> Result, String> { if output.is_string() { let value = js_value_to_field_element(output)?; Ok(ForeignCallParam::Single(value)) @@ -22,8 +25,9 @@ fn decode_foreign_call_output(output: JsValue) -> Result Result { - let mut values: Vec = Vec::with_capacity(js_array.length() as usize); +) -> Result, String> { + let mut values: Vec> = + Vec::with_capacity(js_array.length() as usize); for elem in js_array.iter() { values.push(decode_foreign_call_output(elem)?); } diff --git a/acvm-repo/acvm_js/src/js_execution_error.rs b/acvm-repo/acvm_js/src/js_execution_error.rs index b34ea5ddb6..e51a912a63 100644 --- a/acvm-repo/acvm_js/src/js_execution_error.rs +++ b/acvm-repo/acvm_js/src/js_execution_error.rs @@ -1,4 +1,7 @@ -use acvm::acir::circuit::{OpcodeLocation, RawAssertionPayload}; +use acvm::{ + acir::circuit::{OpcodeLocation, RawAssertionPayload}, + FieldElement, +}; use gloo_utils::format::JsValueSerdeExt; use js_sys::{Array, Error, JsString, Reflect}; use wasm_bindgen::prelude::{wasm_bindgen, JsValue}; @@ -34,7 +37,7 @@ impl JsExecutionError { pub fn new( message: String, call_stack: Option>, - assertion_payload: Option, + assertion_payload: Option>, ) -> Self { let mut error = JsExecutionError::constructor(JsString::from(message)); let js_call_stack = match call_stack { diff --git a/acvm-repo/acvm_js/src/js_witness_map.rs b/acvm-repo/acvm_js/src/js_witness_map.rs index c4482c4a23..8316059e21 100644 --- a/acvm-repo/acvm_js/src/js_witness_map.rs +++ b/acvm-repo/acvm_js/src/js_witness_map.rs @@ -1,5 +1,6 @@ use acvm::{ acir::native_types::{Witness, WitnessMap}, + acir::AcirField, FieldElement, }; use js_sys::{JsString, Map, Object}; @@ -51,8 +52,8 @@ impl Default for JsSolvedAndReturnWitness { } } -impl From for JsWitnessMap { - fn from(witness_map: WitnessMap) -> Self { +impl From> for JsWitnessMap { + fn from(witness_map: WitnessMap) -> Self { let js_map = JsWitnessMap::new(); for (key, value) in witness_map { js_map.set( @@ -64,7 +65,7 @@ impl From for JsWitnessMap { } } -impl From for WitnessMap { +impl From for WitnessMap { fn from(js_map: JsWitnessMap) -> Self { let mut witness_map = WitnessMap::new(); js_map.for_each(&mut |value, key| { @@ -76,8 +77,8 @@ impl From for WitnessMap { } } -impl From<(WitnessMap, WitnessMap)> for JsSolvedAndReturnWitness { - fn from(witness_maps: (WitnessMap, WitnessMap)) -> Self { +impl From<(WitnessMap, WitnessMap)> for JsSolvedAndReturnWitness { + fn from(witness_maps: (WitnessMap, WitnessMap)) -> Self { let js_solved_witness = JsWitnessMap::from(witness_maps.0); let js_return_witness = JsWitnessMap::from(witness_maps.1); @@ -113,7 +114,7 @@ mod test { use acvm::{ acir::native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, FieldElement, }; use wasm_bindgen::JsValue; diff --git a/acvm-repo/acvm_js/src/js_witness_stack.rs b/acvm-repo/acvm_js/src/js_witness_stack.rs index 59f2dbc051..d59ee50808 100644 --- a/acvm-repo/acvm_js/src/js_witness_stack.rs +++ b/acvm-repo/acvm_js/src/js_witness_stack.rs @@ -1,4 +1,4 @@ -use acvm::acir::native_types::WitnessStack; +use acvm::{acir::native_types::WitnessStack, FieldElement}; use js_sys::{Array, Map, Object}; use wasm_bindgen::prelude::{wasm_bindgen, JsValue}; @@ -38,8 +38,8 @@ impl Default for JsWitnessStack { } } -impl From for JsWitnessStack { - fn from(mut witness_stack: WitnessStack) -> Self { +impl From> for JsWitnessStack { + fn from(mut witness_stack: WitnessStack) -> Self { let js_witness_stack = JsWitnessStack::new(); while let Some(stack_item) = witness_stack.pop() { let js_map = JsWitnessMap::from(stack_item.witness); @@ -57,7 +57,7 @@ impl From for JsWitnessStack { } } -impl From for WitnessStack { +impl From for WitnessStack { fn from(js_witness_stack: JsWitnessStack) -> Self { let mut witness_stack = WitnessStack::default(); js_witness_stack.for_each(&mut |stack_item, _, _| { diff --git a/acvm-repo/acvm_js/src/public_witness.rs b/acvm-repo/acvm_js/src/public_witness.rs index 4ba054732d..245d5b4dd0 100644 --- a/acvm-repo/acvm_js/src/public_witness.rs +++ b/acvm-repo/acvm_js/src/public_witness.rs @@ -1,6 +1,9 @@ -use acvm::acir::{ - circuit::Program, - native_types::{Witness, WitnessMap}, +use acvm::{ + acir::{ + circuit::Program, + native_types::{Witness, WitnessMap}, + }, + FieldElement, }; use js_sys::JsString; use wasm_bindgen::prelude::wasm_bindgen; @@ -8,9 +11,9 @@ use wasm_bindgen::prelude::wasm_bindgen; use crate::JsWitnessMap; pub(crate) fn extract_indices( - witness_map: &WitnessMap, + witness_map: &WitnessMap, indices: Vec, -) -> Result { +) -> Result, String> { let mut extracted_witness_map = WitnessMap::new(); for witness in indices { let witness_value = witness_map.get(&witness).ok_or(format!( @@ -36,7 +39,7 @@ pub fn get_return_witness( witness_map: JsWitnessMap, ) -> Result { console_error_panic_hook::set_once(); - let program: Program = + let program: Program = Program::deserialize_program(&program).expect("Failed to deserialize circuit"); let circuit = match program.functions.len() { 0 => return Ok(JsWitnessMap::from(WitnessMap::new())), @@ -63,7 +66,7 @@ pub fn get_public_parameters_witness( solved_witness: JsWitnessMap, ) -> Result { console_error_panic_hook::set_once(); - let program: Program = + let program: Program = Program::deserialize_program(&program).expect("Failed to deserialize circuit"); let circuit = match program.functions.len() { 0 => return Ok(JsWitnessMap::from(WitnessMap::new())), @@ -90,7 +93,7 @@ pub fn get_public_witness( solved_witness: JsWitnessMap, ) -> Result { console_error_panic_hook::set_once(); - let program: Program = + let program: Program = Program::deserialize_program(&program).expect("Failed to deserialize circuit"); let circuit = match program.functions.len() { 0 => return Ok(JsWitnessMap::from(WitnessMap::new())), diff --git a/acvm-repo/blackbox_solver/Cargo.toml b/acvm-repo/blackbox_solver/Cargo.toml index 6c1cd19025..00c87bbca7 100644 --- a/acvm-repo/blackbox_solver/Cargo.toml +++ b/acvm-repo/blackbox_solver/Cargo.toml @@ -40,6 +40,5 @@ p256 = { version = "0.11.0", features = [ libaes = "0.7.0" [features] -default = ["bn254"] bn254 = ["acir/bn254"] bls12_381 = ["acir/bls12_381"] diff --git a/acvm-repo/blackbox_solver/src/curve_specific_solver.rs b/acvm-repo/blackbox_solver/src/curve_specific_solver.rs index 73f64d3d9d..0ee3a25284 100644 --- a/acvm-repo/blackbox_solver/src/curve_specific_solver.rs +++ b/acvm-repo/blackbox_solver/src/curve_specific_solver.rs @@ -1,4 +1,4 @@ -use acir::{BlackBoxFunc, FieldElement}; +use acir::BlackBoxFunc; use crate::BlackBoxResolutionError; @@ -6,44 +6,44 @@ use crate::BlackBoxResolutionError; /// doesn't have a canonical Rust implementation. /// /// Returns an [`BlackBoxResolutionError`] if the backend does not support the given [`acir::BlackBoxFunc`]. -pub trait BlackBoxFunctionSolver { +pub trait BlackBoxFunctionSolver { fn schnorr_verify( &self, - public_key_x: &FieldElement, - public_key_y: &FieldElement, + public_key_x: &F, + public_key_y: &F, signature: &[u8; 64], message: &[u8], ) -> Result; fn pedersen_commitment( &self, - inputs: &[FieldElement], + inputs: &[F], domain_separator: u32, - ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError>; + ) -> Result<(F, F), BlackBoxResolutionError>; fn pedersen_hash( &self, - inputs: &[FieldElement], + inputs: &[F], domain_separator: u32, - ) -> Result; + ) -> Result; fn multi_scalar_mul( &self, - points: &[FieldElement], - scalars_lo: &[FieldElement], - scalars_hi: &[FieldElement], - ) -> Result<(FieldElement, FieldElement, FieldElement), BlackBoxResolutionError>; + points: &[F], + scalars_lo: &[F], + scalars_hi: &[F], + ) -> Result<(F, F, F), BlackBoxResolutionError>; fn ec_add( &self, - input1_x: &FieldElement, - input1_y: &FieldElement, - input1_infinite: &FieldElement, - input2_x: &FieldElement, - input2_y: &FieldElement, - input2_infinite: &FieldElement, - ) -> Result<(FieldElement, FieldElement, FieldElement), BlackBoxResolutionError>; + input1_x: &F, + input1_y: &F, + input1_infinite: &F, + input2_x: &F, + input2_y: &F, + input2_infinite: &F, + ) -> Result<(F, F, F), BlackBoxResolutionError>; fn poseidon2_permutation( &self, - _inputs: &[FieldElement], + _inputs: &[F], _len: u32, - ) -> Result, BlackBoxResolutionError>; + ) -> Result, BlackBoxResolutionError>; } pub struct StubbedBlackBoxSolver; @@ -57,11 +57,11 @@ impl StubbedBlackBoxSolver { } } -impl BlackBoxFunctionSolver for StubbedBlackBoxSolver { +impl BlackBoxFunctionSolver for StubbedBlackBoxSolver { fn schnorr_verify( &self, - _public_key_x: &FieldElement, - _public_key_y: &FieldElement, + _public_key_x: &F, + _public_key_y: &F, _signature: &[u8; 64], _message: &[u8], ) -> Result { @@ -69,42 +69,42 @@ impl BlackBoxFunctionSolver for StubbedBlackBoxSolver { } fn pedersen_commitment( &self, - _inputs: &[FieldElement], + _inputs: &[F], _domain_separator: u32, - ) -> Result<(FieldElement, FieldElement), BlackBoxResolutionError> { + ) -> Result<(F, F), BlackBoxResolutionError> { Err(Self::fail(BlackBoxFunc::PedersenCommitment)) } fn pedersen_hash( &self, - _inputs: &[FieldElement], + _inputs: &[F], _domain_separator: u32, - ) -> Result { + ) -> Result { Err(Self::fail(BlackBoxFunc::PedersenHash)) } fn multi_scalar_mul( &self, - _points: &[FieldElement], - _scalars_lo: &[FieldElement], - _scalars_hi: &[FieldElement], - ) -> Result<(FieldElement, FieldElement, FieldElement), BlackBoxResolutionError> { + _points: &[F], + _scalars_lo: &[F], + _scalars_hi: &[F], + ) -> Result<(F, F, F), BlackBoxResolutionError> { Err(Self::fail(BlackBoxFunc::MultiScalarMul)) } fn ec_add( &self, - _input1_x: &FieldElement, - _input1_y: &FieldElement, - _input1_infinite: &FieldElement, - _input2_x: &FieldElement, - _input2_y: &FieldElement, - _input2_infinite: &FieldElement, - ) -> Result<(FieldElement, FieldElement, FieldElement), BlackBoxResolutionError> { + _input1_x: &F, + _input1_y: &F, + _input1_infinite: &F, + _input2_x: &F, + _input2_y: &F, + _input2_infinite: &F, + ) -> Result<(F, F, F), BlackBoxResolutionError> { Err(Self::fail(BlackBoxFunc::EmbeddedCurveAdd)) } fn poseidon2_permutation( &self, - _inputs: &[FieldElement], + _inputs: &[F], _len: u32, - ) -> Result, BlackBoxResolutionError> { + ) -> Result, BlackBoxResolutionError> { Err(Self::fail(BlackBoxFunc::Poseidon2Permutation)) } } diff --git a/acvm-repo/bn254_blackbox_solver/benches/criterion.rs b/acvm-repo/bn254_blackbox_solver/benches/criterion.rs index b86414423c..cbcb75a329 100644 --- a/acvm-repo/bn254_blackbox_solver/benches/criterion.rs +++ b/acvm-repo/bn254_blackbox_solver/benches/criterion.rs @@ -1,7 +1,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; use std::{hint::black_box, time::Duration}; -use acir::FieldElement; +use acir::{AcirField, FieldElement}; use acvm_blackbox_solver::BlackBoxFunctionSolver; use bn254_blackbox_solver::{poseidon2_permutation, Bn254BlackBoxSolver}; diff --git a/acvm-repo/bn254_blackbox_solver/src/embedded_curve_ops.rs b/acvm-repo/bn254_blackbox_solver/src/embedded_curve_ops.rs index 901eb9d5a0..148accd8b2 100644 --- a/acvm-repo/bn254_blackbox_solver/src/embedded_curve_ops.rs +++ b/acvm-repo/bn254_blackbox_solver/src/embedded_curve_ops.rs @@ -3,7 +3,8 @@ use ark_ec::AffineRepr; use ark_ff::MontConfig; use num_bigint::BigUint; -use acir::{BlackBoxFunc, FieldElement}; +use acir::BlackBoxFunc; +use acir::{AcirField, FieldElement}; use crate::BlackBoxResolutionError; @@ -117,10 +118,10 @@ fn create_point( #[cfg(test)] mod tests { - use ark_ff::BigInteger; - use super::*; + use ark_ff::BigInteger; + fn get_generator() -> [FieldElement; 3] { let generator = grumpkin::SWAffine::generator(); let generator_x = FieldElement::from_repr(*generator.x().unwrap()); diff --git a/acvm-repo/bn254_blackbox_solver/src/lib.rs b/acvm-repo/bn254_blackbox_solver/src/lib.rs index ae5a1c3db6..e3cea1153b 100644 --- a/acvm-repo/bn254_blackbox_solver/src/lib.rs +++ b/acvm-repo/bn254_blackbox_solver/src/lib.rs @@ -18,7 +18,7 @@ pub use poseidon2::poseidon2_permutation; #[derive(Default)] pub struct Bn254BlackBoxSolver; -impl BlackBoxFunctionSolver for Bn254BlackBoxSolver { +impl BlackBoxFunctionSolver for Bn254BlackBoxSolver { fn schnorr_verify( &self, public_key_x: &FieldElement, diff --git a/acvm-repo/bn254_blackbox_solver/src/pedersen/commitment.rs b/acvm-repo/bn254_blackbox_solver/src/pedersen/commitment.rs index 6769150508..5ff2c269d8 100644 --- a/acvm-repo/bn254_blackbox_solver/src/pedersen/commitment.rs +++ b/acvm-repo/bn254_blackbox_solver/src/pedersen/commitment.rs @@ -27,7 +27,7 @@ pub(crate) fn commit_native_with_index( #[cfg(test)] mod test { - use acir::FieldElement; + use acir::{AcirField, FieldElement}; use ark_ec::short_weierstrass::Affine; use ark_std::{One, Zero}; use grumpkin::Fq; diff --git a/acvm-repo/bn254_blackbox_solver/src/pedersen/hash.rs b/acvm-repo/bn254_blackbox_solver/src/pedersen/hash.rs index 28bf354edc..5c63720749 100644 --- a/acvm-repo/bn254_blackbox_solver/src/pedersen/hash.rs +++ b/acvm-repo/bn254_blackbox_solver/src/pedersen/hash.rs @@ -31,7 +31,7 @@ pub(crate) mod test { use super::*; - use acir::FieldElement; + use acir::{AcirField, FieldElement}; use ark_std::One; use grumpkin::Fq; diff --git a/acvm-repo/bn254_blackbox_solver/src/poseidon2.rs b/acvm-repo/bn254_blackbox_solver/src/poseidon2.rs index 65058e1509..95c620aab0 100644 --- a/acvm-repo/bn254_blackbox_solver/src/poseidon2.rs +++ b/acvm-repo/bn254_blackbox_solver/src/poseidon2.rs @@ -1,4 +1,4 @@ -use acir::FieldElement; +use acir::{AcirField, FieldElement}; use acvm_blackbox_solver::BlackBoxResolutionError; use lazy_static::lazy_static; @@ -543,7 +543,7 @@ impl<'a> Poseidon2<'a> { #[cfg(test)] mod test { - use acir::FieldElement; + use acir::{AcirField, FieldElement}; use super::{field_from_hex, poseidon2_permutation}; diff --git a/acvm-repo/bn254_blackbox_solver/src/schnorr/mod.rs b/acvm-repo/bn254_blackbox_solver/src/schnorr/mod.rs index cb21372697..62e515a079 100644 --- a/acvm-repo/bn254_blackbox_solver/src/schnorr/mod.rs +++ b/acvm-repo/bn254_blackbox_solver/src/schnorr/mod.rs @@ -65,7 +65,7 @@ fn schnorr_generate_challenge( #[cfg(test)] mod schnorr_tests { - use acir::FieldElement; + use acir::{AcirField, FieldElement}; use super::verify_signature; diff --git a/acvm-repo/brillig/Cargo.toml b/acvm-repo/brillig/Cargo.toml index f60bde6f07..245767dcec 100644 --- a/acvm-repo/brillig/Cargo.toml +++ b/acvm-repo/brillig/Cargo.toml @@ -17,6 +17,5 @@ acir_field.workspace = true serde.workspace = true [features] -default = ["bn254"] bn254 = ["acir_field/bn254"] bls12_381 = ["acir_field/bls12_381"] diff --git a/acvm-repo/brillig/src/foreign_call.rs b/acvm-repo/brillig/src/foreign_call.rs index e547b99f0e..a439d5c320 100644 --- a/acvm-repo/brillig/src/foreign_call.rs +++ b/acvm-repo/brillig/src/foreign_call.rs @@ -1,34 +1,34 @@ -use acir_field::FieldElement; +use acir_field::AcirField; use serde::{Deserialize, Serialize}; /// Single output of a [foreign call][crate::Opcode::ForeignCall]. #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone)] -pub enum ForeignCallParam { - Single(FieldElement), - Array(Vec), +pub enum ForeignCallParam { + Single(F), + Array(Vec), } -impl From for ForeignCallParam { - fn from(value: FieldElement) -> Self { +impl From for ForeignCallParam { + fn from(value: F) -> Self { ForeignCallParam::Single(value) } } -impl From> for ForeignCallParam { - fn from(values: Vec) -> Self { +impl From> for ForeignCallParam { + fn from(values: Vec) -> Self { ForeignCallParam::Array(values) } } -impl ForeignCallParam { - pub fn fields(&self) -> Vec { +impl ForeignCallParam { + pub fn fields(&self) -> Vec { match self { ForeignCallParam::Single(value) => vec![*value], - ForeignCallParam::Array(values) => values.clone(), + ForeignCallParam::Array(values) => values.to_vec(), } } - pub fn unwrap_field(&self) -> FieldElement { + pub fn unwrap_field(&self) -> F { match self { ForeignCallParam::Single(value) => *value, ForeignCallParam::Array(_) => panic!("Expected single value, found array"), @@ -38,25 +38,25 @@ impl ForeignCallParam { /// Represents the full output of a [foreign call][crate::Opcode::ForeignCall]. #[derive(Debug, PartialEq, Eq, Serialize, Deserialize, Clone, Default)] -pub struct ForeignCallResult { +pub struct ForeignCallResult { /// Resolved output values of the foreign call. - pub values: Vec, + pub values: Vec>, } -impl From for ForeignCallResult { - fn from(value: FieldElement) -> Self { +impl From for ForeignCallResult { + fn from(value: F) -> Self { ForeignCallResult { values: vec![value.into()] } } } -impl From> for ForeignCallResult { - fn from(values: Vec) -> Self { +impl From> for ForeignCallResult { + fn from(values: Vec) -> Self { ForeignCallResult { values: vec![values.into()] } } } -impl From> for ForeignCallResult { - fn from(values: Vec) -> Self { +impl From>> for ForeignCallResult { + fn from(values: Vec>) -> Self { ForeignCallResult { values } } } diff --git a/acvm-repo/brillig/src/opcodes.rs b/acvm-repo/brillig/src/opcodes.rs index a060aa83d4..78c6ba8097 100644 --- a/acvm-repo/brillig/src/opcodes.rs +++ b/acvm-repo/brillig/src/opcodes.rs @@ -1,5 +1,5 @@ use crate::black_box::BlackBoxOp; -use acir_field::FieldElement; +use acir_field::{AcirField, FieldElement}; use serde::{Deserialize, Serialize}; pub type Label = usize; @@ -89,7 +89,7 @@ pub enum ValueOrArray { } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub enum BrilligOpcode { +pub enum BrilligOpcode { /// Takes the fields in addresses `lhs` and `rhs` /// Performs the specified binary operation /// and stores the value in the `result` address. @@ -142,7 +142,7 @@ pub enum BrilligOpcode { Const { destination: MemoryAddress, bit_size: u32, - value: FieldElement, + value: F, }, Return, /// Used to get data from an outside source. diff --git a/acvm-repo/brillig_vm/Cargo.toml b/acvm-repo/brillig_vm/Cargo.toml index 7dd1191244..4735514c9a 100644 --- a/acvm-repo/brillig_vm/Cargo.toml +++ b/acvm-repo/brillig_vm/Cargo.toml @@ -20,6 +20,5 @@ num-traits.workspace = true thiserror.workspace = true [features] -default = ["bn254"] bn254 = ["acir/bn254"] bls12_381 = ["acir/bls12_381"] diff --git a/acvm-repo/brillig_vm/src/arithmetic.rs b/acvm-repo/brillig_vm/src/arithmetic.rs index c17c019d11..c88e06e2b9 100644 --- a/acvm-repo/brillig_vm/src/arithmetic.rs +++ b/acvm-repo/brillig_vm/src/arithmetic.rs @@ -1,5 +1,5 @@ use acir::brillig::{BinaryFieldOp, BinaryIntOp}; -use acir::FieldElement; +use acir::AcirField; use num_bigint::BigUint; use num_traits::ToPrimitive; use num_traits::{One, Zero}; @@ -19,36 +19,36 @@ pub(crate) enum BrilligArithmeticError { } /// Evaluate a binary operation on two FieldElement memory values. -pub(crate) fn evaluate_binary_field_op( +pub(crate) fn evaluate_binary_field_op( op: &BinaryFieldOp, - lhs: MemoryValue, - rhs: MemoryValue, -) -> Result { + lhs: MemoryValue, + rhs: MemoryValue, +) -> Result, BrilligArithmeticError> { let MemoryValue::Field(a) = lhs else { return Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: lhs.bit_size(), - op_bit_size: FieldElement::max_num_bits(), + op_bit_size: F::max_num_bits(), }); }; let MemoryValue::Field(b) = rhs else { return Err(BrilligArithmeticError::MismatchedLhsBitSize { lhs_bit_size: rhs.bit_size(), - op_bit_size: FieldElement::max_num_bits(), + op_bit_size: F::max_num_bits(), }); }; Ok(match op { // Perform addition, subtraction, multiplication, and division based on the BinaryOp variant. - BinaryFieldOp::Add => (a + b).into(), - BinaryFieldOp::Sub => (a - b).into(), - BinaryFieldOp::Mul => (a * b).into(), - BinaryFieldOp::Div => (a / b).into(), + BinaryFieldOp::Add => MemoryValue::new_field(a + b), + BinaryFieldOp::Sub => MemoryValue::new_field(a - b), + BinaryFieldOp::Mul => MemoryValue::new_field(a * b), + BinaryFieldOp::Div => MemoryValue::new_field(a / b), BinaryFieldOp::IntegerDiv => { let a_big = BigUint::from_bytes_be(&a.to_be_bytes()); let b_big = BigUint::from_bytes_be(&b.to_be_bytes()); let result = a_big / b_big; - FieldElement::from_be_bytes_reduce(&result.to_bytes_be()).into() + MemoryValue::new_field(F::from_be_bytes_reduce(&result.to_bytes_be())) } BinaryFieldOp::Equals => (a == b).into(), BinaryFieldOp::LessThan => (a < b).into(), @@ -57,12 +57,12 @@ pub(crate) fn evaluate_binary_field_op( } /// Evaluate a binary operation on two unsigned big integers with a given bit size. -pub(crate) fn evaluate_binary_int_op( +pub(crate) fn evaluate_binary_int_op( op: &BinaryIntOp, - lhs: MemoryValue, - rhs: MemoryValue, + lhs: MemoryValue, + rhs: MemoryValue, bit_size: u32, -) -> Result { +) -> Result, BrilligArithmeticError> { let lhs = lhs.expect_integer_with_bit_size(bit_size).map_err(|err| match err { MemoryTypeError::MismatchedBitSize { value_bit_size, expected_bit_size } => { BrilligArithmeticError::MismatchedLhsBitSize { @@ -82,7 +82,7 @@ pub(crate) fn evaluate_binary_int_op( } })?; - if bit_size == FieldElement::max_num_bits() { + if bit_size == F::max_num_bits() { return Err(BrilligArithmeticError::IntegerOperationOnField { op: *op }); } @@ -155,6 +155,7 @@ pub(crate) fn evaluate_binary_int_op( #[cfg(test)] mod tests { use super::*; + use acir::{AcirField, FieldElement}; struct TestParams { a: u128, @@ -163,7 +164,7 @@ mod tests { } fn evaluate_u128(op: &BinaryIntOp, a: u128, b: u128, bit_size: u32) -> u128 { - let result_value = evaluate_binary_int_op( + let result_value: MemoryValue = evaluate_binary_int_op( op, MemoryValue::new_integer(a.into(), bit_size), MemoryValue::new_integer(b.into(), bit_size), diff --git a/acvm-repo/brillig_vm/src/black_box.rs b/acvm-repo/brillig_vm/src/black_box.rs index ebaa697628..2053f4e7c8 100644 --- a/acvm-repo/brillig_vm/src/black_box.rs +++ b/acvm-repo/brillig_vm/src/black_box.rs @@ -1,5 +1,5 @@ use acir::brillig::{BlackBoxOp, HeapArray, HeapVector}; -use acir::{BlackBoxFunc, FieldElement}; +use acir::{AcirField, BlackBoxFunc}; use acvm_blackbox_solver::BigIntSolver; use acvm_blackbox_solver::{ aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256, @@ -10,17 +10,23 @@ use num_bigint::BigUint; use crate::memory::MemoryValue; use crate::Memory; -fn read_heap_vector<'a>(memory: &'a Memory, vector: &HeapVector) -> &'a [MemoryValue] { +fn read_heap_vector<'a, F: AcirField>( + memory: &'a Memory, + vector: &HeapVector, +) -> &'a [MemoryValue] { let size = memory.read(vector.size); memory.read_slice(memory.read_ref(vector.pointer), size.to_usize()) } -fn read_heap_array<'a>(memory: &'a Memory, array: &HeapArray) -> &'a [MemoryValue] { +fn read_heap_array<'a, F: AcirField>( + memory: &'a Memory, + array: &HeapArray, +) -> &'a [MemoryValue] { memory.read_slice(memory.read_ref(array.pointer), array.size) } /// Extracts the last byte of every value -fn to_u8_vec(inputs: &[MemoryValue]) -> Vec { +fn to_u8_vec(inputs: &[MemoryValue]) -> Vec { let mut result = Vec::with_capacity(inputs.len()); for input in inputs { result.push(input.try_into().unwrap()); @@ -28,14 +34,14 @@ fn to_u8_vec(inputs: &[MemoryValue]) -> Vec { result } -fn to_value_vec(input: &[u8]) -> Vec { +fn to_value_vec(input: &[u8]) -> Vec> { input.iter().map(|&x| x.into()).collect() } -pub(crate) fn evaluate_black_box( +pub(crate) fn evaluate_black_box>( op: &BlackBoxOp, solver: &Solver, - memory: &mut Memory, + memory: &mut Memory, bigint_solver: &mut BigIntSolver, ) -> Result<(), BlackBoxResolutionError> { match op { @@ -91,7 +97,7 @@ pub(crate) fn evaluate_black_box( let new_state = keccakf1600(state)?; - let new_state: Vec = new_state.into_iter().map(|x| x.into()).collect(); + let new_state: Vec> = new_state.into_iter().map(|x| x.into()).collect(); memory.write_slice(memory.read_ref(output.pointer), &new_state); Ok(()) } @@ -146,8 +152,8 @@ pub(crate) fn evaluate_black_box( Ok(()) } BlackBoxOp::SchnorrVerify { public_key_x, public_key_y, message, signature, result } => { - let public_key_x = memory.read(*public_key_x).try_into().unwrap(); - let public_key_y = memory.read(*public_key_y).try_into().unwrap(); + let public_key_x = *memory.read(*public_key_x).extract_field().unwrap(); + let public_key_y = *memory.read(*public_key_y).extract_field().unwrap(); let message: Vec = to_u8_vec(read_heap_vector(memory, message)); let signature: [u8; 64] = to_u8_vec(read_heap_vector(memory, signature)).try_into().unwrap(); @@ -157,20 +163,22 @@ pub(crate) fn evaluate_black_box( Ok(()) } BlackBoxOp::MultiScalarMul { points, scalars, outputs: result } => { - let points: Vec = read_heap_vector(memory, points) + let points: Vec = read_heap_vector(memory, points) .iter() .enumerate() .map(|(i, x)| { if i % 3 == 2 { let is_infinite: bool = x.try_into().unwrap(); - FieldElement::from(is_infinite as u128) + F::from(is_infinite as u128) } else { - x.try_into().unwrap() + *x.extract_field().unwrap() } }) .collect(); - let scalars: Vec = - read_heap_vector(memory, scalars).iter().map(|x| x.try_into().unwrap()).collect(); + let scalars: Vec = read_heap_vector(memory, scalars) + .iter() + .map(|x| *x.extract_field().unwrap()) + .collect(); let mut scalars_lo = Vec::with_capacity(scalars.len() / 2); let mut scalars_hi = Vec::with_capacity(scalars.len() / 2); for (i, scalar) in scalars.iter().enumerate() { @@ -183,7 +191,11 @@ pub(crate) fn evaluate_black_box( let (x, y, is_infinite) = solver.multi_scalar_mul(&points, &scalars_lo, &scalars_hi)?; memory.write_slice( memory.read_ref(result.pointer), - &[x.into(), y.into(), is_infinite.into()], + &[ + MemoryValue::new_field(x), + MemoryValue::new_field(y), + MemoryValue::new_field(is_infinite), + ], ); Ok(()) } @@ -196,11 +208,11 @@ pub(crate) fn evaluate_black_box( input1_infinite, input2_infinite, } => { - let input1_x = memory.read(*input1_x).try_into().unwrap(); - let input1_y = memory.read(*input1_y).try_into().unwrap(); + let input1_x = *memory.read(*input1_x).extract_field().unwrap(); + let input1_y = *memory.read(*input1_y).extract_field().unwrap(); let input1_infinite: bool = memory.read(*input1_infinite).try_into().unwrap(); - let input2_x = memory.read(*input2_x).try_into().unwrap(); - let input2_y = memory.read(*input2_y).try_into().unwrap(); + let input2_x = *memory.read(*input2_x).extract_field().unwrap(); + let input2_y = *memory.read(*input2_y).extract_field().unwrap(); let input2_infinite: bool = memory.read(*input2_infinite).try_into().unwrap(); let (x, y, infinite) = solver.ec_add( &input1_x, @@ -212,13 +224,19 @@ pub(crate) fn evaluate_black_box( )?; memory.write_slice( memory.read_ref(result.pointer), - &[x.into(), y.into(), infinite.into()], + &[ + MemoryValue::new_field(x), + MemoryValue::new_field(y), + MemoryValue::new_field(infinite), + ], ); Ok(()) } BlackBoxOp::PedersenCommitment { inputs, domain_separator, output } => { - let inputs: Vec = - read_heap_vector(memory, inputs).iter().map(|x| x.try_into().unwrap()).collect(); + let inputs: Vec = read_heap_vector(memory, inputs) + .iter() + .map(|x| *x.extract_field().unwrap()) + .collect(); let domain_separator: u32 = memory.read(*domain_separator).try_into().map_err(|_| { BlackBoxResolutionError::Failed( @@ -227,12 +245,17 @@ pub(crate) fn evaluate_black_box( ) })?; let (x, y) = solver.pedersen_commitment(&inputs, domain_separator)?; - memory.write_slice(memory.read_ref(output.pointer), &[x.into(), y.into()]); + memory.write_slice( + memory.read_ref(output.pointer), + &[MemoryValue::new_field(x), MemoryValue::new_field(y)], + ); Ok(()) } BlackBoxOp::PedersenHash { inputs, domain_separator, output } => { - let inputs: Vec = - read_heap_vector(memory, inputs).iter().map(|x| x.try_into().unwrap()).collect(); + let inputs: Vec = read_heap_vector(memory, inputs) + .iter() + .map(|x| *x.extract_field().unwrap()) + .collect(); let domain_separator: u32 = memory.read(*domain_separator).try_into().map_err(|_| { BlackBoxResolutionError::Failed( @@ -241,7 +264,7 @@ pub(crate) fn evaluate_black_box( ) })?; let hash = solver.pedersen_hash(&inputs, domain_separator)?; - memory.write(*output, hash.into()); + memory.write(*output, MemoryValue::new_field(hash)); Ok(()) } BlackBoxOp::BigIntAdd { lhs, rhs, output } => { @@ -297,12 +320,12 @@ pub(crate) fn evaluate_black_box( } BlackBoxOp::Poseidon2Permutation { message, output, len } => { let input = read_heap_vector(memory, message); - let input: Vec = input.iter().map(|x| x.try_into().unwrap()).collect(); + let input: Vec = input.iter().map(|x| *x.extract_field().unwrap()).collect(); let len = memory.read(*len).try_into().unwrap(); let result = solver.poseidon2_permutation(&input, len)?; let mut values = Vec::new(); for i in result { - values.push(i.into()); + values.push(MemoryValue::new_field(i)); } memory.write_slice(memory.read_ref(output.pointer), &values); Ok(()) @@ -338,17 +361,16 @@ pub(crate) fn evaluate_black_box( Ok(()) } BlackBoxOp::ToRadix { input, radix, output } => { - let input: FieldElement = - memory.read(*input).try_into().expect("ToRadix input not a field"); + let input: F = *memory.read(*input).extract_field().expect("ToRadix input not a field"); let mut input = BigUint::from_bytes_be(&input.to_be_bytes()); let radix = BigUint::from(*radix); - let mut limbs: Vec = Vec::with_capacity(output.size); + let mut limbs: Vec> = Vec::with_capacity(output.size); for _ in 0..output.size { let limb = &input % &radix; - limbs.push(FieldElement::from_be_bytes_reduce(&limb.to_bytes_be()).into()); + limbs.push(MemoryValue::new_field(F::from_be_bytes_reduce(&limb.to_bytes_be()))); input /= &radix; } @@ -388,7 +410,10 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc { #[cfg(test)] mod test { - use acir::brillig::{BlackBoxOp, MemoryAddress}; + use acir::{ + brillig::{BlackBoxOp, MemoryAddress}, + FieldElement, + }; use acvm_blackbox_solver::{BigIntSolver, StubbedBlackBoxSolver}; use crate::{ @@ -401,7 +426,7 @@ mod test { let message: Vec = b"hello world".to_vec(); let message_length = message.len(); - let mut memory = Memory::default(); + let mut memory: Memory = Memory::default(); let message_pointer = 3; let result_pointer = message_pointer + message_length; memory.write(MemoryAddress(0), message_pointer.into()); diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index 7901c31359..aaf5505658 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -15,7 +15,7 @@ use acir::brillig::{ BinaryFieldOp, BinaryIntOp, ForeignCallParam, ForeignCallResult, HeapArray, HeapValueType, HeapVector, MemoryAddress, Opcode, ValueOrArray, }; -use acir::FieldElement; +use acir::AcirField; use acvm_blackbox_solver::{BigIntSolver, BlackBoxFunctionSolver}; use arithmetic::{evaluate_binary_field_op, evaluate_binary_int_op, BrilligArithmeticError}; use black_box::evaluate_black_box; @@ -39,7 +39,7 @@ pub enum FailureReason { } #[derive(Debug, PartialEq, Eq, Clone)] -pub enum VMStatus { +pub enum VMStatus { Finished { return_data_offset: usize, return_data_size: usize, @@ -60,15 +60,15 @@ pub enum VMStatus { function: String, /// Input values /// Each input is a list of values as an input can be either a single value or a memory pointer - inputs: Vec, + inputs: Vec>, }, } #[derive(Debug, PartialEq, Eq, Clone)] /// VM encapsulates the state of the Brillig VM during execution. -pub struct VM<'a, B: BlackBoxFunctionSolver> { +pub struct VM<'a, F, B: BlackBoxFunctionSolver> { /// Calldata to the brillig function - calldata: Vec, + calldata: Vec, /// Instruction pointer program_counter: usize, /// A counter maintained throughout a Brillig process that determines @@ -76,13 +76,13 @@ pub struct VM<'a, B: BlackBoxFunctionSolver> { foreign_call_counter: usize, /// Represents the outputs of all foreign calls during a Brillig process /// List is appended onto by the caller upon reaching a [VMStatus::ForeignCallWait] - foreign_call_results: Vec, + foreign_call_results: Vec>, /// Executable opcodes - bytecode: &'a [Opcode], + bytecode: &'a [Opcode], /// Status of the VM - status: VMStatus, + status: VMStatus, /// Memory of the VM - memory: Memory, + memory: Memory, /// Call stack call_stack: Vec, /// The solver for blackbox functions @@ -91,12 +91,12 @@ pub struct VM<'a, B: BlackBoxFunctionSolver> { bigint_solver: BigIntSolver, } -impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { +impl<'a, F: AcirField, B: BlackBoxFunctionSolver> VM<'a, F, B> { /// Constructs a new VM instance pub fn new( - calldata: Vec, - bytecode: &'a [Opcode], - foreign_call_results: Vec, + calldata: Vec, + bytecode: &'a [Opcode], + foreign_call_results: Vec>, black_box_solver: &'a B, ) -> Self { Self { @@ -115,17 +115,17 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { /// Updates the current status of the VM. /// Returns the given status. - fn status(&mut self, status: VMStatus) -> VMStatus { + fn status(&mut self, status: VMStatus) -> VMStatus { self.status = status.clone(); status } - pub fn get_status(&self) -> VMStatus { + pub fn get_status(&self) -> VMStatus { self.status.clone() } /// Sets the current status of the VM to Finished (completed execution). - fn finish(&mut self, return_data_offset: usize, return_data_size: usize) -> VMStatus { + fn finish(&mut self, return_data_offset: usize, return_data_size: usize) -> VMStatus { self.status(VMStatus::Finished { return_data_offset, return_data_size }) } @@ -134,12 +134,12 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { fn wait_for_foreign_call( &mut self, function: String, - inputs: Vec, - ) -> VMStatus { + inputs: Vec>, + ) -> VMStatus { self.status(VMStatus::ForeignCallWait { function, inputs }) } - pub fn resolve_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { + pub fn resolve_foreign_call(&mut self, foreign_call_result: ForeignCallResult) { if self.foreign_call_counter < self.foreign_call_results.len() { panic!("No unresolved foreign calls"); } @@ -156,7 +156,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { /// Sets the current status of the VM to `fail`. /// Indicating that the VM encountered a `Trap` Opcode /// or an invalid state. - fn trap(&mut self, revert_data_offset: usize, revert_data_size: usize) -> VMStatus { + fn trap(&mut self, revert_data_offset: usize, revert_data_size: usize) -> VMStatus { self.status(VMStatus::Failure { call_stack: self.get_error_stack(), reason: FailureReason::Trap { revert_data_offset, revert_data_size }, @@ -164,7 +164,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { self.status.clone() } - fn fail(&mut self, message: String) -> VMStatus { + fn fail(&mut self, message: String) -> VMStatus { self.status(VMStatus::Failure { call_stack: self.get_error_stack(), reason: FailureReason::RuntimeError { message }, @@ -173,7 +173,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { } /// Loop over the bytecode and update the program counter - pub fn process_opcodes(&mut self) -> VMStatus { + pub fn process_opcodes(&mut self) -> VMStatus { while !matches!( self.process_opcode(), VMStatus::Finished { .. } | VMStatus::Failure { .. } | VMStatus::ForeignCallWait { .. } @@ -181,11 +181,11 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { self.status.clone() } - pub fn get_memory(&self) -> &[MemoryValue] { + pub fn get_memory(&self) -> &[MemoryValue] { self.memory.values() } - pub fn write_memory_at(&mut self, ptr: usize, value: MemoryValue) { + pub fn write_memory_at(&mut self, ptr: usize, value: MemoryValue) { self.memory.write(MemoryAddress(ptr), value); } @@ -196,7 +196,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { } /// Process a single opcode and modify the program counter. - pub fn process_opcode(&mut self) -> VMStatus { + pub fn process_opcode(&mut self) -> VMStatus { let opcode = &self.bytecode[self.program_counter]; match opcode { Opcode::BinaryFieldOp { op, lhs, rhs, destination: result } => { @@ -360,14 +360,14 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { } /// Increments the program counter by 1. - fn increment_program_counter(&mut self) -> VMStatus { + fn increment_program_counter(&mut self) -> VMStatus { self.set_program_counter(self.program_counter + 1) } /// Increments the program counter by `value`. /// If the program counter no longer points to an opcode /// in the bytecode, then the VMStatus reports halted. - fn set_program_counter(&mut self, value: usize) -> VMStatus { + fn set_program_counter(&mut self, value: usize) -> VMStatus { assert!(self.program_counter < self.bytecode.len()); self.program_counter = value; if self.program_counter >= self.bytecode.len() { @@ -380,10 +380,10 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { &self, input: ValueOrArray, value_type: &HeapValueType, - ) -> ForeignCallParam { + ) -> ForeignCallParam { match (input, value_type) { (ValueOrArray::MemoryAddress(value_index), HeapValueType::Simple(_)) => { - self.memory.read(value_index).to_field().into() + ForeignCallParam::Single(self.memory.read(value_index).to_field()) } ( ValueOrArray::HeapArray(HeapArray { pointer: pointer_index, size }), @@ -421,7 +421,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { start: MemoryAddress, size: usize, value_types: &[HeapValueType], - ) -> Vec { + ) -> Vec> { if HeapValueType::all_simple(value_types) { self.memory.read_slice(start, size).to_vec() } else { @@ -618,7 +618,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { } /// Casts a value to a different bit size. - fn cast(&self, bit_size: u32, source_value: MemoryValue) -> MemoryValue { + fn cast(&self, bit_size: u32, source_value: MemoryValue) -> MemoryValue { let lhs_big = source_value.to_integer(); let mask = BigUint::from(2_u32).pow(bit_size) - 1_u32; MemoryValue::new_from_integer(lhs_big & mask, bit_size) @@ -627,6 +627,7 @@ impl<'a, B: BlackBoxFunctionSolver> VM<'a, B> { #[cfg(test)] mod tests { + use acir::{AcirField, FieldElement}; use acvm_blackbox_solver::StubbedBlackBoxSolver; use super::*; @@ -664,7 +665,7 @@ mod tests { #[test] fn jmpif_opcode() { - let mut calldata = vec![]; + let mut calldata: Vec = vec![]; let mut opcodes = vec![]; let lhs = { @@ -709,7 +710,7 @@ mod tests { #[test] fn jmpifnot_opcode() { - let calldata = vec![1u128.into(), 2u128.into()]; + let calldata: Vec = vec![1u128.into(), 2u128.into()]; let calldata_copy = Opcode::CalldataCopy { destination_address: MemoryAddress::from(0), @@ -779,7 +780,7 @@ mod tests { #[test] fn cast_opcode() { - let calldata = vec![((2_u128.pow(32)) - 1).into()]; + let calldata: Vec = vec![((2_u128.pow(32)) - 1).into()]; let opcodes = &[ Opcode::CalldataCopy { @@ -813,7 +814,7 @@ mod tests { #[test] fn mov_opcode() { - let calldata = vec![(1u128).into(), (2u128).into(), (3u128).into()]; + let calldata: Vec = vec![(1u128).into(), (2u128).into(), (3u128).into()]; let calldata_copy = Opcode::CalldataCopy { destination_address: MemoryAddress::from(0), @@ -844,7 +845,8 @@ mod tests { #[test] fn cmov_opcode() { - let calldata = vec![(0u128).into(), (1u128).into(), (2u128).into(), (3u128).into()]; + let calldata: Vec = + vec![(0u128).into(), (1u128).into(), (2u128).into(), (3u128).into()]; let calldata_copy = Opcode::CalldataCopy { destination_address: MemoryAddress::from(0), @@ -910,7 +912,7 @@ mod tests { #[test] fn cmp_binary_ops() { let bit_size = 32; - let calldata = + let calldata: Vec = vec![(2u128).into(), (2u128).into(), (0u128).into(), (5u128).into(), (6u128).into()]; let calldata_size = calldata.len(); @@ -1010,14 +1012,14 @@ mod tests { /// memory[i] = i as Value; /// i += 1; /// } - fn brillig_write_memory(item_count: usize) -> Vec { + fn brillig_write_memory(item_count: usize) -> Vec> { let bit_size = 64; let r_i = MemoryAddress::from(0); let r_len = MemoryAddress::from(1); let r_tmp = MemoryAddress::from(2); let r_pointer = MemoryAddress::from(3); - let start = [ + let start: [Opcode; 3] = [ // i = 0 Opcode::Const { destination: r_i, value: 0u128.into(), bit_size }, // len = memory.len() (approximation) @@ -1091,7 +1093,7 @@ mod tests { let r_tmp = MemoryAddress::from(3); let r_pointer = MemoryAddress::from(4); - let start = [ + let start: [Opcode; 5] = [ // sum = 0 Opcode::Const { destination: r_sum, @@ -1179,18 +1181,18 @@ mod tests { /// recursive_write(memory, i + 1, len); /// } /// Note we represent a 100% in-stack optimized form in brillig - fn brillig_recursive_write_memory(size: usize) -> Vec { + fn brillig_recursive_write_memory(size: usize) -> Vec> { let bit_size = 64; let r_i = MemoryAddress::from(0); let r_len = MemoryAddress::from(1); let r_tmp = MemoryAddress::from(2); let r_pointer = MemoryAddress::from(3); - let start = [ + let start: [Opcode; 5] = [ // i = 0 Opcode::Const { destination: r_i, value: 0u128.into(), bit_size }, // len = size - Opcode::Const { destination: r_len, value: size.into(), bit_size }, + Opcode::Const { destination: r_len, value: (size as u128).into(), bit_size }, // pointer = free_memory_ptr Opcode::Const { destination: r_pointer, value: 4u128.into(), bit_size }, // call recursive_fn @@ -1245,28 +1247,28 @@ mod tests { vm.get_memory()[4..].to_vec() } - let memory = brillig_recursive_write_memory(5); + let memory = brillig_recursive_write_memory::(5); let expected = vec![(0u64).into(), (1u64).into(), (2u64).into(), (3u64).into(), (4u64).into()]; assert_eq!(memory, expected); - let memory = brillig_recursive_write_memory(1024); + let memory = brillig_recursive_write_memory::(1024); let expected: Vec<_> = (0..1024).map(|i: u64| i.into()).collect(); assert_eq!(memory, expected); } /// Helper to execute brillig code - fn brillig_execute_and_get_vm( - calldata: Vec, - opcodes: &[Opcode], - ) -> VM<'_, StubbedBlackBoxSolver> { + fn brillig_execute_and_get_vm( + calldata: Vec, + opcodes: &[Opcode], + ) -> VM<'_, F, StubbedBlackBoxSolver> { let mut vm = VM::new(calldata, opcodes, vec![], &StubbedBlackBoxSolver); brillig_execute(&mut vm); assert_eq!(vm.call_stack, vec![]); vm } - fn brillig_execute(vm: &mut VM) { + fn brillig_execute(vm: &mut VM) { loop { let status = vm.process_opcode(); if matches!(status, VMStatus::Finished { .. } | VMStatus::ForeignCallWait { .. }) { @@ -1330,7 +1332,8 @@ mod tests { let r_output = MemoryAddress::from(1); // Define a simple 2x2 matrix in memory - let initial_matrix = vec![(1u128).into(), (2u128).into(), (3u128).into(), (4u128).into()]; + let initial_matrix: Vec = + vec![(1u128).into(), (2u128).into(), (3u128).into(), (4u128).into()]; // Transpose of the matrix (but arbitrary for this test, the 'correct value') let expected_result: Vec = @@ -1505,7 +1508,8 @@ mod tests { let r_output = MemoryAddress::from(1); // Define a simple 2x2 matrix in memory - let initial_matrix = vec![(1u128).into(), (2u128).into(), (3u128).into(), (4u128).into()]; + let initial_matrix: Vec = + vec![(1u128).into(), (2u128).into(), (3u128).into(), (4u128).into()]; // Transpose of the matrix (but arbitrary for this test, the 'correct value') let expected_result: Vec = @@ -1592,9 +1596,11 @@ mod tests { let r_output = MemoryAddress::from(2); // Define a simple 2x2 matrix in memory - let matrix_a = vec![(1u128).into(), (2u128).into(), (3u128).into(), (4u128).into()]; + let matrix_a: Vec = + vec![(1u128).into(), (2u128).into(), (3u128).into(), (4u128).into()]; - let matrix_b = vec![(10u128).into(), (11u128).into(), (12u128).into(), (13u128).into()]; + let matrix_b: Vec = + vec![(10u128).into(), (11u128).into(), (12u128).into(), (13u128).into()]; // Transpose of the matrix (but arbitrary for this test, the 'correct value') let expected_result: Vec = @@ -1678,17 +1684,19 @@ mod tests { fn foreign_call_opcode_nested_arrays_and_slices_input() { // [(1, <2,3>, [4]), (5, <6,7,8>, [9])] - let v2: Vec = vec![ - MemoryValue::from(FieldElement::from(2u128)), - MemoryValue::from(FieldElement::from(3u128)), + let v2: Vec> = vec![ + MemoryValue::new_field(FieldElement::from(2u128)), + MemoryValue::new_field(FieldElement::from(3u128)), ]; - let a4: Vec = vec![FieldElement::from(4u128).into()]; - let v6: Vec = vec![ - MemoryValue::from(FieldElement::from(6u128)), - MemoryValue::from(FieldElement::from(7u128)), - MemoryValue::from(FieldElement::from(8u128)), + let a4: Vec> = + vec![MemoryValue::new_field(FieldElement::from(4u128))]; + let v6: Vec> = vec![ + MemoryValue::new_field(FieldElement::from(6u128)), + MemoryValue::new_field(FieldElement::from(7u128)), + MemoryValue::new_field(FieldElement::from(8u128)), ]; - let a9: Vec = vec![FieldElement::from(9u128).into()]; + let a9: Vec> = + vec![MemoryValue::new_field(FieldElement::from(9u128))]; // construct memory by declaring all inner arrays/vectors first let v2_ptr: usize = 0usize; @@ -1710,11 +1718,11 @@ mod tests { // finally we add the contents of the outer array let outer_ptr = memory.len(); let outer_array = vec![ - MemoryValue::from(FieldElement::from(1u128)), + MemoryValue::new_field(FieldElement::from(1u128)), MemoryValue::from(v2.len()), MemoryValue::from(v2_start), MemoryValue::from(a4_start), - MemoryValue::from(FieldElement::from(5u128)), + MemoryValue::new_field(FieldElement::from(5u128)), MemoryValue::from(v6.len()), MemoryValue::from(v6_start), MemoryValue::from(a9_start), @@ -1801,7 +1809,7 @@ mod tests { // Check result let result_value = vm.memory.read(r_output); - assert_eq!(result_value, MemoryValue::from(FieldElement::from(45u128))); + assert_eq!(result_value, MemoryValue::new_field(FieldElement::from(45u128))); // Ensure the foreign call counter has been incremented assert_eq!(vm.foreign_call_counter, 1); diff --git a/acvm-repo/brillig_vm/src/memory.rs b/acvm-repo/brillig_vm/src/memory.rs index feeb3706bd..4092cd06ae 100644 --- a/acvm-repo/brillig_vm/src/memory.rs +++ b/acvm-repo/brillig_vm/src/memory.rs @@ -1,12 +1,12 @@ -use acir::{brillig::MemoryAddress, FieldElement}; +use acir::{brillig::MemoryAddress, AcirField}; use num_bigint::BigUint; use num_traits::{One, Zero}; pub const MEMORY_ADDRESSING_BIT_SIZE: u32 = 64; #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub enum MemoryValue { - Field(FieldElement), +pub enum MemoryValue { + Field(F), Integer(BigUint, u32), } @@ -16,10 +16,10 @@ pub enum MemoryTypeError { MismatchedBitSize { value_bit_size: u32, expected_bit_size: u32 }, } -impl MemoryValue { +impl MemoryValue { /// Builds a memory value from a field element. - pub fn new_from_field(value: FieldElement, bit_size: u32) -> Self { - if bit_size == FieldElement::max_num_bits() { + pub fn new_from_field(value: F, bit_size: u32) -> Self { + if bit_size == F::max_num_bits() { MemoryValue::new_field(value) } else { MemoryValue::new_integer(BigUint::from_bytes_be(&value.to_be_bytes()), bit_size) @@ -28,16 +28,16 @@ impl MemoryValue { /// Builds a memory value from an integer pub fn new_from_integer(value: BigUint, bit_size: u32) -> Self { - if bit_size == FieldElement::max_num_bits() { - MemoryValue::new_field(FieldElement::from_be_bytes_reduce(&value.to_bytes_be())) + if bit_size == F::max_num_bits() { + MemoryValue::new_field(F::from_be_bytes_reduce(&value.to_bytes_be())) } else { MemoryValue::new_integer(value, bit_size) } } /// Builds a memory value from a field element, checking that the value is within the bit size. - pub fn new_checked(value: FieldElement, bit_size: u32) -> Option { - if bit_size < FieldElement::max_num_bits() && value.num_bits() > bit_size { + pub fn new_checked(value: F, bit_size: u32) -> Option { + if bit_size < F::max_num_bits() && value.num_bits() > bit_size { return None; } @@ -45,21 +45,21 @@ impl MemoryValue { } /// Builds a field-typed memory value. - pub fn new_field(value: FieldElement) -> Self { + pub fn new_field(value: F) -> Self { MemoryValue::Field(value) } /// Builds an integer-typed memory value. pub fn new_integer(value: BigUint, bit_size: u32) -> Self { assert!( - bit_size != FieldElement::max_num_bits(), + bit_size != F::max_num_bits(), "Tried to build a field memory value via new_integer" ); MemoryValue::Integer(value, bit_size) } /// Extracts the field element from the memory value, if it is typed as field element. - pub fn extract_field(&self) -> Option<&FieldElement> { + pub fn extract_field(&self) -> Option<&F> { match self { MemoryValue::Field(value) => Some(value), _ => None, @@ -75,12 +75,10 @@ impl MemoryValue { } /// Converts the memory value to a field element, independent of its type. - pub fn to_field(&self) -> FieldElement { + pub fn to_field(&self) -> F { match self { MemoryValue::Field(value) => *value, - MemoryValue::Integer(value, _) => { - FieldElement::from_be_bytes_reduce(&value.to_bytes_be()) - } + MemoryValue::Integer(value, _) => F::from_be_bytes_reduce(&value.to_bytes_be()), } } @@ -94,7 +92,7 @@ impl MemoryValue { pub fn bit_size(&self) -> u32 { match self { - MemoryValue::Field(_) => FieldElement::max_num_bits(), + MemoryValue::Field(_) => F::max_num_bits(), MemoryValue::Integer(_, bit_size) => *bit_size, } } @@ -107,11 +105,11 @@ impl MemoryValue { self.extract_integer().unwrap().0.try_into().unwrap() } - pub fn expect_field(&self) -> Result<&FieldElement, MemoryTypeError> { + pub fn expect_field(&self) -> Result<&F, MemoryTypeError> { match self { MemoryValue::Integer(_, bit_size) => Err(MemoryTypeError::MismatchedBitSize { value_bit_size: *bit_size, - expected_bit_size: FieldElement::max_num_bits(), + expected_bit_size: F::max_num_bits(), }), MemoryValue::Field(field) => Ok(field), } @@ -132,14 +130,14 @@ impl MemoryValue { Ok(value) } MemoryValue::Field(_) => Err(MemoryTypeError::MismatchedBitSize { - value_bit_size: FieldElement::max_num_bits(), + value_bit_size: F::max_num_bits(), expected_bit_size, }), } } } -impl std::fmt::Display for MemoryValue { +impl std::fmt::Display for MemoryValue { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> Result<(), ::std::fmt::Error> { match self { MemoryValue::Field(value) => write!(f, "{}: field", value), @@ -155,85 +153,71 @@ impl std::fmt::Display for MemoryValue { } } -impl Default for MemoryValue { +impl Default for MemoryValue { fn default() -> Self { MemoryValue::new_integer(BigUint::zero(), 0) } } -impl From for MemoryValue { - fn from(field: FieldElement) -> Self { - MemoryValue::new_field(field) - } -} - -impl From for MemoryValue { +impl From for MemoryValue { fn from(value: usize) -> Self { MemoryValue::new_integer(value.into(), MEMORY_ADDRESSING_BIT_SIZE) } } -impl From for MemoryValue { +impl From for MemoryValue { fn from(value: u64) -> Self { MemoryValue::new_integer(value.into(), 64) } } -impl From for MemoryValue { +impl From for MemoryValue { fn from(value: u32) -> Self { MemoryValue::new_integer(value.into(), 32) } } -impl From for MemoryValue { +impl From for MemoryValue { fn from(value: u8) -> Self { MemoryValue::new_integer(value.into(), 8) } } -impl From for MemoryValue { +impl From for MemoryValue { fn from(value: bool) -> Self { let value = if value { BigUint::one() } else { BigUint::zero() }; MemoryValue::new_integer(value, 1) } } -impl TryFrom for FieldElement { - type Error = MemoryTypeError; - - fn try_from(memory_value: MemoryValue) -> Result { - memory_value.expect_field().copied() - } -} - -impl TryFrom for u64 { +impl TryFrom> for u64 { type Error = MemoryTypeError; - fn try_from(memory_value: MemoryValue) -> Result { + fn try_from(memory_value: MemoryValue) -> Result { memory_value.expect_integer_with_bit_size(64).map(|value| value.try_into().unwrap()) } } -impl TryFrom for u32 { +impl TryFrom> for u32 { type Error = MemoryTypeError; - fn try_from(memory_value: MemoryValue) -> Result { + fn try_from(memory_value: MemoryValue) -> Result { memory_value.expect_integer_with_bit_size(32).map(|value| value.try_into().unwrap()) } } -impl TryFrom for u8 { +impl TryFrom> for u8 { type Error = MemoryTypeError; - fn try_from(memory_value: MemoryValue) -> Result { + fn try_from(memory_value: MemoryValue) -> Result { memory_value.expect_integer_with_bit_size(8).map(|value| value.try_into().unwrap()) } } -impl TryFrom for bool { +impl TryFrom> for bool { type Error = MemoryTypeError; - fn try_from(memory_value: MemoryValue) -> Result { + fn try_from(memory_value: MemoryValue) -> Result { let as_integer = memory_value.expect_integer_with_bit_size(1)?; if as_integer.is_zero() { @@ -246,48 +230,40 @@ impl TryFrom for bool { } } -impl TryFrom<&MemoryValue> for FieldElement { - type Error = MemoryTypeError; - - fn try_from(memory_value: &MemoryValue) -> Result { - memory_value.expect_field().copied() - } -} - -impl TryFrom<&MemoryValue> for u64 { +impl TryFrom<&MemoryValue> for u64 { type Error = MemoryTypeError; - fn try_from(memory_value: &MemoryValue) -> Result { + fn try_from(memory_value: &MemoryValue) -> Result { memory_value.expect_integer_with_bit_size(64).map(|value| { value.try_into().expect("memory_value has been asserted to contain a 64 bit integer") }) } } -impl TryFrom<&MemoryValue> for u32 { +impl TryFrom<&MemoryValue> for u32 { type Error = MemoryTypeError; - fn try_from(memory_value: &MemoryValue) -> Result { + fn try_from(memory_value: &MemoryValue) -> Result { memory_value.expect_integer_with_bit_size(32).map(|value| { value.try_into().expect("memory_value has been asserted to contain a 32 bit integer") }) } } -impl TryFrom<&MemoryValue> for u8 { +impl TryFrom<&MemoryValue> for u8 { type Error = MemoryTypeError; - fn try_from(memory_value: &MemoryValue) -> Result { + fn try_from(memory_value: &MemoryValue) -> Result { memory_value.expect_integer_with_bit_size(8).map(|value| { value.try_into().expect("memory_value has been asserted to contain an 8 bit integer") }) } } -impl TryFrom<&MemoryValue> for bool { +impl TryFrom<&MemoryValue> for bool { type Error = MemoryTypeError; - fn try_from(memory_value: &MemoryValue) -> Result { + fn try_from(memory_value: &MemoryValue) -> Result { let as_integer = memory_value.expect_integer_with_bit_size(1)?; if as_integer.is_zero() { @@ -301,15 +277,15 @@ impl TryFrom<&MemoryValue> for bool { } #[derive(Debug, Clone, Default, PartialEq, Eq)] -pub struct Memory { +pub struct Memory { // Memory is a vector of values. // We grow the memory when values past the end are set, extending with 0s. - inner: Vec, + inner: Vec>, } -impl Memory { +impl Memory { /// Gets the value at pointer - pub fn read(&self, ptr: MemoryAddress) -> MemoryValue { + pub fn read(&self, ptr: MemoryAddress) -> MemoryValue { self.inner.get(ptr.to_usize()).cloned().unwrap_or_default() } @@ -317,12 +293,12 @@ impl Memory { MemoryAddress(self.read(ptr).to_usize()) } - pub fn read_slice(&self, addr: MemoryAddress, len: usize) -> &[MemoryValue] { + pub fn read_slice(&self, addr: MemoryAddress, len: usize) -> &[MemoryValue] { &self.inner[addr.to_usize()..(addr.to_usize() + len)] } /// Sets the value at pointer `ptr` to `value` - pub fn write(&mut self, ptr: MemoryAddress, value: MemoryValue) { + pub fn write(&mut self, ptr: MemoryAddress, value: MemoryValue) { self.resize_to_fit(ptr.to_usize() + 1); self.inner[ptr.to_usize()] = value; } @@ -335,13 +311,13 @@ impl Memory { } /// Sets the values after pointer `ptr` to `values` - pub fn write_slice(&mut self, ptr: MemoryAddress, values: &[MemoryValue]) { + pub fn write_slice(&mut self, ptr: MemoryAddress, values: &[MemoryValue]) { self.resize_to_fit(ptr.to_usize() + values.len()); self.inner[ptr.to_usize()..(ptr.to_usize() + values.len())].clone_from_slice(values); } /// Returns the values of the memory - pub fn values(&self) -> &[MemoryValue] { + pub fn values(&self) -> &[MemoryValue] { &self.inner } } diff --git a/aztec_macros/Cargo.toml b/aztec_macros/Cargo.toml index 355036d28a..ed70066af2 100644 --- a/aztec_macros/Cargo.toml +++ b/aztec_macros/Cargo.toml @@ -10,6 +10,7 @@ repository.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +acvm.workspace = true noirc_frontend.workspace = true noirc_errors.workspace = true iter-extended.workspace = true diff --git a/aztec_macros/src/transforms/storage.rs b/aztec_macros/src/transforms/storage.rs index 0a21093482..bd9fff3c3d 100644 --- a/aztec_macros/src/transforms/storage.rs +++ b/aztec_macros/src/transforms/storage.rs @@ -1,3 +1,4 @@ +use acvm::acir::AcirField; use noirc_errors::Span; use noirc_frontend::ast::{ BlockExpression, Expression, ExpressionKind, FunctionDefinition, Ident, Literal, NoirFunction, diff --git a/aztec_macros/src/utils/hir_utils.rs b/aztec_macros/src/utils/hir_utils.rs index 99b02acd60..34aed3b34a 100644 --- a/aztec_macros/src/utils/hir_utils.rs +++ b/aztec_macros/src/utils/hir_utils.rs @@ -1,3 +1,4 @@ +use acvm::acir::AcirField; use iter_extended::vecmap; use noirc_errors::Location; use noirc_frontend::ast; diff --git a/compiler/noirc_driver/src/contract.rs b/compiler/noirc_driver/src/contract.rs index d6c3dc6205..11fc1bb637 100644 --- a/compiler/noirc_driver/src/contract.rs +++ b/compiler/noirc_driver/src/contract.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; -use acvm::acir::circuit::Program; +use acvm::{acir::circuit::Program, FieldElement}; use fm::FileId; use noirc_abi::{Abi, AbiType, AbiValue}; use noirc_errors::debug_info::DebugInfo; @@ -51,7 +51,7 @@ pub struct ContractFunction { serialize_with = "Program::serialize_program_base64", deserialize_with = "Program::deserialize_program_base64" )] - pub bytecode: Program, + pub bytecode: Program, pub debug: Vec, diff --git a/compiler/noirc_driver/src/program.rs b/compiler/noirc_driver/src/program.rs index ed7ddb29f5..8e02de0b8b 100644 --- a/compiler/noirc_driver/src/program.rs +++ b/compiler/noirc_driver/src/program.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use acvm::acir::circuit::Program; +use acvm::{acir::circuit::Program, FieldElement}; use fm::FileId; use noirc_errors::debug_info::DebugInfo; @@ -22,7 +22,7 @@ pub struct CompiledProgram { serialize_with = "Program::serialize_program_base64", deserialize_with = "Program::deserialize_program_base64" )] - pub program: Program, + pub program: Program, pub abi: noirc_abi::Abi, pub debug: Vec, pub file_map: BTreeMap, diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index 0dbab80dac..1fa4f41b29 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -18,7 +18,7 @@ use crate::ssa::ir::{ }; use acvm::acir::brillig::{MemoryAddress, ValueOrArray}; use acvm::brillig_vm::brillig::HeapVector; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use iter_extended::vecmap; use num_bigint::BigUint; diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs index 4b97a61491..7431959579 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs @@ -1,5 +1,6 @@ use acvm::{ acir::brillig::{BinaryFieldOp, BinaryIntOp, MemoryAddress, Opcode as BrilligOpcode}, + acir::AcirField, FieldElement, }; diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs index 2bd57dc948..9341db2ead 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir.rs @@ -27,7 +27,10 @@ pub(crate) use instructions::BrilligBinaryOp; use self::{artifact::BrilligArtifact, registers::BrilligRegistersContext}; use crate::ssa::ir::dfg::CallStack; -use acvm::acir::brillig::{MemoryAddress, Opcode as BrilligOpcode}; +use acvm::{ + acir::brillig::{MemoryAddress, Opcode as BrilligOpcode}, + FieldElement, +}; use debug_show::DebugShow; /// The Brillig VM does not apply a limit to the memory address space, @@ -110,7 +113,7 @@ impl BrilligContext { result } /// Adds a brillig instruction to the brillig byte code - fn push_opcode(&mut self, opcode: BrilligOpcode) { + fn push_opcode(&mut self, opcode: BrilligOpcode) { self.obj.push_opcode(opcode); } @@ -143,7 +146,7 @@ pub(crate) mod tests { pub(crate) struct DummyBlackBoxSolver; - impl BlackBoxFunctionSolver for DummyBlackBoxSolver { + impl BlackBoxFunctionSolver for DummyBlackBoxSolver { fn schnorr_verify( &self, _public_key_x: &FieldElement, @@ -217,8 +220,8 @@ pub(crate) mod tests { pub(crate) fn create_and_run_vm( calldata: Vec, - bytecode: &[BrilligOpcode], - ) -> (VM<'_, DummyBlackBoxSolver>, usize, usize) { + bytecode: &[BrilligOpcode], + ) -> (VM<'_, FieldElement, DummyBlackBoxSolver>, usize, usize) { let mut vm = VM::new(calldata, bytecode, vec![], &DummyBlackBoxSolver); let status = vm.process_opcodes(); @@ -277,7 +280,7 @@ pub(crate) mod tests { context.stop_instruction(); - let bytecode: Vec = context.artifact().finish().byte_code; + let bytecode: Vec> = context.artifact().finish().byte_code; let number_sequence: Vec = (0_usize..12_usize).map(FieldElement::from).collect(); let mut vm = VM::new( diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index dee6c6076f..99e922c158 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -1,4 +1,4 @@ -use acvm::acir::brillig::Opcode as BrilligOpcode; +use acvm::{acir::brillig::Opcode as BrilligOpcode, FieldElement}; use std::collections::{BTreeMap, HashMap}; use crate::ssa::ir::dfg::CallStack; @@ -19,7 +19,7 @@ pub(crate) enum BrilligParameter { /// This is ready to run bytecode with attached metadata. #[derive(Debug, Default)] pub(crate) struct GeneratedBrillig { - pub(crate) byte_code: Vec, + pub(crate) byte_code: Vec>, pub(crate) locations: BTreeMap, pub(crate) assert_messages: BTreeMap, } @@ -28,7 +28,7 @@ pub(crate) struct GeneratedBrillig { /// Artifacts resulting from the compilation of a function into brillig byte code. /// It includes the bytecode of the function and all the metadata that allows linking with other functions. pub(crate) struct BrilligArtifact { - pub(crate) byte_code: Vec, + pub(crate) byte_code: Vec>, /// A map of bytecode positions to assertion messages. /// Some error messages (compiler intrinsics) are not emitted via revert data, /// instead, they are handled externally so they don't add size to user programs. @@ -154,7 +154,7 @@ impl BrilligArtifact { } /// Adds a brillig instruction to the brillig byte code - pub(crate) fn push_opcode(&mut self, opcode: BrilligOpcode) { + pub(crate) fn push_opcode(&mut self, opcode: BrilligOpcode) { if !self.call_stack.is_empty() { self.locations.insert(self.index_of_next_opcode(), self.call_stack.clone()); } @@ -164,7 +164,7 @@ impl BrilligArtifact { /// Adds a unresolved jump to be fixed at the end of bytecode processing. pub(crate) fn add_unresolved_jump( &mut self, - jmp_instruction: BrilligOpcode, + jmp_instruction: BrilligOpcode, destination: UnresolvedJumpLocation, ) { assert!( @@ -178,7 +178,7 @@ impl BrilligArtifact { /// Adds a unresolved external call that will be fixed once linking has been done. pub(crate) fn add_unresolved_external_call( &mut self, - call_instruction: BrilligOpcode, + call_instruction: BrilligOpcode, destination: UnresolvedJumpLocation, ) { // TODO: Add a check to ensure that the opcode is a call instruction @@ -188,7 +188,7 @@ impl BrilligArtifact { } /// Returns true if the opcode is a jump instruction - fn is_jmp_instruction(instruction: &BrilligOpcode) -> bool { + fn is_jmp_instruction(instruction: &BrilligOpcode) -> bool { matches!( instruction, BrilligOpcode::JumpIfNot { .. } diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/brillig_variable.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/brillig_variable.rs index bbfdbb69f7..cf1fd55519 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/brillig_variable.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/brillig_variable.rs @@ -1,4 +1,5 @@ use acvm::{ + acir::AcirField, brillig_vm::brillig::{HeapArray, HeapValueType, HeapVector, MemoryAddress, ValueOrArray}, FieldElement, }; diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_intrinsic.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_intrinsic.rs index 58166554e1..42f3b34aea 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_intrinsic.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/codegen_intrinsic.rs @@ -1,5 +1,6 @@ use acvm::{ acir::brillig::{BlackBoxOp, HeapArray}, + acir::AcirField, FieldElement, }; diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs index 38e9bdfa8b..9023183eb3 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/entry_point.rs @@ -5,7 +5,7 @@ use super::{ registers::BrilligRegistersContext, BrilligBinaryOp, BrilligContext, ReservedRegisters, }; -use acvm::{acir::brillig::MemoryAddress, FieldElement}; +use acvm::{acir::brillig::MemoryAddress, acir::AcirField, FieldElement}; pub(crate) const MAX_STACK_SIZE: usize = 2048; diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs index 5d2430208e..03a9216b73 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/instructions.rs @@ -3,6 +3,7 @@ use acvm::{ BinaryFieldOp, BinaryIntOp, BlackBoxOp, HeapArray, HeapValueType, MemoryAddress, Opcode as BrilligOpcode, ValueOrArray, }, + acir::AcirField, FieldElement, }; @@ -212,7 +213,7 @@ impl BrilligContext { /// Adds a unresolved `Jump` to the bytecode. fn add_unresolved_jump( &mut self, - jmp_instruction: BrilligOpcode, + jmp_instruction: BrilligOpcode, destination: UnresolvedJumpLocation, ) { self.obj.add_unresolved_jump(jmp_instruction, destination); @@ -380,7 +381,7 @@ impl BrilligContext { constant, result.bit_size ); - if result.bit_size > 128 && !constant.fits_in_u128() { + if result.bit_size > 128 && constant.num_bits() > 128 { let high = FieldElement::from_be_bytes_reduce( constant.to_be_bytes().get(0..16).expect("FieldElement::to_be_bytes() too short!"), ); diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index c2fe7878bf..d38601bfc1 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -10,12 +10,15 @@ use std::collections::{BTreeMap, BTreeSet}; use crate::errors::{RuntimeError, SsaReport}; -use acvm::acir::{ - circuit::{ - brillig::BrilligBytecode, Circuit, ErrorSelector, ExpressionWidth, Program as AcirProgram, - PublicInputs, +use acvm::{ + acir::{ + circuit::{ + brillig::BrilligBytecode, Circuit, ErrorSelector, ExpressionWidth, + Program as AcirProgram, PublicInputs, + }, + native_types::Witness, }, - native_types::Witness, + FieldElement, }; use noirc_errors::debug_info::{DebugFunctions, DebugInfo, DebugTypes, DebugVariables}; @@ -102,7 +105,7 @@ fn time(name: &str, print_timings: bool, f: impl FnOnce() -> T) -> T { #[derive(Default)] pub struct SsaProgramArtifact { - pub program: AcirProgram, + pub program: AcirProgram, pub debug: Vec, pub warnings: Vec, pub main_input_witnesses: Vec, @@ -113,7 +116,7 @@ pub struct SsaProgramArtifact { impl SsaProgramArtifact { fn new( - unconstrained_functions: Vec, + unconstrained_functions: Vec>, error_types: BTreeMap, ) -> Self { let program = AcirProgram { functions: Vec::default(), unconstrained_functions }; @@ -194,7 +197,7 @@ pub fn create_program( pub struct SsaCircuitArtifact { name: String, - circuit: Circuit, + circuit: Circuit, debug_info: DebugInfo, warnings: Vec, input_witnesses: Vec, diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index b6e88a8d4b..4a0f9f798f 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -13,6 +13,7 @@ use acvm::acir::circuit::{AssertionPayload, ExpressionOrMemory, Opcode}; use acvm::blackbox_solver; use acvm::brillig_vm::{MemoryValue, VMStatus, VM}; use acvm::{ + acir::AcirField, acir::{ brillig::Opcode as BrilligOpcode, circuit::opcodes::FunctionInput, @@ -269,7 +270,10 @@ impl AcirContext { } /// Converts an [`AcirVar`] to an [`Expression`] - pub(crate) fn var_to_expression(&self, var: AcirVar) -> Result { + pub(crate) fn var_to_expression( + &self, + var: AcirVar, + ) -> Result, InternalError> { let var_data = match self.vars.get(&var) { Some(var_data) => var_data, None => { @@ -498,7 +502,7 @@ impl AcirContext { &mut self, lhs: AcirVar, rhs: AcirVar, - assert_message: Option, + assert_message: Option>, ) -> Result<(), RuntimeError> { let lhs_expr = self.var_to_expression(lhs)?; let rhs_expr = self.var_to_expression(rhs)?; @@ -527,7 +531,7 @@ impl AcirContext { pub(crate) fn vars_to_expressions_or_memory( &self, values: &[AcirValue], - ) -> Result, RuntimeError> { + ) -> Result>, RuntimeError> { let mut result = Vec::with_capacity(values.len()); for value in values { match value { @@ -901,7 +905,7 @@ impl AcirContext { // Optimization when rhs is const and fits within a u128 let rhs_expr = self.var_to_expression(rhs)?; - if rhs_expr.is_const() && rhs_expr.q_c.fits_in_u128() { + if rhs_expr.is_const() && rhs_expr.q_c.num_bits() <= 128 { // We try to move the offset to rhs let rhs_offset = if self.is_constant_one(&offset) && rhs_expr.q_c.to_u128() >= 1 { lhs_offset = lhs; @@ -1554,21 +1558,24 @@ impl AcirContext { brillig_function_index: u32, brillig_stdlib_func: Option, ) -> Result, RuntimeError> { - let brillig_inputs = try_vecmap(inputs, |i| -> Result<_, InternalError> { - match i { - AcirValue::Var(var, _) => Ok(BrilligInputs::Single(self.var_to_expression(var)?)), - AcirValue::Array(vars) => { - let mut var_expressions: Vec = Vec::new(); - for var in vars { - self.brillig_array_input(&mut var_expressions, var)?; + let brillig_inputs: Vec> = + try_vecmap(inputs, |i| -> Result<_, InternalError> { + match i { + AcirValue::Var(var, _) => { + Ok(BrilligInputs::Single(self.var_to_expression(var)?)) + } + AcirValue::Array(vars) => { + let mut var_expressions: Vec> = Vec::new(); + for var in vars { + self.brillig_array_input(&mut var_expressions, var)?; + } + Ok(BrilligInputs::Array(var_expressions)) + } + AcirValue::DynamicArray(AcirDynamicArray { block_id, .. }) => { + Ok(BrilligInputs::MemoryArray(block_id)) } - Ok(BrilligInputs::Array(var_expressions)) - } - AcirValue::DynamicArray(AcirDynamicArray { block_id, .. }) => { - Ok(BrilligInputs::MemoryArray(block_id)) } - } - })?; + })?; // Optimistically try executing the brillig now, if we can complete execution they just return the results. // This is a temporary measure pending SSA optimizations being applied to Brillig which would remove constant-input opcodes (See #2066) @@ -1645,7 +1652,7 @@ impl AcirContext { fn brillig_array_input( &mut self, - var_expressions: &mut Vec, + var_expressions: &mut Vec>, input: AcirValue, ) -> Result<(), InternalError> { match input { @@ -1704,8 +1711,8 @@ impl AcirContext { fn execute_brillig( &mut self, - code: &[BrilligOpcode], - inputs: &[BrilligInputs], + code: &[BrilligOpcode], + inputs: &[BrilligInputs], outputs_types: &[AcirType], ) -> Option> { let mut memory = (execute_brillig(code, inputs)?).into_iter(); @@ -1730,7 +1737,7 @@ impl AcirContext { &mut self, element_types: &[AcirType], size: usize, - memory_iter: &mut impl Iterator, + memory_iter: &mut impl Iterator>, ) -> AcirValue { let mut array_values = im::Vector::new(); for _ in 0..size { @@ -1885,7 +1892,7 @@ impl AcirContext { #[derive(Debug, Eq, Clone)] enum AcirVarData { Witness(Witness), - Expr(Expression), + Expr(Expression), Const(FieldElement), } @@ -1917,7 +1924,7 @@ impl AcirVarData { None } /// Converts all enum variants to an Expression. - pub(crate) fn to_expression(&self) -> Cow { + pub(crate) fn to_expression(&self) -> Cow> { match self { AcirVarData::Witness(witness) => Cow::Owned(Expression::from(*witness)), AcirVarData::Expr(expr) => Cow::Borrowed(expr), @@ -1938,8 +1945,8 @@ impl From for AcirVarData { } } -impl From for AcirVarData { - fn from(expr: Expression) -> Self { +impl From> for AcirVarData { + fn from(expr: Expression) -> Self { // Prefer simpler variants if possible. if let Some(constant) = expr.to_const() { AcirVarData::from(constant) @@ -1960,7 +1967,10 @@ pub(crate) struct AcirVar(usize); /// Returns the finished state of the Brillig VM if execution can complete. /// /// Returns `None` if complete execution of the Brillig bytecode is not possible. -fn execute_brillig(code: &[BrilligOpcode], inputs: &[BrilligInputs]) -> Option> { +fn execute_brillig( + code: &[BrilligOpcode], + inputs: &[BrilligInputs], +) -> Option>> { // Set input values let mut calldata: Vec = Vec::new(); diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/big_int.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/big_int.rs index c21188a8db..b9c596d80c 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/big_int.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/big_int.rs @@ -1,4 +1,4 @@ -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use num_bigint::BigUint; /// Represents a bigint value in the form (id, modulus) where diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index d23f4abe5f..6c79c0a228 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -17,6 +17,7 @@ use acvm::acir::{ BlackBoxFunc, }; use acvm::{ + acir::AcirField, acir::{circuit::directives::Directive, native_types::Expression}, FieldElement, }; @@ -41,7 +42,7 @@ pub(crate) struct GeneratedAcir { current_witness_index: Option, /// The opcodes of which the compiled ACIR will comprise. - opcodes: Vec, + opcodes: Vec>, /// All witness indices that comprise the final return value of the program /// @@ -60,7 +61,7 @@ pub(crate) struct GeneratedAcir { pub(crate) call_stack: CallStack, /// Correspondence between an opcode index and the error message associated with it. - pub(crate) assertion_payloads: BTreeMap, + pub(crate) assertion_payloads: BTreeMap>, pub(crate) warnings: Vec, @@ -99,18 +100,18 @@ impl GeneratedAcir { } /// Adds a new opcode into ACIR. - pub(crate) fn push_opcode(&mut self, opcode: AcirOpcode) { + pub(crate) fn push_opcode(&mut self, opcode: AcirOpcode) { self.opcodes.push(opcode); if !self.call_stack.is_empty() { self.locations.insert(self.last_acir_opcode_location(), self.call_stack.clone()); } } - pub(crate) fn opcodes(&self) -> &[AcirOpcode] { + pub(crate) fn opcodes(&self) -> &[AcirOpcode] { &self.opcodes } - pub(crate) fn take_opcodes(&mut self) -> Vec { + pub(crate) fn take_opcodes(&mut self) -> Vec> { std::mem::take(&mut self.opcodes) } @@ -129,7 +130,7 @@ impl GeneratedAcir { /// /// If `expr` can be represented as a `Witness` then this function will return it, /// else a new opcode will be added to create a `Witness` that is equal to `expr`. - pub(crate) fn get_or_create_witness(&mut self, expr: &Expression) -> Witness { + pub(crate) fn get_or_create_witness(&mut self, expr: &Expression) -> Witness { match expr.to_witness() { Some(witness) => witness, None => self.create_witness_for_expression(expr), @@ -142,7 +143,10 @@ impl GeneratedAcir { /// This means you cannot multiply an infinite amount of `Expression`s together. /// Once the `Expression` goes over degree-2, then it needs to be reduced to a `Witness` /// which has degree-1 in order to be able to continue the multiplication chain. - pub(crate) fn create_witness_for_expression(&mut self, expression: &Expression) -> Witness { + pub(crate) fn create_witness_for_expression( + &mut self, + expression: &Expression, + ) -> Witness { let fresh_witness = self.next_witness_index(); // Create a constraint that sets them to be equal to each other @@ -392,7 +396,7 @@ impl GeneratedAcir { /// Only radix that are a power of two are supported pub(crate) fn radix_le_decompose( &mut self, - input_expr: &Expression, + input_expr: &Expression, radix: u32, limb_count: u32, bit_size: u32, @@ -438,7 +442,7 @@ impl GeneratedAcir { /// /// Safety: It is the callers responsibility to ensure that the /// resulting `Witness` is constrained to be the inverse. - pub(crate) fn brillig_inverse(&mut self, expr: Expression) -> Witness { + pub(crate) fn brillig_inverse(&mut self, expr: Expression) -> Witness { // Create the witness for the result let inverted_witness = self.next_witness_index(); @@ -462,14 +466,18 @@ impl GeneratedAcir { /// /// If `expr` is not zero, then the constraint system will /// fail upon verification. - pub(crate) fn assert_is_zero(&mut self, expr: Expression) { + pub(crate) fn assert_is_zero(&mut self, expr: Expression) { self.push_opcode(AcirOpcode::AssertZero(expr)); } /// Returns a `Witness` that is constrained to be: /// - `1` if `lhs == rhs` /// - `0` otherwise - pub(crate) fn is_equal(&mut self, lhs: &Expression, rhs: &Expression) -> Witness { + pub(crate) fn is_equal( + &mut self, + lhs: &Expression, + rhs: &Expression, + ) -> Witness { let t = lhs - rhs; self.is_zero(&t) @@ -527,7 +535,7 @@ impl GeneratedAcir { /// By setting `z` to be `0`, we can make `y` equal to `1`. /// This is easily observed: `y = 1 - t * 0` /// Now since `y` is one, this means that `t` needs to be zero, or else `y * t == 0` will fail. - fn is_zero(&mut self, t_expr: &Expression) -> Witness { + fn is_zero(&mut self, t_expr: &Expression) -> Witness { // We're checking for equality with zero so we can negate the expression without changing the result. // This is useful as it will sometimes allow us to simplify an expression down to a witness. let t_witness = if let Some(witness) = t_expr.to_witness() { @@ -588,9 +596,9 @@ impl GeneratedAcir { pub(crate) fn brillig_call( &mut self, - predicate: Option, + predicate: Option>, generated_brillig: &GeneratedBrillig, - inputs: Vec, + inputs: Vec>, outputs: Vec, brillig_function_index: u32, stdlib_func: Option, diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index 30fd83d770..e9b2d5f46d 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -37,6 +37,7 @@ use acvm::acir::circuit::{AssertionPayload, ErrorSelector, OpcodeLocation}; use acvm::acir::native_types::Witness; use acvm::acir::BlackBoxFunc; use acvm::{ + acir::AcirField, acir::{circuit::opcodes::BlockId, native_types::Expression}, FieldElement, }; @@ -278,7 +279,7 @@ impl AcirValue { } pub(crate) type Artifacts = - (Vec, Vec, BTreeMap); + (Vec, Vec>, BTreeMap); impl Ssa { #[tracing::instrument(level = "trace", skip_all)] @@ -3097,7 +3098,7 @@ mod test { } fn check_call_opcode( - opcode: &Opcode, + opcode: &Opcode, expected_id: u32, expected_inputs: Vec, expected_outputs: Vec, @@ -3425,7 +3426,7 @@ mod test { fn check_brillig_calls( brillig_stdlib_function_locations: &BTreeMap, - opcodes: &[Opcode], + opcodes: &[Opcode], num_normal_brillig_functions: u32, expected_num_stdlib_calls: u32, expected_num_normal_calls: u32, diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index f5afbfae1b..b24c5632b2 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -516,7 +516,7 @@ impl std::ops::Index for FunctionBuilder { mod tests { use std::rc::Rc; - use acvm::FieldElement; + use acvm::{acir::AcirField, FieldElement}; use crate::ssa::ir::{ instruction::{Endian, Intrinsic}, diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index c71781557b..545827df1c 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -13,7 +13,7 @@ use super::{ value::{Value, ValueId}, }; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use fxhash::FxHashMap as HashMap; use iter_extended::vecmap; use noirc_errors::Location; diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 93ea703721..5110140bfc 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -1,6 +1,7 @@ use std::hash::{Hash, Hasher}; use acvm::{ + acir::AcirField, acir::{ circuit::{ErrorSelector, STRING_ERROR_SELECTOR}, BlackBoxFunc, diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs index 9099268ace..dbb717b0a1 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs @@ -1,4 +1,4 @@ -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use super::{ DataFlowGraph, Instruction, InstructionResultType, NumericType, SimplifyResult, Type, ValueId, diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs index 8f57d9de36..74e5653c7b 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/call.rs @@ -1,7 +1,7 @@ use fxhash::FxHashMap as HashMap; use std::{collections::VecDeque, rc::Rc}; -use acvm::{acir::BlackBoxFunc, BlackBoxResolutionError, FieldElement}; +use acvm::{acir::AcirField, acir::BlackBoxFunc, BlackBoxResolutionError, FieldElement}; use iter_extended::vecmap; use num_bigint::BigUint; diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/cast.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/cast.rs index 671820e801..d0ed5a1fa9 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/cast.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/cast.rs @@ -1,4 +1,4 @@ -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use num_bigint::BigUint; use super::{DataFlowGraph, Instruction, NumericType, SimplifyResult, Type, Value, ValueId}; diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction/constrain.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction/constrain.rs index d844f35092..66f50440d6 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction/constrain.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction/constrain.rs @@ -1,4 +1,4 @@ -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use super::{Binary, BinaryOp, ConstrainError, DataFlowGraph, Instruction, Type, Value, ValueId}; diff --git a/compiler/noirc_evaluator/src/ssa/ir/printer.rs b/compiler/noirc_evaluator/src/ssa/ir/printer.rs index 58c593b0ad..f7ffe2406e 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/printer.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/printer.rs @@ -5,6 +5,7 @@ use std::{ }; use acvm::acir::circuit::{ErrorSelector, STRING_ERROR_SELECTOR}; +use acvm::acir::AcirField; use iter_extended::vecmap; use super::{ diff --git a/compiler/noirc_evaluator/src/ssa/ir/types.rs b/compiler/noirc_evaluator/src/ssa/ir/types.rs index d72ad487f6..ded385d2d3 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/types.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/types.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; /// A numeric type in the Intermediate representation diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index ac2f642433..48bd70ff13 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -21,7 +21,7 @@ //! different blocks are merged, i.e. after the [`flatten_cfg`][super::flatten_cfg] pass. use std::collections::HashSet; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; use crate::ssa::{ @@ -288,6 +288,7 @@ mod test { value::{Value, ValueId}, }, }; + use acvm::acir::AcirField; #[test] fn simple_constant_fold() { diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 0f8b49b40e..690c0244f6 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -134,7 +134,7 @@ use fxhash::FxHashMap as HashMap; use std::collections::{BTreeMap, HashSet}; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; use crate::ssa::{ @@ -792,6 +792,8 @@ impl<'f> Context<'f> { mod test { use std::rc::Rc; + use acvm::acir::AcirField; + use crate::ssa::{ function_builder::FunctionBuilder, ir::{ diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs index 4fc19acd2a..f0760f2900 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/capacity_tracker.rs @@ -5,7 +5,7 @@ use crate::ssa::ir::{ value::{Value, ValueId}, }; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use fxhash::FxHashMap as HashMap; pub(crate) struct SliceCapacityTracker<'a> { diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs index 80f6529b7b..c59134e4ec 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs @@ -1,4 +1,4 @@ -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use fxhash::{FxHashMap as HashMap, FxHashSet}; use crate::ssa::ir::{ diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 73dc388818..1293671da5 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -4,6 +4,7 @@ //! be a single function remaining when the pass finishes. use std::collections::{BTreeSet, HashSet}; +use acvm::acir::AcirField; use iter_extended::{btree_map, vecmap}; use crate::ssa::{ @@ -562,7 +563,7 @@ impl<'function> PerFunctionContext<'function> { #[cfg(test)] mod test { - use acvm::FieldElement; + use acvm::{acir::AcirField, FieldElement}; use noirc_frontend::monomorphization::ast::InlineType; use crate::ssa::{ diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 7b87142d82..5b1139e5b9 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -406,7 +406,7 @@ impl<'f> PerFunctionContext<'f> { mod tests { use std::rc::Rc; - use acvm::FieldElement; + use acvm::{acir::AcirField, FieldElement}; use im::vector; use crate::ssa::{ diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs index 65a77552c7..628e1bd741 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs @@ -1,6 +1,6 @@ use std::{borrow::Cow, rc::Rc}; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use crate::ssa::{ ir::{ diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs index 464faa5732..6db7699674 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs @@ -10,7 +10,7 @@ //! before the [Instruction]. Continue inserting instructions until the next [Instruction::EnableSideEffects] is encountered. use std::collections::HashSet; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use crate::ssa::{ ir::{ diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs index 91b455dbf2..6ca7eb74e9 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs @@ -1,6 +1,6 @@ use std::collections::hash_map::Entry; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use fxhash::FxHashMap as HashMap; use crate::ssa::ir::value::ValueId; diff --git a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs index f524b10f1f..9d5d7879dc 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/simplify_cfg.rs @@ -11,6 +11,8 @@ //! Currently, 1 and 4 are unimplemented. use std::collections::HashSet; +use acvm::acir::AcirField; + use crate::ssa::{ ir::{ basic_block::BasicBlockId, cfg::ControlFlowGraph, dfg::CallStack, function::Function, @@ -159,6 +161,7 @@ mod test { types::Type, }, }; + use acvm::acir::AcirField; #[test] fn inline_blocks() { diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index cfb5cfac32..5f58be4142 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -16,6 +16,8 @@ //! we remove reference count instructions because they are only used by Brillig bytecode use std::collections::HashSet; +use acvm::acir::AcirField; + use crate::{ errors::RuntimeError, ssa::{ diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs index ebcbfbabe7..df14ee99bd 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/context.rs @@ -1,7 +1,7 @@ use std::rc::Rc; use std::sync::{Mutex, RwLock}; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; use noirc_errors::Location; use noirc_frontend::ast::{BinaryOpKind, Signedness}; diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 0173b17d28..21131c7121 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -6,7 +6,7 @@ use crate::ast::{ UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, Visibility, }; use crate::token::{Attributes, Token}; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; use noirc_errors::{Span, Spanned}; diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index c3556dac6a..090a41fa7d 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -26,6 +26,7 @@ use crate::{ token::IntType, BinaryTypeOperator, }; +use acvm::acir::AcirField; use iter_extended::vecmap; #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Ord, PartialOrd)] diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 863615da53..9b2c0fbfee 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -1,6 +1,7 @@ use std::fmt::Display; use std::sync::atomic::{AtomicU32, Ordering}; +use acvm::acir::AcirField; use acvm::FieldElement; use iter_extended::vecmap; use noirc_errors::{Span, Spanned}; diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 3c8d805d80..4c2b58580c 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1,5 +1,6 @@ use std::rc::Rc; +use acvm::{acir::AcirField, FieldElement}; use iter_extended::vecmap; use noirc_errors::{Location, Span}; diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index af5ba9a44c..34cecf0ece 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -1,5 +1,5 @@ use crate::{hir::def_collector::dc_crate::CompilationError, Type}; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use noirc_errors::{CustomDiagnostic, Location}; use super::value::Value; diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 5984e454f7..c0aeb910f2 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -1,6 +1,6 @@ use std::{collections::hash_map::Entry, rc::Rc}; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use im::Vector; use iter_extended::try_vecmap; use noirc_errors::Location; diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 1f00669735..8beac340c4 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -11,6 +11,8 @@ // XXX: Change mentions of intern to resolve. In regards to the above comment // // XXX: Resolver does not check for unused functions +use acvm::acir::AcirField; + use crate::hir_def::expr::{ HirArrayLiteral, HirBinaryOp, HirBlockExpression, HirCallExpression, HirCapturedVar, HirCastExpression, HirConstructorExpression, HirExpression, HirIdent, HirIfExpression, diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index 0760749c9e..3a570922c8 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -1,3 +1,4 @@ +use acvm::acir::AcirField; use iter_extended::vecmap; use noirc_errors::Span; diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index ebbc7fc981..fdda271e79 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -1,4 +1,4 @@ -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use noirc_errors::{Position, Span, Spanned}; use std::{fmt, iter::Map, vec::IntoIter}; diff --git a/compiler/noirc_frontend/src/monomorphization/debug.rs b/compiler/noirc_frontend/src/monomorphization/debug.rs index 88943be727..3b399c7570 100644 --- a/compiler/noirc_frontend/src/monomorphization/debug.rs +++ b/compiler/noirc_frontend/src/monomorphization/debug.rs @@ -1,3 +1,4 @@ +use acvm::acir::AcirField; use iter_extended::vecmap; use noirc_errors::debug_info::DebugVarId; use noirc_errors::Location; diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 54a6af9774..2e74eb87e6 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -21,7 +21,7 @@ use crate::{ token::FunctionAttribute, Type, TypeBinding, TypeBindings, TypeVariable, TypeVariableKind, }; -use acvm::FieldElement; +use acvm::{acir::AcirField, FieldElement}; use iter_extended::{btree_map, try_vecmap, vecmap}; use noirc_errors::Location; use noirc_printable_type::PrintableType; diff --git a/compiler/noirc_printable_type/src/lib.rs b/compiler/noirc_printable_type/src/lib.rs index cc0dbca247..a12ecf01b5 100644 --- a/compiler/noirc_printable_type/src/lib.rs +++ b/compiler/noirc_printable_type/src/lib.rs @@ -1,6 +1,6 @@ use std::{collections::BTreeMap, str}; -use acvm::{brillig_vm::brillig::ForeignCallParam, FieldElement}; +use acvm::{acir::AcirField, brillig_vm::brillig::ForeignCallParam, FieldElement}; use iter_extended::vecmap; use regex::{Captures, Regex}; use serde::{Deserialize, Serialize}; @@ -81,10 +81,12 @@ pub enum ForeignCallError { ResolvedAssertMessage(String), } -impl TryFrom<&[ForeignCallParam]> for PrintableValueDisplay { +impl TryFrom<&[ForeignCallParam]> for PrintableValueDisplay { type Error = ForeignCallError; - fn try_from(foreign_call_inputs: &[ForeignCallParam]) -> Result { + fn try_from( + foreign_call_inputs: &[ForeignCallParam], + ) -> Result { let (is_fmt_str, foreign_call_inputs) = foreign_call_inputs.split_last().ok_or(ForeignCallError::MissingForeignCallInputs)?; @@ -97,7 +99,7 @@ impl TryFrom<&[ForeignCallParam]> for PrintableValueDisplay { } fn convert_string_inputs( - foreign_call_inputs: &[ForeignCallParam], + foreign_call_inputs: &[ForeignCallParam], ) -> Result { // Fetch the PrintableType from the foreign call input // The remaining input values should hold what is to be printed @@ -114,7 +116,7 @@ fn convert_string_inputs( } fn convert_fmt_string_inputs( - foreign_call_inputs: &[ForeignCallParam], + foreign_call_inputs: &[ForeignCallParam], ) -> Result { let (message, input_and_printable_types) = foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?; @@ -143,7 +145,7 @@ fn convert_fmt_string_inputs( } fn fetch_printable_type( - printable_type: &ForeignCallParam, + printable_type: &ForeignCallParam, ) -> Result { let printable_type_as_fields = printable_type.fields(); let printable_type_as_string = decode_string_value(&printable_type_as_fields); diff --git a/compiler/wasm/Cargo.toml b/compiler/wasm/Cargo.toml index 31ef516101..23686cc4ea 100644 --- a/compiler/wasm/Cargo.toml +++ b/compiler/wasm/Cargo.toml @@ -13,7 +13,7 @@ license.workspace = true crate-type = ["cdylib"] [dependencies] -acvm.workspace = true +acvm = { workspace = true, features = ["bn254"] } fm.workspace = true nargo.workspace = true noirc_driver.workspace = true diff --git a/tooling/acvm_cli/src/cli/execute_cmd.rs b/tooling/acvm_cli/src/cli/execute_cmd.rs index 5f9651c913..ac3af03684 100644 --- a/tooling/acvm_cli/src/cli/execute_cmd.rs +++ b/tooling/acvm_cli/src/cli/execute_cmd.rs @@ -2,6 +2,7 @@ use std::io::{self, Write}; use acir::circuit::Program; use acir::native_types::{WitnessMap, WitnessStack}; +use acir::FieldElement; use bn254_blackbox_solver::Bn254BlackBoxSolver; use clap::Args; @@ -63,11 +64,11 @@ pub(crate) fn run(args: ExecuteCommand) -> Result { } pub(crate) fn execute_program_from_witness( - inputs_map: WitnessMap, + inputs_map: WitnessMap, bytecode: &[u8], foreign_call_resolver_url: Option<&str>, -) -> Result { - let program: Program = Program::deserialize_program(bytecode) +) -> Result, CliError> { + let program: Program = Program::deserialize_program(bytecode) .map_err(|_| CliError::CircuitDeserializationError())?; execute_program( &program, diff --git a/tooling/acvm_cli/src/cli/fs/inputs.rs b/tooling/acvm_cli/src/cli/fs/inputs.rs index 2a46cfba88..a0b6e3a954 100644 --- a/tooling/acvm_cli/src/cli/fs/inputs.rs +++ b/tooling/acvm_cli/src/cli/fs/inputs.rs @@ -1,6 +1,6 @@ use acir::{ native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, FieldElement, }; use toml::Table; @@ -11,7 +11,7 @@ use std::{fs::read, path::Path}; pub(crate) fn read_inputs_from_file>( working_directory: P, file_name: &String, -) -> Result { +) -> Result, CliError> { let file_path = working_directory.as_ref().join(file_name); if !file_path.exists() { return Err(CliError::FilesystemError(FilesystemError::MissingTomlFile( @@ -25,7 +25,7 @@ pub(crate) fn read_inputs_from_file>( let input_map = input_string .parse::() .map_err(|_| FilesystemError::InvalidTomlFile(file_name.clone()))?; - let mut witnesses: WitnessMap = WitnessMap::new(); + let mut witnesses: WitnessMap = WitnessMap::new(); for (key, value) in input_map.into_iter() { let index = Witness(key.trim().parse().map_err(|_| CliError::WitnessIndexError(key.clone()))?); diff --git a/tooling/acvm_cli/src/cli/fs/witness.rs b/tooling/acvm_cli/src/cli/fs/witness.rs index 30ef4278f4..6ecba9792c 100644 --- a/tooling/acvm_cli/src/cli/fs/witness.rs +++ b/tooling/acvm_cli/src/cli/fs/witness.rs @@ -5,7 +5,11 @@ use std::{ path::{Path, PathBuf}, }; -use acvm::acir::native_types::{WitnessMap, WitnessStack}; +use acir::FieldElement; +use acvm::acir::{ + native_types::{WitnessMap, WitnessStack}, + AcirField, +}; use crate::errors::{CliError, FilesystemError}; @@ -31,7 +35,9 @@ fn write_to_file(bytes: &[u8], path: &Path) -> String { } /// Creates a toml representation of the provided witness map -pub(crate) fn create_output_witness_string(witnesses: &WitnessMap) -> Result { +pub(crate) fn create_output_witness_string( + witnesses: &WitnessMap, +) -> Result { let mut witness_map: BTreeMap = BTreeMap::new(); for (key, value) in witnesses.clone().into_iter() { witness_map.insert(key.0.to_string(), format!("0x{}", value.to_hex())); @@ -41,7 +47,7 @@ pub(crate) fn create_output_witness_string(witnesses: &WitnessMap) -> Result>( - witnesses: WitnessStack, + witnesses: WitnessStack, witness_name: &str, witness_dir: P, ) -> Result { diff --git a/tooling/debugger/src/context.rs b/tooling/debugger/src/context.rs index a031d127d8..110e3211e2 100644 --- a/tooling/debugger/src/context.rs +++ b/tooling/debugger/src/context.rs @@ -26,28 +26,28 @@ pub(super) enum DebugCommandResult { Error(NargoError), } -pub(super) struct DebugContext<'a, B: BlackBoxFunctionSolver> { - acvm: ACVM<'a, B>, - brillig_solver: Option>, +pub(super) struct DebugContext<'a, B: BlackBoxFunctionSolver> { + acvm: ACVM<'a, FieldElement, B>, + brillig_solver: Option>, foreign_call_executor: Box, debug_artifact: &'a DebugArtifact, breakpoints: HashSet, source_to_opcodes: BTreeMap>, - unconstrained_functions: &'a [BrilligBytecode], + unconstrained_functions: &'a [BrilligBytecode], // Absolute (in terms of all the opcodes ACIR+Brillig) addresses of the ACIR // opcodes with one additional entry for to indicate the last valid address. acir_opcode_addresses: Vec, } -impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { +impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { pub(super) fn new( blackbox_solver: &'a B, - circuit: &'a Circuit, + circuit: &'a Circuit, debug_artifact: &'a DebugArtifact, - initial_witness: WitnessMap, + initial_witness: WitnessMap, foreign_call_executor: Box, - unconstrained_functions: &'a [BrilligBytecode], + unconstrained_functions: &'a [BrilligBytecode], ) -> Self { let source_to_opcodes = build_source_to_opcode_debug_mappings(debug_artifact); let acir_opcode_addresses = build_acir_opcode_offsets(circuit, unconstrained_functions); @@ -70,11 +70,11 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { } } - pub(super) fn get_opcodes(&self) -> &[Opcode] { + pub(super) fn get_opcodes(&self) -> &[Opcode] { self.acvm.opcodes() } - pub(super) fn get_witness_map(&self) -> &WitnessMap { + pub(super) fn get_witness_map(&self) -> &WitnessMap { self.acvm.witness_map() } @@ -302,7 +302,10 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { } } - fn handle_foreign_call(&mut self, foreign_call: ForeignCallWaitInfo) -> DebugCommandResult { + fn handle_foreign_call( + &mut self, + foreign_call: ForeignCallWaitInfo, + ) -> DebugCommandResult { let foreign_call_result = self.foreign_call_executor.execute(&foreign_call); match foreign_call_result { Ok(foreign_call_result) => { @@ -319,7 +322,7 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { } } - fn handle_acvm_status(&mut self, status: ACVMStatus) -> DebugCommandResult { + fn handle_acvm_status(&mut self, status: ACVMStatus) -> DebugCommandResult { if let ACVMStatus::RequiresForeignCall(foreign_call) = status { return self.handle_foreign_call(foreign_call); } @@ -465,7 +468,7 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { } } - pub(super) fn get_brillig_memory(&self) -> Option<&[MemoryValue]> { + pub(super) fn get_brillig_memory(&self) -> Option<&[MemoryValue]> { self.brillig_solver.as_ref().map(|solver| solver.get_memory()) } @@ -539,7 +542,7 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { matches!(self.acvm.get_status(), ACVMStatus::Solved) } - pub fn finalize(self) -> WitnessMap { + pub fn finalize(self) -> WitnessMap { self.acvm.finalize() } } @@ -591,8 +594,8 @@ fn build_source_to_opcode_debug_mappings( } fn build_acir_opcode_offsets( - circuit: &Circuit, - unconstrained_functions: &[BrilligBytecode], + circuit: &Circuit, + unconstrained_functions: &[BrilligBytecode], ) -> Vec { let mut result = Vec::with_capacity(circuit.opcodes.len() + 1); // address of the first opcode is always 0 @@ -625,6 +628,7 @@ mod tests { opcodes::BlockId, }, native_types::Expression, + AcirField, }, blackbox_solver::StubbedBlackBoxSolver, brillig_vm::brillig::{ diff --git a/tooling/debugger/src/dap.rs b/tooling/debugger/src/dap.rs index c9b6b816a7..d0ed0559b8 100644 --- a/tooling/debugger/src/dap.rs +++ b/tooling/debugger/src/dap.rs @@ -4,7 +4,7 @@ use std::io::{Read, Write}; use acvm::acir::circuit::brillig::BrilligBytecode; use acvm::acir::circuit::{Circuit, OpcodeLocation}; use acvm::acir::native_types::WitnessMap; -use acvm::BlackBoxFunctionSolver; +use acvm::{BlackBoxFunctionSolver, FieldElement}; use crate::context::DebugCommandResult; use crate::context::DebugContext; @@ -31,7 +31,7 @@ use noirc_driver::CompiledProgram; type BreakpointId = i64; -pub struct DapSession<'a, R: Read, W: Write, B: BlackBoxFunctionSolver> { +pub struct DapSession<'a, R: Read, W: Write, B: BlackBoxFunctionSolver> { server: Server, context: DebugContext<'a, B>, debug_artifact: &'a DebugArtifact, @@ -57,14 +57,14 @@ impl From for ScopeReferences { } } -impl<'a, R: Read, W: Write, B: BlackBoxFunctionSolver> DapSession<'a, R, W, B> { +impl<'a, R: Read, W: Write, B: BlackBoxFunctionSolver> DapSession<'a, R, W, B> { pub fn new( server: Server, solver: &'a B, - circuit: &'a Circuit, + circuit: &'a Circuit, debug_artifact: &'a DebugArtifact, - initial_witness: WitnessMap, - unconstrained_functions: &'a [BrilligBytecode], + initial_witness: WitnessMap, + unconstrained_functions: &'a [BrilligBytecode], ) -> Self { let context = DebugContext::new( solver, @@ -602,11 +602,11 @@ impl<'a, R: Read, W: Write, B: BlackBoxFunctionSolver> DapSession<'a, R, W, B> { } } -pub fn run_session( +pub fn run_session>( server: Server, solver: &B, program: CompiledProgram, - initial_witness: WitnessMap, + initial_witness: WitnessMap, ) -> Result<(), ServerError> { let debug_artifact = DebugArtifact { debug_symbols: program.debug, diff --git a/tooling/debugger/src/foreign_calls.rs b/tooling/debugger/src/foreign_calls.rs index 209439f5f9..6989936ae9 100644 --- a/tooling/debugger/src/foreign_calls.rs +++ b/tooling/debugger/src/foreign_calls.rs @@ -1,7 +1,7 @@ use acvm::{ acir::brillig::{ForeignCallParam, ForeignCallResult}, pwg::ForeignCallWaitInfo, - FieldElement, + AcirField, FieldElement, }; use nargo::{ artifacts::debug::{DebugArtifact, DebugVars, StackFrame}, @@ -93,8 +93,8 @@ fn debug_fn_id(value: &FieldElement) -> DebugFnId { impl ForeignCallExecutor for DefaultDebugForeignCallExecutor { fn execute( &mut self, - foreign_call: &ForeignCallWaitInfo, - ) -> Result { + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError> { let foreign_call_name = foreign_call.function.as_str(); match DebugForeignCall::lookup(foreign_call_name) { Some(DebugForeignCall::VarAssign) => { diff --git a/tooling/debugger/src/lib.rs b/tooling/debugger/src/lib.rs index a8fc61c893..d7a1337c82 100644 --- a/tooling/debugger/src/lib.rs +++ b/tooling/debugger/src/lib.rs @@ -10,29 +10,29 @@ use std::io::{Read, Write}; use ::dap::errors::ServerError; use ::dap::server::Server; use acvm::acir::circuit::brillig::BrilligBytecode; -use acvm::BlackBoxFunctionSolver; use acvm::{acir::circuit::Circuit, acir::native_types::WitnessMap}; +use acvm::{BlackBoxFunctionSolver, FieldElement}; use nargo::artifacts::debug::DebugArtifact; use nargo::NargoError; use noirc_driver::CompiledProgram; -pub fn debug_circuit( +pub fn debug_circuit>( blackbox_solver: &B, - circuit: &Circuit, + circuit: &Circuit, debug_artifact: DebugArtifact, - initial_witness: WitnessMap, - unconstrained_functions: &[BrilligBytecode], -) -> Result, NargoError> { + initial_witness: WitnessMap, + unconstrained_functions: &[BrilligBytecode], +) -> Result>, NargoError> { repl::run(blackbox_solver, circuit, &debug_artifact, initial_witness, unconstrained_functions) } -pub fn run_dap_loop( +pub fn run_dap_loop>( server: Server, solver: &B, program: CompiledProgram, - initial_witness: WitnessMap, + initial_witness: WitnessMap, ) -> Result<(), ServerError> { dap::run_session(server, solver, program, initial_witness) } diff --git a/tooling/debugger/src/repl.rs b/tooling/debugger/src/repl.rs index 8f908a38ff..5aef12ad8d 100644 --- a/tooling/debugger/src/repl.rs +++ b/tooling/debugger/src/repl.rs @@ -15,23 +15,23 @@ use std::cell::RefCell; use crate::source_code_printer::print_source_code_location; -pub struct ReplDebugger<'a, B: BlackBoxFunctionSolver> { +pub struct ReplDebugger<'a, B: BlackBoxFunctionSolver> { context: DebugContext<'a, B>, blackbox_solver: &'a B, - circuit: &'a Circuit, + circuit: &'a Circuit, debug_artifact: &'a DebugArtifact, - initial_witness: WitnessMap, + initial_witness: WitnessMap, last_result: DebugCommandResult, - unconstrained_functions: &'a [BrilligBytecode], + unconstrained_functions: &'a [BrilligBytecode], } -impl<'a, B: BlackBoxFunctionSolver> ReplDebugger<'a, B> { +impl<'a, B: BlackBoxFunctionSolver> ReplDebugger<'a, B> { pub fn new( blackbox_solver: &'a B, - circuit: &'a Circuit, + circuit: &'a Circuit, debug_artifact: &'a DebugArtifact, - initial_witness: WitnessMap, - unconstrained_functions: &'a [BrilligBytecode], + initial_witness: WitnessMap, + unconstrained_functions: &'a [BrilligBytecode], ) -> Self { let foreign_call_executor = Box::new(DefaultDebugForeignCallExecutor::from_artifact(true, debug_artifact)); @@ -161,7 +161,7 @@ impl<'a, B: BlackBoxFunctionSolver> ReplDebugger<'a, B> { "" } }; - let print_brillig_bytecode = |acir_index, bytecode: &[BrilligOpcode]| { + let print_brillig_bytecode = |acir_index, bytecode: &[BrilligOpcode]| { for (brillig_index, brillig_opcode) in bytecode.iter().enumerate() { println!( "{:>3}.{:<2} |{:2} {:?}", @@ -371,18 +371,18 @@ impl<'a, B: BlackBoxFunctionSolver> ReplDebugger<'a, B> { self.context.is_solved() } - fn finalize(self) -> WitnessMap { + fn finalize(self) -> WitnessMap { self.context.finalize() } } -pub fn run( +pub fn run>( blackbox_solver: &B, - circuit: &Circuit, + circuit: &Circuit, debug_artifact: &DebugArtifact, - initial_witness: WitnessMap, - unconstrained_functions: &[BrilligBytecode], -) -> Result, NargoError> { + initial_witness: WitnessMap, + unconstrained_functions: &[BrilligBytecode], +) -> Result>, NargoError> { let context = RefCell::new(ReplDebugger::new( blackbox_solver, circuit, diff --git a/tooling/lsp/src/lib.rs b/tooling/lsp/src/lib.rs index 05345b96c8..304a2d34e4 100644 --- a/tooling/lsp/src/lib.rs +++ b/tooling/lsp/src/lib.rs @@ -13,7 +13,7 @@ use std::{ task::{self, Poll}, }; -use acvm::BlackBoxFunctionSolver; +use acvm::{BlackBoxFunctionSolver, FieldElement}; use async_lsp::{ router::Router, AnyEvent, AnyNotification, AnyRequest, ClientSocket, Error, LspService, ResponseError, @@ -79,7 +79,10 @@ pub struct LspState { } impl LspState { - fn new(client: &ClientSocket, solver: impl BlackBoxFunctionSolver + 'static) -> Self { + fn new( + client: &ClientSocket, + solver: impl BlackBoxFunctionSolver + 'static, + ) -> Self { Self { client: client.clone(), root_path: None, @@ -99,7 +102,10 @@ pub struct NargoLspService { } impl NargoLspService { - pub fn new(client: &ClientSocket, solver: impl BlackBoxFunctionSolver + 'static) -> Self { + pub fn new( + client: &ClientSocket, + solver: impl BlackBoxFunctionSolver + 'static, + ) -> Self { let state = LspState::new(client, solver); let mut router = Router::new(state); router diff --git a/tooling/lsp/src/solver.rs b/tooling/lsp/src/solver.rs index 87327b01e3..0fcac73b90 100644 --- a/tooling/lsp/src/solver.rs +++ b/tooling/lsp/src/solver.rs @@ -3,9 +3,9 @@ use acvm::BlackBoxFunctionSolver; // This is a struct that wraps a dynamically dispatched `BlackBoxFunctionSolver` // where we proxy the unimplemented stuff to the wrapped backend, but it // allows us to avoid changing function signatures to include the `Box` -pub(super) struct WrapperSolver(pub(super) Box); +pub(super) struct WrapperSolver(pub(super) Box>); -impl BlackBoxFunctionSolver for WrapperSolver { +impl BlackBoxFunctionSolver for WrapperSolver { fn schnorr_verify( &self, public_key_x: &acvm::FieldElement, diff --git a/tooling/nargo/src/artifacts/contract.rs b/tooling/nargo/src/artifacts/contract.rs index a864da7c33..1afc7977ae 100644 --- a/tooling/nargo/src/artifacts/contract.rs +++ b/tooling/nargo/src/artifacts/contract.rs @@ -1,4 +1,4 @@ -use acvm::acir::circuit::Program; +use acvm::{acir::circuit::Program, FieldElement}; use noirc_abi::{Abi, AbiType, AbiValue}; use noirc_driver::{CompiledContract, CompiledContractOutputs, ContractFunction}; use serde::{Deserialize, Serialize}; @@ -65,7 +65,7 @@ pub struct ContractFunctionArtifact { serialize_with = "Program::serialize_program_base64", deserialize_with = "Program::deserialize_program_base64" )] - pub bytecode: Program, + pub bytecode: Program, #[serde( serialize_with = "ProgramDebugInfo::serialize_compressed_base64_json", diff --git a/tooling/nargo/src/artifacts/program.rs b/tooling/nargo/src/artifacts/program.rs index 3c25b9e334..91f0215741 100644 --- a/tooling/nargo/src/artifacts/program.rs +++ b/tooling/nargo/src/artifacts/program.rs @@ -1,6 +1,7 @@ use std::collections::BTreeMap; use acvm::acir::circuit::Program; +use acvm::FieldElement; use fm::FileId; use noirc_abi::Abi; use noirc_driver::CompiledProgram; @@ -24,7 +25,7 @@ pub struct ProgramArtifact { serialize_with = "Program::serialize_program_base64", deserialize_with = "Program::deserialize_program_base64" )] - pub bytecode: Program, + pub bytecode: Program, #[serde( serialize_with = "ProgramDebugInfo::serialize_compressed_base64_json", diff --git a/tooling/nargo/src/errors.rs b/tooling/nargo/src/errors.rs index 63a72247e2..200420e5ce 100644 --- a/tooling/nargo/src/errors.rs +++ b/tooling/nargo/src/errors.rs @@ -6,6 +6,7 @@ use acvm::{ ResolvedOpcodeLocation, }, pwg::{ErrorLocation, OpcodeResolutionError}, + FieldElement, }; use noirc_abi::{display_abi_error, Abi, AbiErrorType}; use noirc_errors::{ @@ -95,10 +96,10 @@ impl NargoError { #[derive(Debug, Error)] pub enum ExecutionError { #[error("Failed assertion")] - AssertionFailed(ResolvedAssertionPayload, Vec), + AssertionFailed(ResolvedAssertionPayload, Vec), #[error("Failed to solve program: '{}'", .0)] - SolvingError(OpcodeResolutionError, Option>), + SolvingError(OpcodeResolutionError, Option>), } /// Extracts the opcode locations from a nargo error. diff --git a/tooling/nargo/src/ops/execute.rs b/tooling/nargo/src/ops/execute.rs index 4a75212ba4..42e93e0e3c 100644 --- a/tooling/nargo/src/ops/execute.rs +++ b/tooling/nargo/src/ops/execute.rs @@ -1,22 +1,24 @@ use acvm::acir::circuit::brillig::BrilligBytecode; -use acvm::acir::circuit::{OpcodeLocation, Program, ResolvedOpcodeLocation}; +use acvm::acir::circuit::{ + OpcodeLocation, Program, ResolvedAssertionPayload, ResolvedOpcodeLocation, +}; use acvm::acir::native_types::WitnessStack; use acvm::pwg::{ACVMStatus, ErrorLocation, OpcodeNotSolvable, OpcodeResolutionError, ACVM}; -use acvm::BlackBoxFunctionSolver; use acvm::{acir::circuit::Circuit, acir::native_types::WitnessMap}; +use acvm::{BlackBoxFunctionSolver, FieldElement}; use crate::errors::ExecutionError; use crate::NargoError; use super::foreign_calls::ForeignCallExecutor; -struct ProgramExecutor<'a, B: BlackBoxFunctionSolver, F: ForeignCallExecutor> { - functions: &'a [Circuit], +struct ProgramExecutor<'a, B: BlackBoxFunctionSolver, F: ForeignCallExecutor> { + functions: &'a [Circuit], - unconstrained_functions: &'a [BrilligBytecode], + unconstrained_functions: &'a [BrilligBytecode], // This gets built as we run through the program looking at each function call - witness_stack: WitnessStack, + witness_stack: WitnessStack, blackbox_solver: &'a B, @@ -32,10 +34,12 @@ struct ProgramExecutor<'a, B: BlackBoxFunctionSolver, F: ForeignCallExecutor> { current_function_index: usize, } -impl<'a, B: BlackBoxFunctionSolver, F: ForeignCallExecutor> ProgramExecutor<'a, B, F> { +impl<'a, B: BlackBoxFunctionSolver, F: ForeignCallExecutor> + ProgramExecutor<'a, B, F> +{ fn new( - functions: &'a [Circuit], - unconstrained_functions: &'a [BrilligBytecode], + functions: &'a [Circuit], + unconstrained_functions: &'a [BrilligBytecode], blackbox_solver: &'a B, foreign_call_executor: &'a mut F, ) -> Self { @@ -50,12 +54,15 @@ impl<'a, B: BlackBoxFunctionSolver, F: ForeignCallExecutor> ProgramExecutor<'a, } } - fn finalize(self) -> WitnessStack { + fn finalize(self) -> WitnessStack { self.witness_stack } #[tracing::instrument(level = "trace", skip_all)] - fn execute_circuit(&mut self, initial_witness: WitnessMap) -> Result { + fn execute_circuit( + &mut self, + initial_witness: WitnessMap, + ) -> Result, NargoError> { let circuit = &self.functions[self.current_function_index]; let mut acvm = ACVM::new( self.blackbox_solver, @@ -102,13 +109,14 @@ impl<'a, B: BlackBoxFunctionSolver, F: ForeignCallExecutor> ProgramExecutor<'a, _ => None, }; - let assertion_payload = match &error { - OpcodeResolutionError::BrilligFunctionFailed { payload, .. } - | OpcodeResolutionError::UnsatisfiedConstrain { payload, .. } => { - payload.clone() - } - _ => None, - }; + let assertion_payload: Option> = + match &error { + OpcodeResolutionError::BrilligFunctionFailed { payload, .. } + | OpcodeResolutionError::UnsatisfiedConstrain { payload, .. } => { + payload.clone() + } + _ => None, + }; return Err(NargoError::ExecutionError(match assertion_payload { Some(payload) => ExecutionError::AssertionFailed( @@ -166,12 +174,12 @@ impl<'a, B: BlackBoxFunctionSolver, F: ForeignCallExecutor> ProgramExecutor<'a, } #[tracing::instrument(level = "trace", skip_all)] -pub fn execute_program( - program: &Program, - initial_witness: WitnessMap, +pub fn execute_program, F: ForeignCallExecutor>( + program: &Program, + initial_witness: WitnessMap, blackbox_solver: &B, foreign_call_executor: &mut F, -) -> Result { +) -> Result, NargoError> { let mut executor = ProgramExecutor::new( &program.functions, &program.unconstrained_functions, diff --git a/tooling/nargo/src/ops/foreign_calls.rs b/tooling/nargo/src/ops/foreign_calls.rs index c314a230ce..c6b284beb1 100644 --- a/tooling/nargo/src/ops/foreign_calls.rs +++ b/tooling/nargo/src/ops/foreign_calls.rs @@ -1,7 +1,7 @@ use acvm::{ acir::brillig::{ForeignCallParam, ForeignCallResult}, pwg::ForeignCallWaitInfo, - FieldElement, + AcirField, FieldElement, }; use jsonrpc::{arg as build_json_rpc_arg, minreq_http::Builder, Client}; use noirc_printable_type::{decode_string_value, ForeignCallError, PrintableValueDisplay}; @@ -9,8 +9,8 @@ use noirc_printable_type::{decode_string_value, ForeignCallError, PrintableValue pub trait ForeignCallExecutor { fn execute( &mut self, - foreign_call: &ForeignCallWaitInfo, - ) -> Result; + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError>; } /// This enumeration represents the Brillig foreign calls that are natively supported by nargo. @@ -66,11 +66,11 @@ struct MockedCall { /// The oracle it's mocking name: String, /// Optionally match the parameters - params: Option>, + params: Option>>, /// The parameters with which the mock was last called - last_called_params: Option>, + last_called_params: Option>>, /// The result to return when this mock is called - result: ForeignCallResult, + result: ForeignCallResult, /// How many times should this mock be called before it is removed times_left: Option, } @@ -89,7 +89,7 @@ impl MockedCall { } impl MockedCall { - fn matches(&self, name: &str, params: &[ForeignCallParam]) -> bool { + fn matches(&self, name: &str, params: &[ForeignCallParam]) -> bool { self.name == name && (self.params.is_none() || self.params.as_deref() == Some(params)) } } @@ -130,8 +130,8 @@ impl DefaultForeignCallExecutor { impl DefaultForeignCallExecutor { fn extract_mock_id( - foreign_call_inputs: &[ForeignCallParam], - ) -> Result<(usize, &[ForeignCallParam]), ForeignCallError> { + foreign_call_inputs: &[ForeignCallParam], + ) -> Result<(usize, &[ForeignCallParam]), ForeignCallError> { let (id, params) = foreign_call_inputs.split_first().ok_or(ForeignCallError::MissingForeignCallInputs)?; let id = @@ -148,12 +148,14 @@ impl DefaultForeignCallExecutor { self.mocked_responses.iter_mut().find(|response| response.id == id) } - fn parse_string(param: &ForeignCallParam) -> String { + fn parse_string(param: &ForeignCallParam) -> String { let fields: Vec<_> = param.fields().to_vec(); decode_string_value(&fields) } - fn execute_print(foreign_call_inputs: &[ForeignCallParam]) -> Result<(), ForeignCallError> { + fn execute_print( + foreign_call_inputs: &[ForeignCallParam], + ) -> Result<(), ForeignCallError> { let skip_newline = foreign_call_inputs[0].unwrap_field().is_zero(); let foreign_call_inputs = @@ -166,7 +168,7 @@ impl DefaultForeignCallExecutor { } fn format_printable_value( - foreign_call_inputs: &[ForeignCallParam], + foreign_call_inputs: &[ForeignCallParam], skip_newline: bool, ) -> Result { let display_values: PrintableValueDisplay = foreign_call_inputs.try_into()?; @@ -180,8 +182,8 @@ impl DefaultForeignCallExecutor { impl ForeignCallExecutor for DefaultForeignCallExecutor { fn execute( &mut self, - foreign_call: &ForeignCallWaitInfo, - ) -> Result { + foreign_call: &ForeignCallWaitInfo, + ) -> Result, ForeignCallError> { let foreign_call_name = foreign_call.function.as_str(); match ForeignCall::lookup(foreign_call_name) { Some(ForeignCall::Print) => { @@ -280,7 +282,7 @@ impl ForeignCallExecutor for DefaultForeignCallExecutor { let response = external_resolver.send_request(req)?; - let parsed_response: ForeignCallResult = response.result()?; + let parsed_response: ForeignCallResult = response.result()?; Ok(parsed_response) } else { @@ -314,20 +316,32 @@ mod tests { #[rpc] pub trait OracleResolver { #[rpc(name = "echo")] - fn echo(&self, param: ForeignCallParam) -> RpcResult; + fn echo( + &self, + param: ForeignCallParam, + ) -> RpcResult>; #[rpc(name = "sum")] - fn sum(&self, array: ForeignCallParam) -> RpcResult; + fn sum( + &self, + array: ForeignCallParam, + ) -> RpcResult>; } struct OracleResolverImpl; impl OracleResolver for OracleResolverImpl { - fn echo(&self, param: ForeignCallParam) -> RpcResult { + fn echo( + &self, + param: ForeignCallParam, + ) -> RpcResult> { Ok(vec![param].into()) } - fn sum(&self, array: ForeignCallParam) -> RpcResult { + fn sum( + &self, + array: ForeignCallParam, + ) -> RpcResult> { let mut res: FieldElement = 0_usize.into(); for value in array.fields() { diff --git a/tooling/nargo/src/ops/optimize.rs b/tooling/nargo/src/ops/optimize.rs index a62f469632..07adfb57df 100644 --- a/tooling/nargo/src/ops/optimize.rs +++ b/tooling/nargo/src/ops/optimize.rs @@ -1,4 +1,4 @@ -use acvm::acir::circuit::Program; +use acvm::{acir::circuit::Program, FieldElement}; use iter_extended::vecmap; use noirc_driver::{CompiledContract, CompiledProgram}; use noirc_errors::debug_info::DebugInfo; @@ -18,7 +18,10 @@ pub fn optimize_contract(contract: CompiledContract) -> CompiledContract { CompiledContract { functions, ..contract } } -fn optimize_program_internal(mut program: Program, debug: &mut [DebugInfo]) -> Program { +fn optimize_program_internal( + mut program: Program, + debug: &mut [DebugInfo], +) -> Program { let functions = std::mem::take(&mut program.functions); let optimized_functions = functions diff --git a/tooling/nargo/src/ops/test.rs b/tooling/nargo/src/ops/test.rs index 86dd8cd7cd..ed45251ac8 100644 --- a/tooling/nargo/src/ops/test.rs +++ b/tooling/nargo/src/ops/test.rs @@ -1,6 +1,6 @@ use acvm::{ acir::native_types::{WitnessMap, WitnessStack}, - BlackBoxFunctionSolver, + BlackBoxFunctionSolver, FieldElement, }; use noirc_abi::Abi; use noirc_driver::{compile_no_check, CompileError, CompileOptions}; @@ -23,7 +23,7 @@ impl TestStatus { } } -pub fn run_test( +pub fn run_test>( blackbox_solver: &B, context: &mut Context, test_function: &TestFunction, @@ -76,7 +76,7 @@ fn test_status_program_compile_pass( test_function: &TestFunction, abi: Abi, debug: Vec, - circuit_execution: Result, + circuit_execution: Result, NargoError>, ) -> TestStatus { let circuit_execution_err = match circuit_execution { // Circuit execution was successful; ie no errors or unsatisfied constraints diff --git a/tooling/nargo/src/ops/transform.rs b/tooling/nargo/src/ops/transform.rs index b4811bd578..9255ac3e0e 100644 --- a/tooling/nargo/src/ops/transform.rs +++ b/tooling/nargo/src/ops/transform.rs @@ -1,4 +1,7 @@ -use acvm::acir::circuit::{ExpressionWidth, Program}; +use acvm::{ + acir::circuit::{ExpressionWidth, Program}, + FieldElement, +}; use iter_extended::vecmap; use noirc_driver::{CompiledContract, CompiledProgram}; use noirc_errors::debug_info::DebugInfo; @@ -30,10 +33,10 @@ pub fn transform_contract( } fn transform_program_internal( - mut program: Program, + mut program: Program, debug: &mut [DebugInfo], expression_width: ExpressionWidth, -) -> Program { +) -> Program { let functions = std::mem::take(&mut program.functions); let optimized_functions = functions diff --git a/tooling/nargo_cli/Cargo.toml b/tooling/nargo_cli/Cargo.toml index d10dd6a22f..9e886bc700 100644 --- a/tooling/nargo_cli/Cargo.toml +++ b/tooling/nargo_cli/Cargo.toml @@ -32,7 +32,7 @@ noirc_driver.workspace = true noirc_frontend.workspace = true noirc_abi.workspace = true noirc_errors.workspace = true -acvm.workspace = true +acvm = { workspace = true, features = ["bn254"] } bn254_blackbox_solver.workspace = true toml.workspace = true serde.workspace = true diff --git a/tooling/nargo_cli/src/cli/dap_cmd.rs b/tooling/nargo_cli/src/cli/dap_cmd.rs index eded2bfd8d..a84e961cfe 100644 --- a/tooling/nargo_cli/src/cli/dap_cmd.rs +++ b/tooling/nargo_cli/src/cli/dap_cmd.rs @@ -1,5 +1,6 @@ use acvm::acir::circuit::ExpressionWidth; use acvm::acir::native_types::WitnessMap; +use acvm::FieldElement; use bn254_blackbox_solver::Bn254BlackBoxSolver; use clap::Args; use nargo::constants::PROVER_INPUT_FILE; @@ -101,7 +102,7 @@ fn load_and_compile_project( expression_width: ExpressionWidth, acir_mode: bool, skip_instrumentation: bool, -) -> Result<(CompiledProgram, WitnessMap), LoadError> { +) -> Result<(CompiledProgram, WitnessMap), LoadError> { let workspace = find_workspace(project_folder, package) .ok_or(LoadError::Generic(workspace_not_found_error_msg(project_folder, package)))?; let package = workspace diff --git a/tooling/nargo_cli/src/cli/debug_cmd.rs b/tooling/nargo_cli/src/cli/debug_cmd.rs index 7865b60826..adc17f4159 100644 --- a/tooling/nargo_cli/src/cli/debug_cmd.rs +++ b/tooling/nargo_cli/src/cli/debug_cmd.rs @@ -1,6 +1,7 @@ use std::path::PathBuf; use acvm::acir::native_types::{WitnessMap, WitnessStack}; +use acvm::FieldElement; use bn254_blackbox_solver::Bn254BlackBoxSolver; use clap::Args; @@ -199,7 +200,7 @@ fn debug_program_and_decode( program: CompiledProgram, package: &Package, prover_name: &str, -) -> Result<(Option, Option), CliError> { +) -> Result<(Option, Option>), CliError> { // Parse the initial witness values from Prover.toml let (inputs_map, _) = read_inputs_from_file(&package.root_dir, prover_name, Format::Toml, &program.abi)?; @@ -218,7 +219,7 @@ fn debug_program_and_decode( pub(crate) fn debug_program( compiled_program: &CompiledProgram, inputs_map: &InputMap, -) -> Result, CliError> { +) -> Result>, CliError> { let initial_witness = compiled_program.abi.encode(inputs_map, None)?; let debug_artifact = DebugArtifact { diff --git a/tooling/nargo_cli/src/cli/execute_cmd.rs b/tooling/nargo_cli/src/cli/execute_cmd.rs index 3fcedbb8f5..c312cc4cfd 100644 --- a/tooling/nargo_cli/src/cli/execute_cmd.rs +++ b/tooling/nargo_cli/src/cli/execute_cmd.rs @@ -1,4 +1,5 @@ use acvm::acir::native_types::WitnessStack; +use acvm::FieldElement; use bn254_blackbox_solver::Bn254BlackBoxSolver; use clap::Args; @@ -91,7 +92,7 @@ fn execute_program_and_decode( package: &Package, prover_name: &str, foreign_call_resolver_url: Option<&str>, -) -> Result<(Option, WitnessStack), CliError> { +) -> Result<(Option, WitnessStack), CliError> { // Parse the initial witness values from Prover.toml let (inputs_map, _) = read_inputs_from_file(&package.root_dir, prover_name, Format::Toml, &program.abi)?; @@ -109,7 +110,7 @@ pub(crate) fn execute_program( compiled_program: &CompiledProgram, inputs_map: &InputMap, foreign_call_resolver_url: Option<&str>, -) -> Result { +) -> Result, CliError> { let initial_witness = compiled_program.abi.encode(inputs_map, None)?; let solved_witness_stack_err = nargo::ops::execute_program( diff --git a/tooling/nargo_cli/src/cli/fs/witness.rs b/tooling/nargo_cli/src/cli/fs/witness.rs index 613cdec28d..f95eb3d7a4 100644 --- a/tooling/nargo_cli/src/cli/fs/witness.rs +++ b/tooling/nargo_cli/src/cli/fs/witness.rs @@ -1,13 +1,13 @@ use std::path::{Path, PathBuf}; -use acvm::acir::native_types::WitnessStack; +use acvm::{acir::native_types::WitnessStack, FieldElement}; use nargo::constants::WITNESS_EXT; use super::{create_named_dir, write_to_file}; use crate::errors::FilesystemError; pub(crate) fn save_witness_to_dir>( - witness_stack: WitnessStack, + witness_stack: WitnessStack, witness_name: &str, witness_dir: P, ) -> Result { diff --git a/tooling/nargo_cli/src/cli/test_cmd.rs b/tooling/nargo_cli/src/cli/test_cmd.rs index 51e21248af..99c284e501 100644 --- a/tooling/nargo_cli/src/cli/test_cmd.rs +++ b/tooling/nargo_cli/src/cli/test_cmd.rs @@ -1,6 +1,6 @@ use std::io::Write; -use acvm::BlackBoxFunctionSolver; +use acvm::{BlackBoxFunctionSolver, FieldElement}; use bn254_blackbox_solver::Bn254BlackBoxSolver; use clap::Args; use fm::FileManager; @@ -119,7 +119,7 @@ pub(crate) fn run(args: TestCommand, config: NargoConfig) -> Result<(), CliError } } -fn run_tests( +fn run_tests + Default>( file_manager: &FileManager, parsed_files: &ParsedFiles, package: &Package, @@ -157,7 +157,7 @@ fn run_tests( Ok(test_report) } -fn run_test( +fn run_test + Default>( file_manager: &FileManager, parsed_files: &ParsedFiles, package: &Package, diff --git a/tooling/noirc_abi/src/input_parser/json.rs b/tooling/noirc_abi/src/input_parser/json.rs index 7618cd6c15..070f9effe4 100644 --- a/tooling/noirc_abi/src/input_parser/json.rs +++ b/tooling/noirc_abi/src/input_parser/json.rs @@ -1,6 +1,6 @@ use super::{parse_str_to_field, InputValue}; use crate::{errors::InputParserError, Abi, AbiType, MAIN_RETURN_NAME}; -use acvm::FieldElement; +use acvm::{AcirField, FieldElement}; use iter_extended::{try_btree_map, try_vecmap}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; diff --git a/tooling/noirc_abi/src/input_parser/mod.rs b/tooling/noirc_abi/src/input_parser/mod.rs index 9629ddc87a..14d92bc71b 100644 --- a/tooling/noirc_abi/src/input_parser/mod.rs +++ b/tooling/noirc_abi/src/input_parser/mod.rs @@ -3,7 +3,7 @@ use num_traits::{Num, Zero}; use std::collections::{BTreeMap, HashSet}; use thiserror::Error; -use acvm::FieldElement; +use acvm::{AcirField, FieldElement}; use serde::Serialize; use crate::errors::InputParserError; @@ -229,7 +229,7 @@ impl Format { mod serialization_tests { use std::collections::BTreeMap; - use acvm::FieldElement; + use acvm::{AcirField, FieldElement}; use strum::IntoEnumIterator; use crate::{ @@ -362,7 +362,7 @@ fn field_from_big_int(bigint: BigInt) -> FieldElement { #[cfg(test)] mod test { - use acvm::FieldElement; + use acvm::{AcirField, FieldElement}; use num_bigint::BigUint; use super::parse_str_to_field; diff --git a/tooling/noirc_abi/src/input_parser/toml.rs b/tooling/noirc_abi/src/input_parser/toml.rs index b216fe5879..321d3511b5 100644 --- a/tooling/noirc_abi/src/input_parser/toml.rs +++ b/tooling/noirc_abi/src/input_parser/toml.rs @@ -1,6 +1,6 @@ use super::{parse_str_to_field, parse_str_to_signed, InputValue}; use crate::{errors::InputParserError, Abi, AbiType, MAIN_RETURN_NAME}; -use acvm::FieldElement; +use acvm::{AcirField, FieldElement}; use iter_extended::{try_btree_map, try_vecmap}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; diff --git a/tooling/noirc_abi/src/lib.rs b/tooling/noirc_abi/src/lib.rs index 7a1d1787ca..0acace71fb 100644 --- a/tooling/noirc_abi/src/lib.rs +++ b/tooling/noirc_abi/src/lib.rs @@ -8,7 +8,7 @@ use acvm::{ circuit::ErrorSelector, native_types::{Witness, WitnessMap}, }, - FieldElement, + AcirField, FieldElement, }; use errors::AbiError; use input_parser::InputValue; @@ -331,7 +331,7 @@ impl Abi { &self, input_map: &InputMap, return_value: Option, - ) -> Result { + ) -> Result, AbiError> { // Check that no extra witness values have been provided. let param_names = self.parameter_names(); if param_names.len() < input_map.len() { @@ -439,7 +439,7 @@ impl Abi { /// Decode a `WitnessMap` into the types specified in the ABI. pub fn decode( &self, - witness_map: &WitnessMap, + witness_map: &WitnessMap, ) -> Result<(InputMap, Option), AbiError> { let public_inputs_map = try_btree_map(self.parameters.clone(), |AbiParameter { name, typ, .. }| { @@ -652,7 +652,7 @@ pub fn display_abi_error( mod test { use std::collections::BTreeMap; - use acvm::{acir::native_types::Witness, FieldElement}; + use acvm::{acir::native_types::Witness, AcirField, FieldElement}; use crate::{ input_parser::InputValue, Abi, AbiParameter, AbiReturnType, AbiType, AbiVisibility, diff --git a/tooling/noirc_abi_wasm/Cargo.toml b/tooling/noirc_abi_wasm/Cargo.toml index c78c3ead0c..5692c757d3 100644 --- a/tooling/noirc_abi_wasm/Cargo.toml +++ b/tooling/noirc_abi_wasm/Cargo.toml @@ -12,7 +12,7 @@ license.workspace = true crate-type = ["cdylib"] [dependencies] -acvm.workspace = true +acvm = { workspace = true, features = ["bn254"] } noirc_abi.workspace = true iter-extended.workspace = true wasm-bindgen.workspace = true diff --git a/tooling/noirc_abi_wasm/src/js_witness_map.rs b/tooling/noirc_abi_wasm/src/js_witness_map.rs index 293c5c089f..a82621822e 100644 --- a/tooling/noirc_abi_wasm/src/js_witness_map.rs +++ b/tooling/noirc_abi_wasm/src/js_witness_map.rs @@ -2,7 +2,7 @@ use acvm::{ acir::native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, FieldElement, }; use js_sys::{JsString, Map}; use wasm_bindgen::prelude::{wasm_bindgen, JsValue}; @@ -25,8 +25,8 @@ impl Default for JsWitnessMap { } } -impl From for JsWitnessMap { - fn from(witness_map: WitnessMap) -> Self { +impl From> for JsWitnessMap { + fn from(witness_map: WitnessMap) -> Self { let js_map = JsWitnessMap::new(); for (key, value) in witness_map { js_map.set( @@ -38,7 +38,7 @@ impl From for JsWitnessMap { } } -impl From for WitnessMap { +impl From for WitnessMap { fn from(js_map: JsWitnessMap) -> Self { let mut witness_map = WitnessMap::new(); js_map.for_each(&mut |value, key| { @@ -73,7 +73,7 @@ mod test { use acvm::{ acir::native_types::{Witness, WitnessMap}, - FieldElement, + AcirField, FieldElement, }; use wasm_bindgen::JsValue; diff --git a/tooling/noirc_abi_wasm/src/lib.rs b/tooling/noirc_abi_wasm/src/lib.rs index 10c0c43b35..ef4a468b66 100644 --- a/tooling/noirc_abi_wasm/src/lib.rs +++ b/tooling/noirc_abi_wasm/src/lib.rs @@ -5,9 +5,12 @@ // See Cargo.toml for explanation. use getrandom as _; -use acvm::acir::{ - circuit::RawAssertionPayload, - native_types::{WitnessMap, WitnessStack}, +use acvm::{ + acir::{ + circuit::RawAssertionPayload, + native_types::{WitnessMap, WitnessStack}, + }, + FieldElement, }; use iter_extended::try_btree_map; use noirc_abi::{ @@ -125,8 +128,8 @@ pub fn abi_decode(abi: JsAbi, witness_map: JsWitnessMap) -> Result Result, JsAbiError> { console_error_panic_hook::set_once(); - let converted_witness: WitnessMap = witness_map.into(); - let witness_stack: WitnessStack = converted_witness.into(); + let converted_witness: WitnessMap = witness_map.into(); + let witness_stack: WitnessStack = converted_witness.into(); let output = witness_stack.try_into(); output.map_err(|_| JsAbiError::new("Failed to convert to Vec".to_string())) } @@ -140,7 +143,7 @@ pub fn abi_decode_error( let mut abi: Abi = JsValueSerdeExt::into_serde(&JsValue::from(abi)).map_err(|err| err.to_string())?; - let raw_error: RawAssertionPayload = + let raw_error: RawAssertionPayload = JsValueSerdeExt::into_serde(&JsValue::from(raw_error)).map_err(|err| err.to_string())?; let error_type = abi.error_types.remove(&raw_error.selector).expect("Missing error type");