Skip to content

Commit

Permalink
refactor: Integrate zip_with macro (microsoft#292)
Browse files Browse the repository at this point in the history
- Created a new file `macros.rs` with the implementation of `zip_with` and `zip_with_for_each` macros providing syntactic sugar for zipWith patterns.
- zipWith patterns implemented through the use of the zip_with! macros now resolve to the use of zip_eq, a variant of zip that panics when the iterator arguments are of different length,
- the zip_eq implementation is the native rayon one for parallel iterators, and the one from itertools (see below) for the sequential ones,
- Optimized and refactored functions like `batch_eval_prove` and `batch_eval_verify` in `snark.rs`, methods inside `PolyEvalWitness` and `PolyEvalInstance` in `mod.rs`, and multiple functions in `multilinear.rs` through the use of implemented macros.

- Introduced the use of itertools::Itertools in various files to import the use of zip_eq on sequential iterators.
- Made use of the Itertools library for refactoring and optimizing computation in `sumcheck.rs` and `eq.rs` files.

This backports (among others) content from the following Arecibo PRS:
- lurk-lab/arecibo#149
- lurk-lab/arecibo#158
- lurk-lab/arecibo#169

Co-authored-by: porcuquine <porcuquine@users.noreply.github.com>
  • Loading branch information
huitseeker and porcuquine committed Jan 10, 2024
1 parent 0e8f5fd commit 48887f8
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 116 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ byteorder = "1.4.3"
thiserror = "1.0"
group = "0.13.0"
once_cell = "1.18.0"
itertools = "0.12.0"

[target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies]
pasta-msm = { version = "0.1.4" }
Expand Down
103 changes: 103 additions & 0 deletions src/spartan/macros.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/// Macros to give syntactic sugar for zipWith pattern and variants.
///
/// ```ignore
/// use crate::spartan::zip_with;
/// use itertools::Itertools as _; // we use zip_eq to zip!
/// let v = vec![0, 1, 2];
/// let w = vec![2, 3, 4];
/// let y = vec![4, 5, 6];
///
/// // Using the `zip_with!` macro to zip three iterators together and apply a closure
/// // that sums the elements of each iterator.
/// let res = zip_with!((v.iter(), w.iter(), y.iter()), |a, b, c| a + b + c)
/// .collect::<Vec<_>>();
///
/// println!("{:?}", res); // Output: [6, 9, 12]
/// ```

#[macro_export]
macro_rules! zip_with {
// no iterator projection specified: the macro assumes the arguments *are* iterators
// ```ignore
// zip_with!((iter1, iter2, iter3), |a, b, c| a + b + c) ->
// iter1.zip_eq(iter2.zip_eq(iter3)).map(|(a, (b, c))| a + b + c)
// ```
//
// iterator projection specified: use it on each argument
// ```ignore
// zip_with!(par_iter, (vec1, vec2, vec3), |a, b, c| a + b + c) ->
// vec1.par_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).map(|(a, (b, c))| a + b + c)
// ````
($($f:ident,)? ($e:expr $(, $rest:expr)*), $($move:ident)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{
$crate::zip_with!($($f,)? ($e $(, $rest)*), map, $($move)? |$($i),+| $($work)*)
}};
// no iterator projection specified: the macro assumes the arguments *are* iterators
// optional zipping function specified as well: use it instead of map
// ```ignore
// zip_with!((iter1, iter2, iter3), for_each, |a, b, c| a + b + c) ->
// iter1.zip_eq(iter2.zip_eq(iter3)).for_each(|(a, (b, c))| a + b + c)
// ```
//
//
// iterator projection specified: use it on each argument
// optional zipping function specified as well: use it instead of map
// ```ignore
// zip_with!(par_iter, (vec1, vec2, vec3), for_each, |a, b, c| a + b + c) ->
// vec1.part_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).for_each(|(a, (b, c))| a + b + c)
// ```
($($f:ident,)? ($e:expr $(, $rest:expr)*), $worker:ident, $($move:ident,)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{
$crate::zip_all!($($f,)? ($e $(, $rest)*))
.$worker($($move)? |$crate::nested_idents!($($i),+)| {
$($work)*
})
}};
}

/// Like `zip_with` but use `for_each` instead of `map`.
#[macro_export]
macro_rules! zip_with_for_each {
// no iterator projection specified: the macro assumes the arguments *are* iterators
// ```ignore
// zip_with_for_each!((iter1, iter2, iter3), |a, b, c| a + b + c) ->
// iter1.zip_eq(iter2.zip_eq(iter3)).for_each(|(a, (b, c))| a + b + c)
// ```
//
// iterator projection specified: use it on each argument
// ```ignore
// zip_with_for_each!(par_iter, (vec1, vec2, vec3), |a, b, c| a + b + c) ->
// vec1.par_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).for_each(|(a, (b, c))| a + b + c)
// ````
($($f:ident,)? ($e:expr $(, $rest:expr)*), $($move:ident)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{
$crate::zip_with!($($f,)? ($e $(, $rest)*), for_each, $($move)? |$($i),+| $($work)*)
}};
}

// Foldright-like nesting for idents (a, b, c) -> (a, (b, c))
#[doc(hidden)]
#[macro_export]
macro_rules! nested_idents {
($a:ident, $b:ident) => {
($a, $b)
};
($first:ident, $($rest:ident),+) => {
($first, $crate::nested_idents!($($rest),+))
};
}

// Fold-right like zipping, with an optional function `f` to apply to each argument
#[doc(hidden)]
#[macro_export]
macro_rules! zip_all {
(($e:expr,)) => {
$e
};
($f:ident, ($e:expr,)) => {
$e.$f()
};
($f:ident, ($first:expr, $second:expr $(, $rest:expr)*)) => {
($first.$f().zip_eq($crate::zip_all!($f, ($second, $( $rest),*))))
};
(($first:expr, $second:expr $(, $rest:expr)*)) => {
($first.zip_eq($crate::zip_all!(($second, $( $rest),*))))
};
}
39 changes: 16 additions & 23 deletions src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
//! We also provide direct.rs that allows proving a step circuit directly with either of the two SNARKs.
//!
//! In polynomial.rs we also provide foundational types and functions for manipulating multilinear polynomials.
#[macro_use]
mod macros;
pub mod direct;
pub(crate) mod math;
pub mod polys;
Expand All @@ -14,6 +16,7 @@ mod sumcheck;

use crate::{traits::Engine, Commitment};
use ff::Field;
use itertools::Itertools as _;
use polys::multilinear::SparsePolynomial;
use rayon::{iter::IntoParallelRefIterator, prelude::*};

Expand Down Expand Up @@ -64,20 +67,17 @@ impl<E: Engine> PolyEvalWitness<E> {

let powers_of_s = powers::<E>(s, p_vec.len());

let p = p_vec
.par_iter()
.zip(powers_of_s.par_iter())
.map(|(v, &weight)| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
acc.into_iter().zip(v).map(|(x, y)| x + y).collect()
},
);
let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| {
// compute the weighted sum for each vector
v.iter().map(|&x| x * weight).collect::<Vec<E::Scalar>>()
})
.reduce(
|| vec![E::Scalar::ZERO; p_vec[0].len()],
|acc, v| {
// perform vector addition to combine the weighted vectors
zip_with!((acc.into_iter(), v), |x, y| x + y).collect()
},
);

PolyEvalWitness { p }
}
Expand Down Expand Up @@ -113,15 +113,8 @@ impl<E: Engine> PolyEvalInstance<E> {
s: &E::Scalar,
) -> PolyEvalInstance<E> {
let powers_of_s = powers::<E>(s, c_vec.len());
let e = e_vec
.par_iter()
.zip(powers_of_s.par_iter())
.map(|(e, p)| *e * p)
.sum();
let c = c_vec
.par_iter()
.zip(powers_of_s.par_iter())
.map(|(c, p)| *c * *p)
let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum();
let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p)
.reduce(Commitment::<E>::default, |acc, item| acc + item);

PolyEvalInstance {
Expand Down
11 changes: 4 additions & 7 deletions src/spartan/polys/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,10 @@ impl<Scalar: PrimeField> EqPolynomial<Scalar> {
let (evals_left, evals_right) = evals.split_at_mut(size);
let (evals_right, _) = evals_right.split_at_mut(size);

evals_left
.par_iter_mut()
.zip(evals_right.par_iter_mut())
.for_each(|(x, y)| {
*y = *x * r;
*x -= &*y;
});
zip_with_for_each!(par_iter_mut, (evals_left, evals_right), |x, y| {
*y = *x * r;
*x -= &*y;
});

size *= 2;
}
Expand Down
29 changes: 11 additions & 18 deletions src/spartan/polys/multilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use std::ops::{Add, Index};

use ff::PrimeField;
use itertools::Itertools as _;
use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
Expand Down Expand Up @@ -65,12 +66,9 @@ impl<Scalar: PrimeField> MultilinearPolynomial<Scalar> {

let (left, right) = self.Z.split_at_mut(n);

left
.par_iter_mut()
.zip(right.par_iter())
.for_each(|(a, b)| {
*a += *r * (*b - *a);
});
zip_with_for_each!((left.par_iter_mut(), right.par_iter()), |a, b| {
*a += *r * (*b - *a);
});

self.Z.resize(n, Scalar::ZERO);
self.num_vars -= 1;
Expand All @@ -94,12 +92,12 @@ impl<Scalar: PrimeField> MultilinearPolynomial<Scalar> {

/// Evaluates the polynomial with the given evaluations and point.
pub fn evaluate_with(Z: &[Scalar], r: &[Scalar]) -> Scalar {
EqPolynomial::new(r.to_vec())
.evals()
.into_par_iter()
.zip(Z.into_par_iter())
.map(|(a, b)| a * b)
.sum()
zip_with!(
into_par_iter,
(EqPolynomial::new(r.to_vec()).evals(), Z),
|a, b| a * b
)
.sum()
}
}

Expand Down Expand Up @@ -167,12 +165,7 @@ impl<Scalar: PrimeField> Add for MultilinearPolynomial<Scalar> {
return Err("The two polynomials must have the same number of variables");
}

let sum: Vec<Scalar> = self
.Z
.iter()
.zip(other.Z.iter())
.map(|(a, b)| *a + *b)
.collect();
let sum: Vec<Scalar> = zip_with!(iter, (self.Z, other.Z), |a, b| *a + *b).collect();

Ok(MultilinearPolynomial::new(sum))
}
Expand Down
31 changes: 10 additions & 21 deletions src/spartan/ppsnark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ use crate::{
snark::{DigestHelperTrait, RelaxedR1CSSNARKTrait},
Engine, TranscriptEngineTrait, TranscriptReprTrait,
},
Commitment, CommitmentKey, CompressedCommitment,
zip_with, Commitment, CommitmentKey, CompressedCommitment,
};
use core::cmp::max;
use ff::Field;
use itertools::Itertools as _;
use once_cell::sync::OnceCell;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -339,13 +340,7 @@ impl<E: Engine> MemorySumcheckInstance<E> {
let inv = batch_invert(&T.par_iter().map(|e| *e + *r).collect::<Vec<E::Scalar>>())?;

// compute inv[i] * TS[i] in parallel
Ok(
inv
.par_iter()
.zip(TS.par_iter())
.map(|(e1, e2)| *e1 * *e2)
.collect::<Vec<_>>(),
)
Ok(zip_with!(par_iter, (inv, TS), |e1, e2| *e1 * e2).collect::<Vec<_>>())
},
|| batch_invert(&W.par_iter().map(|e| *e + *r).collect::<Vec<E::Scalar>>()),
)
Expand Down Expand Up @@ -853,11 +848,7 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> RelaxedR1CSSNARK<E, EE> {
let coeffs = powers::<E>(&s, claims.len());

// compute the joint claim
let claim = claims
.iter()
.zip(coeffs.iter())
.map(|(c_1, c_2)| *c_1 * c_2)
.sum();
let claim = zip_with!(iter, (claims, coeffs), |c_1, c_2| *c_1 * c_2).sum();

let mut e = claim;
let mut r: Vec<E::Scalar> = Vec::new();
Expand Down Expand Up @@ -1086,14 +1077,12 @@ impl<E: Engine, EE: EvaluationEngineTrait<E>> RelaxedR1CSSNARKTrait<E> for Relax
);

// a sum-check instance to prove the second claim
let val = pk
.S_repr
.val_A
.par_iter()
.zip(pk.S_repr.val_B.par_iter())
.zip(pk.S_repr.val_C.par_iter())
.map(|((v_a, v_b), v_c)| *v_a + c * *v_b + c * c * *v_c)
.collect::<Vec<E::Scalar>>();
let val = zip_with!(
par_iter,
(pk.S_repr.val_A, pk.S_repr.val_B, pk.S_repr.val_C),
|v_a, v_b, v_c| *v_a + c * *v_b + c * c * *v_c
)
.collect::<Vec<E::Scalar>>();
let inner_sc_inst = InnerSumcheckInstance {
claim: eval_Az_at_tau + c * eval_Bz_at_tau + c * c * eval_Cz_at_tau,
poly_L_row: MultilinearPolynomial::new(L_row.clone()),
Expand Down

0 comments on commit 48887f8

Please sign in to comment.