Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improving Shplonk implementation #326

Merged
merged 4 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 130 additions & 60 deletions src/provider/shplonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@ use crate::{CommitmentEngineTrait, NovaError};
use ff::{Field, PrimeFieldBits};
use group::{Curve, Group as group_Group};
use pairing::{Engine, MillerLoopResult, MultiMillerLoop};
use rayon::iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
};
use rayon::prelude::*;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::marker::PhantomData;

use crate::provider::hyperkzg::EvaluationEngine as HyperKZG;
use crate::spartan::math::Math;
use group::prime::PrimeCurveAffine;
use itertools::Itertools;
use ref_cast::RefCast as _;
Expand Down Expand Up @@ -60,7 +64,7 @@ where
transcript.squeeze(b"a").unwrap()
}

fn compute_pi_polynomials(hat_P: &[E::Fr], point: &[E::Fr], eval: &E::Fr) -> Vec<Vec<E::Fr>> {
fn compute_pi_polynomials(hat_P: &[E::Fr], point: &[E::Fr]) -> Vec<Vec<E::Fr>> {
let mut polys: Vec<Vec<E::Fr>> = Vec::new();
polys.push(hat_P.to_vec());

Expand All @@ -78,26 +82,20 @@ where
polys.push(Pi);
}

// TODO avoid including last constant polynomial, known to verifier
polys.push(vec![*eval]);

assert_eq!(polys.len(), 1 + (hat_P.len() as f32).log2().ceil() as usize);
assert_eq!(polys.len(), hat_P.len().log_2());

polys
}

fn compute_commitments(
ck: &UniversalKZGParam<E>,
C: &Commitment<NE>,
_C: &Commitment<NE>,
polys: &[Vec<E::Fr>],
) -> Vec<E::G1Affine> {
// TODO avoid computing commitment to constant polynomial
let mut comms: Vec<NE::GE> = (1..polys.len())
let comms: Vec<NE::GE> = (1..polys.len())
.into_par_iter()
.map(|i| <NE::CE as CommitmentEngineTrait<NE>>::commit(ck, &polys[i]).comm)
.collect();
// TODO avoid inserting commitment known to verifier
comms.insert(0, C.comm);

let mut comms_affine: Vec<E::G1Affine> = vec![E::G1Affine::identity(); comms.len()];
NE::GE::batch_normalize(&comms, &mut comms_affine);
Expand Down Expand Up @@ -169,15 +167,15 @@ where
C: &Commitment<NE>,
hat_P: &[E::Fr],
point: &[E::Fr],
eval: &E::Fr,
_eval: &E::Fr,
) -> Result<EvaluationArgument<E>, NovaError> {
let x: Vec<E::Fr> = point.to_vec();
let ell = x.len();
let n = hat_P.len();
assert_eq!(n, 1 << ell);

// Phase 1 (similar to hyperkzg)
let polys = Self::compute_pi_polynomials(hat_P, point, eval);
let polys = Self::compute_pi_polynomials(hat_P, point);
let comms = Self::compute_commitments(ck, C, &polys);

// Phase 2 (similar to hyperkzg)
Expand Down Expand Up @@ -226,9 +224,9 @@ where
fn verify(
vk: &KZGVerifierKey<E>,
transcript: &mut <NE as NovaEngine>::TE,
_C: &Commitment<NE>,
C: &Commitment<NE>,
point: &[E::Fr],
_P_of_x: &E::Fr,
P_of_x: &E::Fr,
pi: &EvaluationArgument<E>,
) -> Result<(), NovaError> {
let r = HyperKZG::<E, NE>::compute_challenge(&pi.comms, transcript);
adr1anh marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -241,39 +239,22 @@ where
return Err(NovaError::ProofVerifyError);
}

// TODO:
// insert _P_of_x into every pi.evals_i[last]
// insert _C into pi.comms[0]
// compute commitment for eval and insert it into pi.comms[last]
let mut comms = pi.comms.to_vec();
comms.insert(0, C.comm.to_affine());

let q = HyperKZG::<E, NE>::get_batch_challenge(&pi.evals, transcript);
//let q_powers = HyperKZG::<E, NE>::batch_challenge_powers(q, pi.comms.len());

let R_x = UniPoly::new(pi.R_x.clone());

let mut evals_at_r = vec![];
let mut evals_at_minus_r = vec![];
let mut evals_at_r_squared = vec![];
for (i, evals_i) in pi.evals.iter().enumerate() {
if i == 0 {
evals_at_r = evals_i.clone();
}
if i == 1 {
evals_at_minus_r = evals_i.clone();
}
if i == 2 {
evals_at_r_squared = evals_i.clone();
}

let batched_eval = UniPoly::ref_cast(evals_i).evaluate(&q);

let verification_failed = pi.evals.iter().zip_eq(u.iter()).any(|(evals_i, u_i)| {
// here we check correlation between R polynomial and batched evals, e.g.:
// 1) R(r) == eval at r
// 2) R(-r) == eval at -r
// 3) R(r^2) == eval at r^2
if batched_eval != R_x.evaluate(&u[i]) {
return Err(NovaError::ProofVerifyError);
}
let batched_eval = UniPoly::ref_cast(evals_i).evaluate(&q);
batched_eval != R_x.evaluate(u_i)
});
if verification_failed {
return Err(NovaError::ProofVerifyError);
}

// here we check that Pi polynomials were correctly constructed by the prover, using 'r' as a random point, e.g:
Expand All @@ -282,23 +263,33 @@ where
// P_i+1(r^2) == (1 - point_i) * P_i_even + point_i * P_i_odd -> should hold, according to Gemini transformation
let mut point = point.to_vec();
point.reverse();

let r_mul_2 = E::Fr::from(2) * r;
#[allow(clippy::disallowed_methods)]
for (index, ((eval_r, eval_minus_r), eval_r_squared)) in evals_at_r
.iter()
.zip_eq(evals_at_minus_r.iter())
// TODO: Ask Adrian if we need evals_at_r_squared[0] for some additional checks
.zip(evals_at_r_squared[1..].iter())
let verification_failed = pi.evals[0]
.par_iter()
.chain(&[*P_of_x])
.zip_eq(pi.evals[1].par_iter().chain(&[*P_of_x]))
.zip(pi.evals[2][1..].par_iter().chain(&[*P_of_x]))
.enumerate()
{
let even = (*eval_r + eval_minus_r) * (E::Fr::from(2).invert().unwrap());
let odd = (*eval_r - eval_minus_r) * ((E::Fr::from(2) * r).invert().unwrap());
.any(|(index, ((eval_r, eval_minus_r), eval_r_squared))| {
// some optimisation to avoid using expensive inversions:
// P_i+1(r^2) == (1 - point_i) * (P_i(r) + P_i(-r)) * 1/2 + point_i * (P_i(r) - P_i(-r)) * 1/2 * r
// is equivalent to:
// 2 * r * P_i+1(r^2) == r * (1 - point_i) * (P_i(r) + P_i(-r)) + point_i * (P_i(r) - P_i(-r))

let even = *eval_r + eval_minus_r;
let odd = *eval_r - eval_minus_r;
let right = r * ((E::Fr::ONE - point[index]) * even) + (point[index] * odd);
let left = *eval_r_squared * r_mul_2;
left != right
});

if *eval_r_squared != ((E::Fr::ONE - point[index]) * even) + (point[index] * odd) {
return Err(NovaError::ProofVerifyError);
}
if verification_failed {
return Err(NovaError::ProofVerifyError);
}

let C_P: E::G1 = pi.comms.par_iter().map(|comm| comm.to_curve()).rlc(&q);
let C_P: E::G1 = comms.par_iter().map(|comm| comm.to_curve()).rlc(&q);
let C_Q = pi.C_Q;
let C_H = pi.C_H;
let r_squared = u[2];
Expand Down Expand Up @@ -352,10 +343,11 @@ mod tests {
C: &Commitment<NE>,
poly: &[Fr],
point: &[Fr],
eval: &Fr,
_eval: &Fr,
) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point, eval);
let comms = EvaluationEngine::<E, NE>::compute_commitments(ck, C, &polys);
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point);
let mut comms = EvaluationEngine::<E, NE>::compute_commitments(ck, C, &polys);
comms.insert(0, C.comm.to_affine());

let q = Fr::from(8165763);
let q_powers = HyperKZG::<E, NE>::batch_challenge_powers(q, polys.len());
Expand Down Expand Up @@ -404,8 +396,8 @@ mod tests {
assert_eq!(C_K_expected, C_K.to_affine());
}

fn test_k_polynomial_correctness(poly: &[Fr], point: &[Fr], eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point, eval);
fn test_k_polynomial_correctness(poly: &[Fr], point: &[Fr], _eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point);
let q = Fr::from(8165763);
let batched_Pi: UniPoly<Fr> = polys.clone().into_iter().map(UniPoly::new).rlc(&q);

Expand All @@ -428,8 +420,8 @@ mod tests {
assert_eq!(Fr::from(0), K_x.evaluate(&a));
}

fn test_d_polynomial_correctness(poly: &[Fr], point: &[Fr], eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point, eval);
fn test_d_polynomial_correctness(poly: &[Fr], point: &[Fr], _eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point);
let q = Fr::from(8165763);
let batched_Pi: UniPoly<Fr> = polys.clone().into_iter().map(UniPoly::new).rlc(&q);

Expand Down Expand Up @@ -471,8 +463,8 @@ mod tests {
assert_eq!(Q_x, Q_x_recomputed);
}

fn test_batching_property_on_evaluation(poly: &[Fr], point: &[Fr], eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point, eval);
fn test_batching_property_on_evaluation(poly: &[Fr], point: &[Fr], _eval: &Fr) {
let polys = EvaluationEngine::<E, NE>::compute_pi_polynomials(poly, point);

let q = Fr::from(97652);
let u = [Fr::from(10), Fr::from(20), Fr::from(50)];
Expand Down Expand Up @@ -648,4 +640,82 @@ mod tests {
)
.is_err());
}

#[test]
fn test_shplonk_pcs_negative_wrong_commitment() {
let n = 8;
// poly = [1, 2, 1, 4, 1, 2, 1, 4]
let poly = vec![
Fr::ONE,
Fr::from(2),
Fr::from(1),
Fr::from(4),
Fr::ONE,
Fr::from(2),
Fr::from(1),
Fr::from(4),
];
// point = [4,3,8]
let point = vec![Fr::from(4), Fr::from(3), Fr::from(8)];
// eval = 57
let eval = Fr::from(57);

// altered_poly = [85, 84, 83, 82, 81, 80, 79, 78]
let altered_poly = vec![
Fr::from(85),
Fr::from(84),
Fr::from(83),
Fr::from(82),
Fr::from(81),
Fr::from(80),
Fr::from(79),
Fr::from(78),
];

let ck: CommitmentKey<NE> =
<KZGCommitmentEngine<E> as CommitmentEngineTrait<NE>>::setup(b"test", n);

let C1: Commitment<NE> = KZGCommitmentEngine::commit(&ck, &poly); // correct commitment
let C2: Commitment<NE> = KZGCommitmentEngine::commit(&ck, &altered_poly); // wrong commitment

test_negative_inner_commitment(&poly, &point, &eval, &ck, &C1, &C2); // here we check detection when proof and commitment do not correspond
test_negative_inner_commitment(&poly, &point, &eval, &ck, &C2, &C2); // here we check detection when proof was built with wrong commitment
}

fn test_negative_inner_commitment(
poly: &[Fr],
point: &[Fr],
eval: &Fr,
ck: &CommitmentKey<NE>,
C_prover: &Commitment<NE>,
C_verifier: &Commitment<NE>,
) {
let ck = Arc::new(ck.clone());
let (pk, vk): (KZGProverKey<E>, KZGVerifierKey<E>) =
EvaluationEngine::<E, NE>::setup(ck.clone());

let mut prover_transcript = Keccak256Transcript::new(b"TestEval");
let mut verifier_transcript = Keccak256Transcript::<NE>::new(b"TestEval");

let proof = EvaluationEngine::<E, NE>::prove(
&ck,
&pk,
&mut prover_transcript,
C_prover,
poly,
point,
eval,
)
.unwrap();

assert!(EvaluationEngine::<E, NE>::verify(
&vk,
&mut verifier_transcript,
C_verifier,
point,
eval,
&proof
)
.is_err());
}
}
2 changes: 1 addition & 1 deletion src/spartan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod batched;
pub mod batched_ppsnark;
#[macro_use]
mod macros;
mod math;
pub(crate) mod math;
pub mod polys;
pub mod ppsnark;
pub mod snark;
Expand Down