diff --git a/acvm-repo/acvm/src/compiler/mod.rs b/acvm-repo/acvm/src/compiler/mod.rs index 4abf94a2e7..0ed5b59167 100644 --- a/acvm-repo/acvm/src/compiler/mod.rs +++ b/acvm-repo/acvm/src/compiler/mod.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use acir::{ circuit::{opcodes::UnsupportedMemoryOpcode, Circuit, Opcode, OpcodeLocation}, BlackBoxFunc, @@ -27,12 +29,22 @@ pub enum CompileError { /// metadata they had about the opcodes to the new opcode structure generated after the transformation. #[derive(Debug)] pub struct AcirTransformationMap { - /// This is a vector of pointers to the old acir opcodes. The index of the vector is the new opcode index. - /// The value of the vector is the old opcode index pointed. - acir_opcode_positions: Vec, + /// Maps the old acir indices to the new acir indices + old_indices_to_new_indices: HashMap>, } impl AcirTransformationMap { + /// Builds a map from a vector of pointers to the old acir opcodes. + /// The index of the vector is the new opcode index. + /// The value of the vector is the old opcode index pointed. + fn new(acir_opcode_positions: Vec) -> Self { + let mut old_indices_to_new_indices = HashMap::with_capacity(acir_opcode_positions.len()); + for (new_index, old_index) in acir_opcode_positions.into_iter().enumerate() { + old_indices_to_new_indices.entry(old_index).or_insert_with(Vec::new).push(new_index); + } + AcirTransformationMap { old_indices_to_new_indices } + } + pub fn new_locations( &self, old_location: OpcodeLocation, @@ -42,16 +54,16 @@ impl AcirTransformationMap { OpcodeLocation::Brillig { acir_index, .. } => acir_index, }; - self.acir_opcode_positions - .iter() - .enumerate() - .filter(move |(_, &old_index)| old_index == old_acir_index) - .map(move |(new_index, _)| match old_location { - OpcodeLocation::Acir(_) => OpcodeLocation::Acir(new_index), - OpcodeLocation::Brillig { brillig_index, .. } => { - OpcodeLocation::Brillig { acir_index: new_index, brillig_index } - } - }) + self.old_indices_to_new_indices.get(&old_acir_index).into_iter().flat_map( + move |new_indices| { + new_indices.iter().map(move |new_index| match old_location { + OpcodeLocation::Acir(_) => OpcodeLocation::Acir(*new_index), + OpcodeLocation::Brillig { brillig_index, .. } => { + OpcodeLocation::Brillig { acir_index: *new_index, brillig_index } + } + }) + }, + ) } } @@ -74,11 +86,13 @@ pub fn compile( np_language: Language, is_opcode_supported: impl Fn(&Opcode) -> bool, ) -> Result<(Circuit, AcirTransformationMap), CompileError> { - let (acir, AcirTransformationMap { acir_opcode_positions }) = optimize_internal(acir); + let (acir, acir_opcode_positions) = optimize_internal(acir); - let (mut acir, transformation_map) = + let (mut acir, acir_opcode_positions) = transform_internal(acir, np_language, is_opcode_supported, acir_opcode_positions)?; + let transformation_map = AcirTransformationMap::new(acir_opcode_positions); + acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map); Ok((acir, transformation_map)) diff --git a/acvm-repo/acvm/src/compiler/optimizers/mod.rs b/acvm-repo/acvm/src/compiler/optimizers/mod.rs index 6e2af4f58a..85a97c2c7d 100644 --- a/acvm-repo/acvm/src/compiler/optimizers/mod.rs +++ b/acvm-repo/acvm/src/compiler/optimizers/mod.rs @@ -13,7 +13,9 @@ use super::{transform_assert_messages, AcirTransformationMap}; /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] independent optimizations to a [`Circuit`]. pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) { - let (mut acir, transformation_map) = optimize_internal(acir); + let (mut acir, new_opcode_positions) = optimize_internal(acir); + + let transformation_map = AcirTransformationMap::new(new_opcode_positions); acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map); @@ -21,7 +23,7 @@ pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) { } /// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] independent optimizations to a [`Circuit`]. -pub(super) fn optimize_internal(acir: Circuit) -> (Circuit, AcirTransformationMap) { +pub(super) fn optimize_internal(acir: Circuit) -> (Circuit, Vec) { log::trace!("Start circuit optimization"); // General optimizer pass @@ -52,9 +54,7 @@ pub(super) fn optimize_internal(acir: Circuit) -> (Circuit, AcirTransformationMa let (acir, acir_opcode_positions) = range_optimizer.replace_redundant_ranges(acir_opcode_positions); - let transformation_map = AcirTransformationMap { acir_opcode_positions }; - log::trace!("Finish circuit optimization"); - (acir, transformation_map) + (acir, acir_opcode_positions) } diff --git a/acvm-repo/acvm/src/compiler/transformers/mod.rs b/acvm-repo/acvm/src/compiler/transformers/mod.rs index 00ee9dc7ce..c4c94e371b 100644 --- a/acvm-repo/acvm/src/compiler/transformers/mod.rs +++ b/acvm-repo/acvm/src/compiler/transformers/mod.rs @@ -27,9 +27,11 @@ pub fn transform( // by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert) let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect(); - let (mut acir, transformation_map) = + let (mut acir, acir_opcode_positions) = transform_internal(acir, np_language, is_opcode_supported, acir_opcode_positions)?; + let transformation_map = AcirTransformationMap::new(acir_opcode_positions); + acir.assert_messages = transform_assert_messages(acir.assert_messages, &transformation_map); Ok((acir, transformation_map)) @@ -43,7 +45,7 @@ pub(super) fn transform_internal( np_language: Language, is_opcode_supported: impl Fn(&Opcode) -> bool, acir_opcode_positions: Vec, -) -> Result<(Circuit, AcirTransformationMap), CompileError> { +) -> Result<(Circuit, Vec), CompileError> { log::trace!("Start circuit transformation"); // Fallback transformer pass @@ -52,9 +54,8 @@ pub(super) fn transform_internal( let mut transformer = match &np_language { crate::Language::R1CS => { - let transformation_map = AcirTransformationMap { acir_opcode_positions }; let transformer = R1CSTransformer::new(acir); - return Ok((transformer.transform(), transformation_map)); + return Ok((transformer.transform(), acir_opcode_positions)); } crate::Language::PLONKCSat { width } => { let mut csat = CSatTransformer::new(*width); @@ -216,10 +217,7 @@ pub(super) fn transform_internal( ..acir }; - let transformation_map = - AcirTransformationMap { acir_opcode_positions: new_acir_opcode_positions }; - log::trace!("Finish circuit transformation"); - Ok((acir, transformation_map)) + Ok((acir, new_acir_opcode_positions)) }