From 4c1c184ff5b210b4f55fcd47f2dad8e12f57709f Mon Sep 17 00:00:00 2001 From: grjte Date: Tue, 4 Jul 2023 15:56:50 +0100 Subject: [PATCH 1/5] feat: switch TraceLde to a trait --- prover/src/constraints/evaluator.rs | 18 ++-- prover/src/lib.rs | 43 ++++----- prover/src/trace/commitment.rs | 140 ---------------------------- prover/src/trace/mod.rs | 3 - prover/src/trace/tests.rs | 111 +--------------------- prover/src/trace/trace_lde.rs | 119 ----------------------- prover/src/trace/trace_lde/mod.rs | 90 ++++++++++++++++++ 7 files changed, 121 insertions(+), 403 deletions(-) delete mode 100644 prover/src/trace/commitment.rs delete mode 100644 prover/src/trace/trace_lde.rs create mode 100644 prover/src/trace/trace_lde/mod.rs diff --git a/prover/src/constraints/evaluator.rs b/prover/src/constraints/evaluator.rs index 5acc7be27..f5a324c98 100644 --- a/prover/src/constraints/evaluator.rs +++ b/prover/src/constraints/evaluator.rs @@ -71,9 +71,9 @@ impl<'a, A: Air, E: FieldElement> ConstraintEvaluator< /// Evaluates constraints against the provided extended execution trace. Constraints are /// evaluated over a constraint evaluation domain. This is an optimization because constraint /// evaluation domain can be many times smaller than the full LDE domain. - pub fn evaluate( + pub fn evaluate>( self, - trace: &TraceLde, + trace: &T, domain: &'a StarkDomain, ) -> ConstraintEvaluationTable<'a, E> { assert_eq!( @@ -137,14 +137,14 @@ impl<'a, A: Air, E: FieldElement> ConstraintEvaluator< /// Evaluates constraints for a single fragment of the evaluation table. /// /// This evaluates constraints only over the main segment of the execution trace. - fn evaluate_fragment_main( + fn evaluate_fragment_main>( &self, - trace: &TraceLde, + trace: &T, domain: &StarkDomain, fragment: &mut EvaluationTableFragment, ) { // initialize buffers to hold trace values and evaluation results at each step; - let mut main_frame = EvaluationFrame::new(trace.main_trace_width()); + let mut main_frame = EvaluationFrame::new(trace.trace_layout().main_trace_width()); let mut evaluations = vec![E::ZERO; fragment.num_columns()]; let mut t_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()]; @@ -188,15 +188,15 @@ impl<'a, A: Air, E: FieldElement> ConstraintEvaluator< /// /// This evaluates constraints only over all segments of the execution trace (i.e. main segment /// and all auxiliary segments). - fn evaluate_fragment_full( + fn evaluate_fragment_full>( &self, - trace: &TraceLde, + trace: &T, domain: &StarkDomain, fragment: &mut EvaluationTableFragment, ) { // initialize buffers to hold trace values and evaluation results at each step - let mut main_frame = EvaluationFrame::new(trace.main_trace_width()); - let mut aux_frame = EvaluationFrame::new(trace.aux_trace_width()); + let mut main_frame = EvaluationFrame::new(trace.trace_layout().main_trace_width()); + let mut aux_frame = EvaluationFrame::new(trace.trace_layout().aux_trace_width()); let mut tm_evaluations = vec![E::BaseField::ZERO; self.num_main_transition_constraints()]; let mut ta_evaluations = vec![E::ZERO; self.num_aux_transition_constraints()]; let mut evaluations = vec![E::ZERO; fragment.num_columns()]; diff --git a/prover/src/lib.rs b/prover/src/lib.rs index af475e08a..6b42bbfe0 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -86,8 +86,7 @@ mod composer; use composer::DeepCompositionPoly; mod trace; -pub use trace::{Trace, TraceTable, TraceTableFragment}; -use trace::{TraceCommitment, TraceLde, TracePolyTable}; +pub use trace::{Trace, TraceLde, TracePolyTable, TraceTable, TraceTableFragment}; mod channel; use channel::ProverChannel; @@ -136,6 +135,13 @@ pub trait Prover { /// PRNG to be used for generating random field elements. type RandomCoin: RandomCoin; + // Trace low-degree extension for building the LDEs of trace segments and their commitments. + type TraceLde>: TraceLde< + BaseField = Self::BaseField, + ExtensionField = E, + HashFn = Self::HashFn, + >; + // REQUIRED METHODS // -------------------------------------------------------------------------------------------- @@ -227,22 +233,15 @@ pub trait Prover { ); // extend the main execution trace and build a Merkle tree from the extended trace - let (main_trace_lde, main_trace_tree, main_trace_polys) = - self.build_trace_commitment::(trace.main_segment(), &domain); + let (mut trace_polys, mut trace_lde): (TracePolyTable, Self::TraceLde) = + TraceLde::new(&trace.get_info(), trace.main_segment(), &domain); + + // get the commitment to the main trace segment LDE + let main_trace_root = trace_lde.get_main_trace_commitment(); // commit to the LDE of the main trace by writing the root of its Merkle tree into // the channel - channel.commit_trace(*main_trace_tree.root()); - - // initialize trace commitment and trace polynomial table structs with the main trace - // data; for multi-segment traces these structs will be used as accumulators of all - // trace segments - let mut trace_commitment = TraceCommitment::new( - main_trace_lde, - main_trace_tree, - domain.trace_to_lde_blowup(), - ); - let mut trace_polys = TracePolyTable::new(main_trace_polys); + channel.commit_trace(main_trace_root); // build auxiliary trace segments (if any), and append the resulting segments to trace // commitment and trace polynomial table structs @@ -268,15 +267,13 @@ pub trait Prover { ); // extend the auxiliary trace segment and build a Merkle tree from the extended trace - let (aux_segment_lde, aux_segment_tree, aux_segment_polys) = - self.build_trace_commitment::(&aux_segment, &domain); + let (aux_segment_polys, aux_segment_root) = + trace_lde.add_aux_segment(&aux_segment, &domain); - // commit to the LDE of the extended auxiliary trace segment by writing the root of + // commit to the LDE of the extended auxiliary trace segment by writing the root of // its Merkle tree into the channel - channel.commit_trace(*aux_segment_tree.root()); + channel.commit_trace(aux_segment_root); - // append the segment to the trace commitment and trace polynomial table structs - trace_commitment.add_segment(aux_segment_lde, aux_segment_tree); trace_polys.add_aux_segment(aux_segment_polys); aux_trace_rand_elements.add_segment_elements(rand_elements); aux_trace_segments.push(aux_segment); @@ -299,7 +296,7 @@ pub trait Prover { let now = Instant::now(); let constraint_coeffs = channel.get_constraint_composition_coeffs(); let evaluator = ConstraintEvaluator::new(&air, aux_trace_rand_elements, constraint_coeffs); - let constraint_evaluations = evaluator.evaluate(trace_commitment.trace_table(), &domain); + let constraint_evaluations = evaluator.evaluate(&trace_lde, &domain); #[cfg(feature = "std")] debug!( "Evaluated constraints over domain of 2^{} elements in {} ms", @@ -434,7 +431,7 @@ pub trait Prover { // query the execution trace at the selected position; for each query, we need the // state of the trace at that position + Merkle authentication path - let trace_queries = trace_commitment.query(&query_positions); + let trace_queries = trace_lde.query(&query_positions); // query the constraint commitment at the selected positions; for each query, we need just // a Merkle authentication path. this is because constraint evaluations for each step are diff --git a/prover/src/trace/commitment.rs b/prover/src/trace/commitment.rs deleted file mode 100644 index a84f90b5e..000000000 --- a/prover/src/trace/commitment.rs +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -use crate::RowMatrix; -use air::proof::Queries; -use crypto::{ElementHasher, MerkleTree}; -use math::FieldElement; -use utils::collections::Vec; - -use super::TraceLde; - -// TRACE COMMITMENT -// ================================================================================================ - -/// Execution trace commitment. -/// -/// The describes one or more trace segments, each consisting of the following components: -/// * Evaluations of a trace segment's polynomials over the LDE domain. -/// * Merkle tree where each leaf in the tree corresponds to a row in the trace LDE matrix. -pub struct TraceCommitment> { - trace_lde: TraceLde, - main_segment_tree: MerkleTree, - aux_segment_trees: Vec>, -} - -impl> TraceCommitment { - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - /// Creates a new trace commitment from the provided main trace low-degree extension and the - /// corresponding Merkle tree commitment. - pub fn new( - main_trace_lde: RowMatrix, - main_trace_tree: MerkleTree, - blowup: usize, - ) -> Self { - assert_eq!( - main_trace_lde.num_rows(), - main_trace_tree.leaves().len(), - "number of rows in trace LDE must be the same as number of leaves in trace commitment" - ); - Self { - trace_lde: TraceLde::new(main_trace_lde, blowup), - main_segment_tree: main_trace_tree, - aux_segment_trees: Vec::new(), - } - } - - // STATE MUTATORS - // -------------------------------------------------------------------------------------------- - - /// Adds the provided auxiliary segment trace LDE and Merkle tree to this trace commitment. - pub fn add_segment(&mut self, aux_segment_lde: RowMatrix, aux_segment_tree: MerkleTree) { - assert_eq!( - aux_segment_lde.num_rows(), - aux_segment_tree.leaves().len(), - "number of rows in trace LDE must be the same as number of leaves in trace commitment" - ); - - self.trace_lde.add_aux_segment(aux_segment_lde); - self.aux_segment_trees.push(aux_segment_tree); - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns the execution trace for this commitment. - /// - /// The trace contains both the main trace segment and the auxiliary trace segments (if any). - pub fn trace_table(&self) -> &TraceLde { - &self.trace_lde - } - - // QUERY TRACE - // -------------------------------------------------------------------------------------------- - /// Returns trace table rows at the specified positions along with Merkle authentication paths - /// from the commitment root to these rows. - pub fn query(&self, positions: &[usize]) -> Vec { - // build queries for the main trace segment - let mut result = vec![build_segment_queries( - self.trace_lde.get_main_segment(), - &self.main_segment_tree, - positions, - )]; - - // build queries for auxiliary trace segments - for (i, segment_tree) in self.aux_segment_trees.iter().enumerate() { - let segment_lde = self.trace_lde.get_aux_segment(i); - result.push(build_segment_queries(segment_lde, segment_tree, positions)); - } - - result - } - - // TEST HELPERS - // -------------------------------------------------------------------------------------------- - - /// Returns the root of the commitment Merkle tree. - #[cfg(test)] - pub fn main_trace_root(&self) -> H::Digest { - *self.main_segment_tree.root() - } - - /// Returns the entire trace for the column at the specified index. - #[cfg(test)] - pub fn get_main_trace_column(&self, col_idx: usize) -> Vec { - let trace = self.trace_lde.get_main_segment(); - (0..trace.num_rows()) - .map(|row_idx| trace.get(col_idx, row_idx)) - .collect() - } -} - -// HELPER FUNCTIONS -// ================================================================================================ - -fn build_segment_queries( - segment_lde: &RowMatrix, - segment_tree: &MerkleTree, - positions: &[usize], -) -> Queries -where - E: FieldElement, - H: ElementHasher, -{ - // for each position, get the corresponding row from the trace segment LDE and put all these - // rows into a single vector - let trace_states = positions - .iter() - .map(|&pos| segment_lde.row(pos).to_vec()) - .collect::>(); - - // build Merkle authentication paths to the leaves specified by positions - let trace_proof = segment_tree - .prove_batch(positions) - .expect("failed to generate a Merkle proof for trace queries"); - - Queries::new(trace_proof, trace_states) -} diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index 786a50b16..af3ecdf6e 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -16,9 +16,6 @@ pub use poly_table::TracePolyTable; mod trace_table; pub use trace_table::{TraceTable, TraceTableFragment}; -mod commitment; -pub use commitment::TraceCommitment; - #[cfg(test)] mod tests; diff --git a/prover/src/trace/tests.rs b/prover/src/trace/tests.rs index fd2f66209..29460705b 100644 --- a/prover/src/trace/tests.rs +++ b/prover/src/trace/tests.rs @@ -3,20 +3,10 @@ // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. -use crate::{ - tests::{build_fib_trace, MockAir}, - trace::TracePolyTable, - RowMatrix, StarkDomain, Trace, TraceCommitment, -}; -use crypto::{hashers::Blake3_256, ElementHasher, MerkleTree}; -use math::{ - fields::f128::BaseElement, get_power_series, get_power_series_with_offset, polynom, - FieldElement, StarkField, -}; +use crate::{tests::build_fib_trace, Trace}; +use math::fields::f128::BaseElement; use utils::collections::Vec; -type Blake3 = Blake3_256; - #[test] fn new_trace_table() { let trace_length = 8; @@ -37,100 +27,3 @@ fn new_trace_table() { .collect(); assert_eq!(expected, trace.get_column(1)); } - -#[test] -fn extend_trace_table() { - // build the trace and the domain - let trace_length = 8; - let air = MockAir::with_trace_length(trace_length); - let trace = build_fib_trace(trace_length * 2); - let domain = StarkDomain::new(&air); - - // build extended trace commitment - let trace_polys = trace.main_segment().interpolate_columns(); - let trace_lde = RowMatrix::evaluate_polys_over::<8>(&trace_polys, &domain); - let trace_tree = trace_lde.commit_to_rows::(); - let trace_comm = TraceCommitment::::new( - trace_lde, - trace_tree, - domain.trace_to_lde_blowup(), - ); - let trace_polys = TracePolyTable::::new(trace_polys); - - assert_eq!(2, trace_comm.trace_table().main_trace_width()); - assert_eq!(64, trace_comm.trace_table().trace_len()); - - // make sure trace polynomials evaluate to Fibonacci trace - let trace_root = BaseElement::get_root_of_unity(trace_length.ilog2()); - let trace_domain = get_power_series(trace_root, trace_length); - assert_eq!(2, trace_polys.num_main_trace_polys()); - assert_eq!( - vec![1u32, 2, 5, 13, 34, 89, 233, 610] - .into_iter() - .map(BaseElement::from) - .collect::>(), - polynom::eval_many(trace_polys.get_main_trace_poly(0), &trace_domain) - ); - assert_eq!( - vec![1u32, 3, 8, 21, 55, 144, 377, 987] - .into_iter() - .map(BaseElement::from) - .collect::>(), - polynom::eval_many(trace_polys.get_main_trace_poly(1), &trace_domain) - ); - - // make sure column values are consistent with trace polynomials - let lde_domain = build_lde_domain(domain.lde_domain_size()); - assert_eq!( - trace_polys.get_main_trace_poly(0), - polynom::interpolate(&lde_domain, &trace_comm.get_main_trace_column(0), true) - ); - assert_eq!( - trace_polys.get_main_trace_poly(1), - polynom::interpolate(&lde_domain, &trace_comm.get_main_trace_column(1), true) - ); -} - -#[test] -fn commit_trace_table() { - // build the trade and the domain - let trace_length = 8; - let air = MockAir::with_trace_length(trace_length); - let trace = build_fib_trace(trace_length * 2); - let domain = StarkDomain::new(&air); - - // build extended trace commitment - let trace_polys = trace.main_segment().interpolate_columns(); - let trace_lde = RowMatrix::evaluate_polys_over::<8>(&trace_polys, &domain); - let trace_tree = trace_lde.commit_to_rows::(); - let trace_comm = TraceCommitment::::new( - trace_lde, - trace_tree, - domain.trace_to_lde_blowup(), - ); - - // build Merkle tree from trace rows - let trace_table = trace_comm.trace_table(); - let mut hashed_states = Vec::new(); - let mut trace_state = vec![BaseElement::ZERO; trace_table.main_trace_width()]; - #[allow(clippy::needless_range_loop)] - for i in 0..trace_table.trace_len() { - for j in 0..trace_table.main_trace_width() { - trace_state[j] = trace_table.get_main_segment().get(j, i); - } - let buf = Blake3::hash_elements(&trace_state); - hashed_states.push(buf); - } - let expected_tree = MerkleTree::::new(hashed_states).unwrap(); - - // compare the result - assert_eq!(*expected_tree.root(), trace_comm.main_trace_root()) -} - -// HELPER FUNCTIONS -// ================================================================================================ - -fn build_lde_domain(domain_size: usize) -> Vec { - let g = B::get_root_of_unity(domain_size.ilog2()); - get_power_series_with_offset(g, B::GENERATOR, domain_size) -} diff --git a/prover/src/trace/trace_lde.rs b/prover/src/trace/trace_lde.rs deleted file mode 100644 index 2ebfb356d..000000000 --- a/prover/src/trace/trace_lde.rs +++ /dev/null @@ -1,119 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. -// -// This source code is licensed under the MIT license found in the -// LICENSE file in the root directory of this source tree. - -use crate::RowMatrix; -use air::EvaluationFrame; -use math::FieldElement; -use utils::collections::Vec; - -// TRACE LOW DEGREE EXTENSION -// ================================================================================================ -/// Contains all segments of the extended execution trace. -/// -/// Segments are stored in two groups: -/// - Main segment: this is the first trace segment generated by the prover. Values in this segment -/// will always be elements in the base field (even when an extension field is used). -/// - Auxiliary segments: a list of 0 or more segments for traces generated after the prover -/// commits to the first trace segment. Currently, at most 1 auxiliary segment is possible. -pub struct TraceLde { - main_segment_lde: RowMatrix, - aux_segment_ldes: Vec>, - blowup: usize, -} - -impl TraceLde { - // CONSTRUCTOR - // -------------------------------------------------------------------------------------------- - /// Creates a new trace low-degree extension table from the provided main trace segment LDE. - pub fn new(main_trace_lde: RowMatrix, blowup: usize) -> Self { - Self { - main_segment_lde: main_trace_lde, - aux_segment_ldes: Vec::new(), - blowup, - } - } - - // STATE MUTATORS - // -------------------------------------------------------------------------------------------- - - /// Adds the provided auxiliary segment LDE to this trace LDE. - pub fn add_aux_segment(&mut self, aux_segment_lde: RowMatrix) { - assert_eq!( - self.main_segment_lde.num_rows(), - aux_segment_lde.num_rows(), - "number of rows in auxiliary segment must be of the same as in the main segment" - ); - self.aux_segment_ldes.push(aux_segment_lde); - } - - // PUBLIC ACCESSORS - // -------------------------------------------------------------------------------------------- - - /// Returns number of columns in the main segment of the execution trace. - pub fn main_trace_width(&self) -> usize { - self.main_segment_lde.num_cols() - } - - /// Returns number of columns in the auxiliary segments of the execution trace. - pub fn aux_trace_width(&self) -> usize { - self.aux_segment_ldes - .iter() - .fold(0, |s, m| s + m.num_cols()) - } - - /// Returns the number of rows in the execution trace. - pub fn trace_len(&self) -> usize { - self.main_segment_lde.num_rows() - } - - /// Returns blowup factor which was used to extend original execution trace into trace LDE. - pub fn blowup(&self) -> usize { - self.blowup - } - - /// Reads current and next rows from the main trace segment into the specified frame. - pub fn read_main_trace_frame_into( - &self, - lde_step: usize, - frame: &mut EvaluationFrame, - ) { - // at the end of the trace, next state wraps around and we read the first step again - let next_lde_step = (lde_step + self.blowup()) % self.trace_len(); - - // copy main trace segment values into the frame - frame - .current_mut() - .copy_from_slice(self.main_segment_lde.row(lde_step)); - frame - .next_mut() - .copy_from_slice(self.main_segment_lde.row(next_lde_step)); - } - - /// Reads current and next rows from the auxiliary trace segment into the specified frame. - /// - /// # Panics - /// This currently assumes that there is exactly one auxiliary trace segment, and will panic - /// otherwise. - pub fn read_aux_trace_frame_into(&self, lde_step: usize, frame: &mut EvaluationFrame) { - // at the end of the trace, next state wraps around and we read the first step again - let next_lde_step = (lde_step + self.blowup()) % self.trace_len(); - - // copy auxiliary trace segment values into the frame - let segment = &self.aux_segment_ldes[0]; - frame.current_mut().copy_from_slice(segment.row(lde_step)); - frame.next_mut().copy_from_slice(segment.row(next_lde_step)); - } - - /// Returns a reference to [Matrix] representing the main trace segment. - pub fn get_main_segment(&self) -> &RowMatrix { - &self.main_segment_lde - } - - /// Returns a reference to a [Matrix] representing an auxiliary trace segment at the specified - /// index. - pub fn get_aux_segment(&self, aux_segment_idx: usize) -> &RowMatrix { - &self.aux_segment_ldes[aux_segment_idx] - } -} diff --git a/prover/src/trace/trace_lde/mod.rs b/prover/src/trace/trace_lde/mod.rs new file mode 100644 index 000000000..5ff6e6f31 --- /dev/null +++ b/prover/src/trace/trace_lde/mod.rs @@ -0,0 +1,90 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use super::{ColMatrix, EvaluationFrame, FieldElement, StarkField, TracePolyTable}; +use crate::StarkDomain; +use air::{proof::Queries, TraceInfo, TraceLayout}; +use crypto::{ElementHasher, Hasher}; +use utils::collections::Vec; + +// TRACE LOW DEGREE EXTENSION +// ================================================================================================ +/// Contains all segments of the extended execution trace and their commitments. +/// +/// Segments are stored in two groups: +/// - Main segment: this is the first trace segment generated by the prover. Values in this segment +/// will always be elements in the base field (even when an extension field is used). +/// - Auxiliary segments: a list of 0 or more segments for traces generated after the prover +/// commits to the first trace segment. Currently, at most 1 auxiliary segment is possible. +pub trait TraceLde: Sync { + /// The base field, used for computation on the main trace segment. + type BaseField: StarkField; + /// The extension field, used for computation on auxiliary trace segments. + type ExtensionField: FieldElement; + /// The hash function used for building the Merkle tree commitments to trace segment LDEs. + type HashFn: ElementHasher; + + /// Takes the main trace segment columns as input, interpolates them into polynomials in + /// coefficient form, and evaluates the polynomials over the LDE domain. + /// + /// Returns a tuple containing a [TracePolyTable] with the trace polynomials for the main trace + /// and a new [TraceLde] instance from which the LDE and trace commitments can be obtained. + fn new( + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (TracePolyTable, Self); + + /// Returns the commitment to the low-degree extension of the main trace segment. + fn get_main_trace_commitment(&self) -> ::Digest; + + /// Takes auxiliary trace segment columns as input, interpolates them into polynomials in + /// coefficient form, evaluates the polynomials over the LDE domain, and commits to the + /// polynomial evaluations. + /// + /// Returns a tuple containing the column polynomials in coefficient form and the commitment + /// to the polynomial evaluations over the LDE domain. + /// + /// # Panics + /// + /// This function is expected to panic if any of the following are true: + /// - the number of rows in the provided `aux_trace` does not match the main trace. + /// - this segment would exceed the number of segments specified by the trace layout. + fn add_aux_segment( + &mut self, + aux_trace: &ColMatrix, + domain: &StarkDomain, + ) -> ( + ColMatrix, + ::Digest, + ); + + /// Reads current and next rows from the main trace segment into the specified frame. + fn read_main_trace_frame_into( + &self, + lde_step: usize, + frame: &mut EvaluationFrame, + ); + + /// Reads current and next rows from the auxiliary trace segment into the specified frame. + fn read_aux_trace_frame_into( + &self, + lde_step: usize, + frame: &mut EvaluationFrame, + ); + + /// Returns trace table rows at the specified positions along with Merkle authentication paths + /// from the commitment root to these rows. + fn query(&self, positions: &[usize]) -> Vec; + + /// Returns the number of rows in the execution trace. + fn trace_len(&self) -> usize; + + /// Returns blowup factor which was used to extend original execution trace into trace LDE. + fn blowup(&self) -> usize; + + /// Returns the trace layout of the execution trace. + fn trace_layout(&self) -> &TraceLayout; +} From b646439ce2bf7ebc53f4f4d0956742bc769a5035 Mon Sep 17 00:00:00 2001 From: grjte Date: Tue, 4 Jul 2023 16:35:46 +0100 Subject: [PATCH 2/5] feat: add default impl of TraceLde --- prover/src/lib.rs | 49 +---- prover/src/trace/mod.rs | 2 +- prover/src/trace/trace_lde/default.rs | 263 ++++++++++++++++++++++++++ prover/src/trace/trace_lde/mod.rs | 3 + 4 files changed, 268 insertions(+), 49 deletions(-) create mode 100644 prover/src/trace/trace_lde/default.rs diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 6b42bbfe0..efa8b3af2 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -65,7 +65,7 @@ use math::{ }; pub use crypto; -use crypto::{ElementHasher, MerkleTree, RandomCoin}; +use crypto::{ElementHasher, RandomCoin}; #[cfg(feature = "std")] use log::debug; @@ -446,53 +446,6 @@ pub trait Prover { Ok(proof) } - /// Computes a low-degree extension (LDE) of the provided execution trace over the specified - /// domain and build a commitment to the extended trace. - /// - /// The extension is performed by interpolating each column of the execution trace into a - /// polynomial of degree = trace_length - 1, and then evaluating the polynomial over the LDE - /// domain. - /// - /// Trace commitment is computed by hashing each row of the extended execution trace, and then - /// building a Merkle tree from the resulting hashes. - fn build_trace_commitment( - &self, - trace: &ColMatrix, - domain: &StarkDomain, - ) -> (RowMatrix, MerkleTree, ColMatrix) - where - E: FieldElement, - { - // extend the execution trace - #[cfg(feature = "std")] - let now = Instant::now(); - let trace_polys = trace.interpolate_columns(); - let trace_lde = - RowMatrix::evaluate_polys_over::(&trace_polys, domain); - #[cfg(feature = "std")] - debug!( - "Extended execution trace of {} columns from 2^{} to 2^{} steps ({}x blowup) in {} ms", - trace_lde.num_cols(), - trace_polys.num_rows().ilog2(), - trace_lde.num_rows().ilog2(), - domain.trace_to_lde_blowup(), - now.elapsed().as_millis() - ); - - // build trace commitment - #[cfg(feature = "std")] - let now = Instant::now(); - let trace_tree = trace_lde.commit_to_rows(); - #[cfg(feature = "std")] - debug!( - "Computed execution trace commitment (Merkle tree of depth {}) in {} ms", - trace_tree.depth(), - now.elapsed().as_millis() - ); - - (trace_lde, trace_tree, trace_polys) - } - /// Evaluates constraint composition polynomial over the LDE domain and builds a commitment /// to these evaluations. /// diff --git a/prover/src/trace/mod.rs b/prover/src/trace/mod.rs index af3ecdf6e..df835697f 100644 --- a/prover/src/trace/mod.rs +++ b/prover/src/trace/mod.rs @@ -8,7 +8,7 @@ use air::{Air, AuxTraceRandElements, EvaluationFrame, TraceInfo, TraceLayout}; use math::{polynom, FieldElement, StarkField}; mod trace_lde; -pub use trace_lde::TraceLde; +pub use trace_lde::{DefaultTraceLde, TraceLde}; mod poly_table; pub use poly_table::TracePolyTable; diff --git a/prover/src/trace/trace_lde/default.rs b/prover/src/trace/trace_lde/default.rs new file mode 100644 index 000000000..de12c59b1 --- /dev/null +++ b/prover/src/trace/trace_lde/default.rs @@ -0,0 +1,263 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use super::{ + ColMatrix, ElementHasher, EvaluationFrame, FieldElement, Hasher, Queries, StarkDomain, + TraceInfo, TraceLayout, TraceLde, TracePolyTable, Vec, +}; +use crate::{RowMatrix, DEFAULT_SEGMENT_WIDTH}; +use crypto::MerkleTree; + +#[cfg(feature = "std")] +use log::debug; +#[cfg(feature = "std")] +use std::time::Instant; + +// TRACE LOW DEGREE EXTENSION +// ================================================================================================ +/// Contains all segments of the extended execution trace, the commitments to these segments, the +/// LDE blowup factor, and the [TraceInfo]. +/// +/// Segments are stored in two groups: +/// - Main segment: this is the first trace segment generated by the prover. Values in this segment +/// will always be elements in the base field (even when an extension field is used). +/// - Auxiliary segments: a list of 0 or more segments for traces generated after the prover +/// commits to the first trace segment. Currently, at most 1 auxiliary segment is possible. +pub struct DefaultTraceLde> { + // low-degree extension of the main segment of the trace + main_segment_lde: RowMatrix, + // commitment to the main segment of the trace + main_segment_tree: MerkleTree, + // low-degree extensions of the auxiliary segments of the trace + aux_segment_ldes: Vec>, + // commitment to the auxiliary segments of the trace + aux_segment_trees: Vec>, + blowup: usize, + trace_info: TraceInfo, +} + +impl> TraceLde + for DefaultTraceLde +{ + type BaseField = E::BaseField; + type ExtensionField = E; + type HashFn = H; + + /// Takes the main trace segment columns as input, interpolates them into polynomials in + /// coefficient form, evaluates the polynomials over the LDE domain, commits to the + /// polynomial evaluations, and creates a new [DefaultTraceLde] with the LDE of the main trace + /// segment and the commitment. + /// + /// Returns a tuple containing a [TracePolyTable] with the trace polynomials for the main trace + /// segment and the new [DefaultTraceLde]. + fn new( + trace_info: &TraceInfo, + main_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (TracePolyTable, Self) { + // extend the main execution trace and build a Merkle tree from the extended trace + let (main_segment_lde, main_segment_tree, main_segment_polys) = + build_trace_commitment::(main_trace, domain); + + let trace_poly_table = TracePolyTable::new(main_segment_polys); + let trace_lde = DefaultTraceLde { + main_segment_lde, + main_segment_tree, + aux_segment_ldes: Vec::new(), + aux_segment_trees: Vec::new(), + blowup: domain.trace_to_lde_blowup(), + trace_info: trace_info.clone(), + }; + + (trace_poly_table, trace_lde) + } + + /// Returns the commitment to the low-degree extension of the main trace segment. + fn get_main_trace_commitment(&self) -> ::Digest { + let root_hash = self.main_segment_tree.root(); + *root_hash + } + + /// Takes auxiliary trace segment columns as input, interpolates them into polynomials in + /// coefficient form, evaluates the polynomials over the LDE domain, and commits to the + /// polynomial evaluations. + /// + /// Returns a tuple containing the column polynomials in coefficient from and the commitment + /// to the polynomial evaluations over the LDE domain. + /// + /// # Panics + /// + /// This function will panic if any of the following are true: + /// - the number of rows in the provided `aux_trace` does not match the main trace. + /// - this segment would exceed the number of segments specified by the trace layout. + fn add_aux_segment( + &mut self, + aux_trace: &ColMatrix, + domain: &StarkDomain, + ) -> (ColMatrix, ::Digest) { + // extend the auxiliary trace segment and build a Merkle tree from the extended trace + let (aux_segment_lde, aux_segment_tree, aux_segment_polys) = + build_trace_commitment::(aux_trace, domain); + + // check errors + assert!( + self.aux_segment_ldes.len() < self.trace_info.layout().num_aux_segments(), + "the specified number of auxiliary segments has already been added" + ); + assert_eq!( + self.main_segment_lde.num_rows(), + aux_segment_lde.num_rows(), + "the number of rows in the auxiliary segment must be the same as in the main segment" + ); + + // save the lde and commitment + self.aux_segment_ldes.push(aux_segment_lde); + let root_hash = *aux_segment_tree.root(); + self.aux_segment_trees.push(aux_segment_tree); + + (aux_segment_polys, root_hash) + } + + /// Reads current and next rows from the main trace segment into the specified frame. + fn read_main_trace_frame_into( + &self, + lde_step: usize, + frame: &mut EvaluationFrame, + ) { + // at the end of the trace, next state wraps around and we read the first step again + let next_lde_step = (lde_step + self.blowup()) % self.trace_len(); + + // copy main trace segment values into the frame + frame + .current_mut() + .copy_from_slice(self.main_segment_lde.row(lde_step)); + frame + .next_mut() + .copy_from_slice(self.main_segment_lde.row(next_lde_step)); + } + + /// Reads current and next rows from the auxiliary trace segment into the specified frame. + /// + /// # Panics + /// This currently assumes that there is exactly one auxiliary trace segment, and will panic + /// otherwise. + fn read_aux_trace_frame_into(&self, lde_step: usize, frame: &mut EvaluationFrame) { + // at the end of the trace, next state wraps around and we read the first step again + let next_lde_step = (lde_step + self.blowup()) % self.trace_len(); + + // copy auxiliary trace segment values into the frame + let segment = &self.aux_segment_ldes[0]; + frame.current_mut().copy_from_slice(segment.row(lde_step)); + frame.next_mut().copy_from_slice(segment.row(next_lde_step)); + } + + /// Returns trace table rows at the specified positions along with Merkle authentication paths + /// from the commitment root to these rows. + fn query(&self, positions: &[usize]) -> Vec { + // build queries for the main trace segment + let mut result = vec![build_segment_queries( + &self.main_segment_lde, + &self.main_segment_tree, + positions, + )]; + + // build queries for auxiliary trace segments + for (i, segment_tree) in self.aux_segment_trees.iter().enumerate() { + let segment_lde = &self.aux_segment_ldes[i]; + result.push(build_segment_queries(segment_lde, segment_tree, positions)); + } + + result + } + + /// Returns the number of rows in the execution trace. + fn trace_len(&self) -> usize { + self.main_segment_lde.num_rows() + } + + /// Returns blowup factor which was used to extend original execution trace into trace LDE. + fn blowup(&self) -> usize { + self.blowup + } + + /// Returns the trace layout of the execution trace. + fn trace_layout(&self) -> &TraceLayout { + self.trace_info.layout() + } +} + +// HELPER FUNCTIONS +// ================================================================================================ + +/// Computes a low-degree extension (LDE) of the provided execution trace over the specified +/// domain and builds a commitment to the extended trace. +/// +/// The extension is performed by interpolating each column of the execution trace into a +/// polynomial of degree = trace_length - 1, and then evaluating the polynomial over the LDE +/// domain. +/// +/// The trace commitment is computed by hashing each row of the extended execution trace, then +/// building a Merkle tree from the resulting hashes. +fn build_trace_commitment( + trace: &ColMatrix, + domain: &StarkDomain, +) -> (RowMatrix, MerkleTree, ColMatrix) +where + E: FieldElement, + F: FieldElement, + H: ElementHasher, +{ + // extend the execution trace + #[cfg(feature = "std")] + let now = Instant::now(); + let trace_polys = trace.interpolate_columns(); + let trace_lde = RowMatrix::evaluate_polys_over::(&trace_polys, domain); + #[cfg(feature = "std")] + debug!( + "Extended execution trace of {} columns from 2^{} to 2^{} steps ({}x blowup) in {} ms", + trace_lde.num_cols(), + trace_polys.num_rows().ilog2(), + trace_lde.num_rows().ilog2(), + domain.trace_to_lde_blowup(), + now.elapsed().as_millis() + ); + + // build trace commitment + #[cfg(feature = "std")] + let now = Instant::now(); + let trace_tree = trace_lde.commit_to_rows(); + #[cfg(feature = "std")] + debug!( + "Computed execution trace commitment (Merkle tree of depth {}) in {} ms", + trace_tree.depth(), + now.elapsed().as_millis() + ); + + (trace_lde, trace_tree, trace_polys) +} + +fn build_segment_queries( + segment_lde: &RowMatrix, + segment_tree: &MerkleTree, + positions: &[usize], +) -> Queries +where + E: FieldElement, + H: ElementHasher, +{ + // for each position, get the corresponding row from the trace segment LDE and put all these + // rows into a single vector + let trace_states = positions + .iter() + .map(|&pos| segment_lde.row(pos).to_vec()) + .collect::>(); + + // build Merkle authentication paths to the leaves specified by positions + let trace_proof = segment_tree + .prove_batch(positions) + .expect("failed to generate a Merkle proof for trace queries"); + + Queries::new(trace_proof, trace_states) +} diff --git a/prover/src/trace/trace_lde/mod.rs b/prover/src/trace/trace_lde/mod.rs index 5ff6e6f31..cd45ca3c9 100644 --- a/prover/src/trace/trace_lde/mod.rs +++ b/prover/src/trace/trace_lde/mod.rs @@ -9,6 +9,9 @@ use air::{proof::Queries, TraceInfo, TraceLayout}; use crypto::{ElementHasher, Hasher}; use utils::collections::Vec; +mod default; +pub use default::DefaultTraceLde; + // TRACE LOW DEGREE EXTENSION // ================================================================================================ /// Contains all segments of the extended execution trace and their commitments. From 31d6078dd801607bfb88bd09190bde6fbff8fbce Mon Sep 17 00:00:00 2001 From: grjte Date: Thu, 6 Jul 2023 11:01:01 +0100 Subject: [PATCH 3/5] fix: update examples to use default TraceLde impl --- examples/src/fibonacci/fib2/mod.rs | 2 +- examples/src/fibonacci/fib2/prover.rs | 5 +++-- examples/src/fibonacci/fib8/mod.rs | 2 +- examples/src/fibonacci/fib8/prover.rs | 5 +++-- examples/src/fibonacci/fib_small/mod.rs | 2 +- examples/src/fibonacci/fib_small/prover.rs | 6 +++--- examples/src/fibonacci/mulfib2/mod.rs | 2 +- examples/src/fibonacci/mulfib2/prover.rs | 5 +++-- examples/src/fibonacci/mulfib8/mod.rs | 2 +- examples/src/fibonacci/mulfib8/prover.rs | 5 +++-- examples/src/lamport/aggregate/mod.rs | 2 +- examples/src/lamport/aggregate/prover.rs | 7 ++++--- examples/src/lamport/threshold/mod.rs | 2 +- examples/src/lamport/threshold/prover.rs | 8 +++++--- examples/src/merkle/mod.rs | 2 +- examples/src/merkle/prover.rs | 7 ++++--- examples/src/rescue/mod.rs | 2 +- examples/src/rescue/prover.rs | 6 ++++-- examples/src/rescue_raps/mod.rs | 2 +- examples/src/rescue_raps/prover.rs | 5 +++-- examples/src/vdf/exempt/mod.rs | 2 +- examples/src/vdf/exempt/prover.rs | 5 +++-- examples/src/vdf/regular/mod.rs | 2 +- examples/src/vdf/regular/prover.rs | 5 +++-- prover/src/lib.rs | 2 +- winterfell/src/lib.rs | 15 +++++++++------ 26 files changed, 63 insertions(+), 47 deletions(-) diff --git a/examples/src/fibonacci/fib2/mod.rs b/examples/src/fibonacci/fib2/mod.rs index 2fb45c582..3b1cce210 100644 --- a/examples/src/fibonacci/fib2/mod.rs +++ b/examples/src/fibonacci/fib2/mod.rs @@ -11,7 +11,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/fibonacci/fib2/prover.rs b/examples/src/fibonacci/fib2/prover.rs index 5ab260931..d14079758 100644 --- a/examples/src/fibonacci/fib2/prover.rs +++ b/examples/src/fibonacci/fib2/prover.rs @@ -4,8 +4,8 @@ // LICENSE file in the root directory of this source tree. use super::{ - BaseElement, DefaultRandomCoin, ElementHasher, FibAir, FieldElement, PhantomData, ProofOptions, - Prover, Trace, TraceTable, TRACE_WIDTH, + BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, FibAir, FieldElement, + PhantomData, ProofOptions, Prover, Trace, TraceTable, TRACE_WIDTH, }; // FIBONACCI PROVER @@ -57,6 +57,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> BaseElement { let last_step = trace.length() - 1; diff --git a/examples/src/fibonacci/fib8/mod.rs b/examples/src/fibonacci/fib8/mod.rs index c11af63fa..9cfdfa052 100644 --- a/examples/src/fibonacci/fib8/mod.rs +++ b/examples/src/fibonacci/fib8/mod.rs @@ -11,7 +11,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/fibonacci/fib8/prover.rs b/examples/src/fibonacci/fib8/prover.rs index d24f054c0..5c07e7c0b 100644 --- a/examples/src/fibonacci/fib8/prover.rs +++ b/examples/src/fibonacci/fib8/prover.rs @@ -4,8 +4,8 @@ // LICENSE file in the root directory of this source tree. use super::{ - BaseElement, DefaultRandomCoin, ElementHasher, Fib8Air, FieldElement, PhantomData, - ProofOptions, Prover, Trace, TraceTable, + BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, Fib8Air, FieldElement, + PhantomData, ProofOptions, Prover, Trace, TraceTable, }; // FIBONACCI PROVER @@ -72,6 +72,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> BaseElement { let last_step = trace.length() - 1; diff --git a/examples/src/fibonacci/fib_small/mod.rs b/examples/src/fibonacci/fib_small/mod.rs index e4807bb63..b5b9d990c 100644 --- a/examples/src/fibonacci/fib_small/mod.rs +++ b/examples/src/fibonacci/fib_small/mod.rs @@ -11,7 +11,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f64::BaseElement, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/fibonacci/fib_small/prover.rs b/examples/src/fibonacci/fib_small/prover.rs index f7d4f82d7..12df3c2eb 100644 --- a/examples/src/fibonacci/fib_small/prover.rs +++ b/examples/src/fibonacci/fib_small/prover.rs @@ -2,10 +2,9 @@ // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. - use super::{ - air::FibSmall, BaseElement, DefaultRandomCoin, ElementHasher, FieldElement, PhantomData, - ProofOptions, Prover, Trace, TraceTable, TRACE_WIDTH, + air::FibSmall, BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, FieldElement, + PhantomData, ProofOptions, Prover, Trace, TraceTable, TRACE_WIDTH, }; // FIBONACCI PROVER @@ -57,6 +56,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> BaseElement { let last_step = trace.length() - 1; diff --git a/examples/src/fibonacci/mulfib2/mod.rs b/examples/src/fibonacci/mulfib2/mod.rs index 6aa973f56..94e219454 100644 --- a/examples/src/fibonacci/mulfib2/mod.rs +++ b/examples/src/fibonacci/mulfib2/mod.rs @@ -11,7 +11,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/fibonacci/mulfib2/prover.rs b/examples/src/fibonacci/mulfib2/prover.rs index dc84c3b41..72e7fb74c 100644 --- a/examples/src/fibonacci/mulfib2/prover.rs +++ b/examples/src/fibonacci/mulfib2/prover.rs @@ -4,8 +4,8 @@ // LICENSE file in the root directory of this source tree. use super::{ - BaseElement, DefaultRandomCoin, ElementHasher, MulFib2Air, PhantomData, ProofOptions, Prover, - Trace, TraceTable, + BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, FieldElement, MulFib2Air, + PhantomData, ProofOptions, Prover, Trace, TraceTable, }; // FIBONACCI PROVER @@ -53,6 +53,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> BaseElement { let last_step = trace.length() - 1; diff --git a/examples/src/fibonacci/mulfib8/mod.rs b/examples/src/fibonacci/mulfib8/mod.rs index 868ac1bb4..1274538b4 100644 --- a/examples/src/fibonacci/mulfib8/mod.rs +++ b/examples/src/fibonacci/mulfib8/mod.rs @@ -11,7 +11,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/fibonacci/mulfib8/prover.rs b/examples/src/fibonacci/mulfib8/prover.rs index a45300b02..9fe4b9852 100644 --- a/examples/src/fibonacci/mulfib8/prover.rs +++ b/examples/src/fibonacci/mulfib8/prover.rs @@ -4,8 +4,8 @@ // LICENSE file in the root directory of this source tree. use super::{ - BaseElement, DefaultRandomCoin, ElementHasher, MulFib8Air, PhantomData, ProofOptions, Prover, - Trace, TraceTable, + BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, FieldElement, MulFib8Air, + PhantomData, ProofOptions, Prover, Trace, TraceTable, }; // FIBONACCI PROVER @@ -65,6 +65,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> BaseElement { let last_step = trace.length() - 1; diff --git a/examples/src/lamport/aggregate/mod.rs b/examples/src/lamport/aggregate/mod.rs index b4e934273..d81191fc8 100644 --- a/examples/src/lamport/aggregate/mod.rs +++ b/examples/src/lamport/aggregate/mod.rs @@ -13,7 +13,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, get_power_series, FieldElement, StarkField}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/lamport/aggregate/prover.rs b/examples/src/lamport/aggregate/prover.rs index c4d59c37f..49a3c0d48 100644 --- a/examples/src/lamport/aggregate/prover.rs +++ b/examples/src/lamport/aggregate/prover.rs @@ -4,9 +4,9 @@ // LICENSE file in the root directory of this source tree. use super::{ - get_power_series, rescue, BaseElement, DefaultRandomCoin, ElementHasher, FieldElement, - LamportAggregateAir, PhantomData, ProofOptions, Prover, PublicInputs, Signature, StarkField, - TraceTable, CYCLE_LENGTH, NUM_HASH_ROUNDS, SIG_CYCLE_LENGTH, TRACE_WIDTH, + get_power_series, rescue, BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, + FieldElement, LamportAggregateAir, PhantomData, ProofOptions, Prover, PublicInputs, Signature, + StarkField, TraceTable, CYCLE_LENGTH, NUM_HASH_ROUNDS, SIG_CYCLE_LENGTH, TRACE_WIDTH, }; #[cfg(feature = "concurrent")] @@ -97,6 +97,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, _trace: &Self::Trace) -> PublicInputs { self.pub_inputs.clone() diff --git a/examples/src/lamport/threshold/mod.rs b/examples/src/lamport/threshold/mod.rs index fe650f14d..84ffcab59 100644 --- a/examples/src/lamport/threshold/mod.rs +++ b/examples/src/lamport/threshold/mod.rs @@ -14,7 +14,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, get_power_series, FieldElement, StarkField}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod signature; diff --git a/examples/src/lamport/threshold/prover.rs b/examples/src/lamport/threshold/prover.rs index 1afaecc08..88ce906cb 100644 --- a/examples/src/lamport/threshold/prover.rs +++ b/examples/src/lamport/threshold/prover.rs @@ -4,9 +4,10 @@ // LICENSE file in the root directory of this source tree. use super::{ - get_power_series, rescue, AggPublicKey, BaseElement, DefaultRandomCoin, ElementHasher, - FieldElement, LamportThresholdAir, PhantomData, ProofOptions, Prover, PublicInputs, Signature, - StarkField, TraceTable, HASH_CYCLE_LENGTH, NUM_HASH_ROUNDS, SIG_CYCLE_LENGTH, TRACE_WIDTH, + get_power_series, rescue, AggPublicKey, BaseElement, DefaultRandomCoin, DefaultTraceLde, + ElementHasher, FieldElement, LamportThresholdAir, PhantomData, ProofOptions, Prover, + PublicInputs, Signature, StarkField, TraceTable, HASH_CYCLE_LENGTH, NUM_HASH_ROUNDS, + SIG_CYCLE_LENGTH, TRACE_WIDTH, }; use std::collections::HashMap; @@ -138,6 +139,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, _trace: &Self::Trace) -> PublicInputs { self.pub_inputs.clone() diff --git a/examples/src/merkle/mod.rs b/examples/src/merkle/mod.rs index 1e4a671a5..384363bfc 100644 --- a/examples/src/merkle/mod.rs +++ b/examples/src/merkle/mod.rs @@ -18,7 +18,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, Digest, ElementHasher, MerkleTree}, math::{fields::f128::BaseElement, FieldElement, StarkField}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/merkle/prover.rs b/examples/src/merkle/prover.rs index 135fd4eb5..86b0874d0 100644 --- a/examples/src/merkle/prover.rs +++ b/examples/src/merkle/prover.rs @@ -4,9 +4,9 @@ // LICENSE file in the root directory of this source tree. use super::{ - rescue, BaseElement, DefaultRandomCoin, ElementHasher, FieldElement, MerkleAir, PhantomData, - ProofOptions, Prover, PublicInputs, Trace, TraceTable, HASH_CYCLE_LEN, HASH_STATE_WIDTH, - NUM_HASH_ROUNDS, TRACE_WIDTH, + rescue, BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, FieldElement, + MerkleAir, PhantomData, ProofOptions, Prover, PublicInputs, Trace, TraceTable, HASH_CYCLE_LEN, + HASH_STATE_WIDTH, NUM_HASH_ROUNDS, TRACE_WIDTH, }; // MERKLE PROVER @@ -103,6 +103,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> PublicInputs { let last_step = trace.length() - 1; diff --git a/examples/src/rescue/mod.rs b/examples/src/rescue/mod.rs index 5333136de..e704304b6 100644 --- a/examples/src/rescue/mod.rs +++ b/examples/src/rescue/mod.rs @@ -10,7 +10,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; #[allow(clippy::module_inception)] diff --git a/examples/src/rescue/prover.rs b/examples/src/rescue/prover.rs index 88525c9d8..29acbdad2 100644 --- a/examples/src/rescue/prover.rs +++ b/examples/src/rescue/prover.rs @@ -4,8 +4,9 @@ // LICENSE file in the root directory of this source tree. use super::{ - rescue, BaseElement, DefaultRandomCoin, ElementHasher, FieldElement, PhantomData, ProofOptions, - Prover, PublicInputs, RescueAir, Trace, TraceTable, CYCLE_LENGTH, NUM_HASH_ROUNDS, + rescue, BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, FieldElement, + PhantomData, ProofOptions, Prover, PublicInputs, RescueAir, Trace, TraceTable, CYCLE_LENGTH, + NUM_HASH_ROUNDS, }; // RESCUE PROVER @@ -69,6 +70,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> PublicInputs { let last_step = trace.length() - 1; diff --git a/examples/src/rescue_raps/mod.rs b/examples/src/rescue_raps/mod.rs index b695113d8..4152d2bac 100644 --- a/examples/src/rescue_raps/mod.rs +++ b/examples/src/rescue_raps/mod.rs @@ -11,7 +11,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, ExtensionOf, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, VerifierError, }; mod custom_trace_table; diff --git a/examples/src/rescue_raps/prover.rs b/examples/src/rescue_raps/prover.rs index 57bb9f249..e6ab7c0eb 100644 --- a/examples/src/rescue_raps/prover.rs +++ b/examples/src/rescue_raps/prover.rs @@ -5,8 +5,8 @@ use super::{ apply_rescue_round_parallel, rescue::STATE_WIDTH, BaseElement, DefaultRandomCoin, - ElementHasher, FieldElement, PhantomData, ProofOptions, Prover, PublicInputs, RapTraceTable, - RescueRapsAir, Trace, CYCLE_LENGTH, NUM_HASH_ROUNDS, + DefaultTraceLde, ElementHasher, FieldElement, PhantomData, ProofOptions, Prover, PublicInputs, + RapTraceTable, RescueRapsAir, Trace, CYCLE_LENGTH, NUM_HASH_ROUNDS, }; // RESCUE PROVER @@ -98,6 +98,7 @@ where type Trace = RapTraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> PublicInputs { let last_step = trace.length() - 1; diff --git a/examples/src/vdf/exempt/mod.rs b/examples/src/vdf/exempt/mod.rs index fa3af0f5d..d19dab8fd 100644 --- a/examples/src/vdf/exempt/mod.rs +++ b/examples/src/vdf/exempt/mod.rs @@ -10,7 +10,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/vdf/exempt/prover.rs b/examples/src/vdf/exempt/prover.rs index c6970e5f7..bcd80e7c9 100644 --- a/examples/src/vdf/exempt/prover.rs +++ b/examples/src/vdf/exempt/prover.rs @@ -4,8 +4,8 @@ // LICENSE file in the root directory of this source tree. use super::{ - BaseElement, DefaultRandomCoin, ElementHasher, FieldElement, PhantomData, ProofOptions, Prover, - Trace, TraceTable, VdfAir, VdfInputs, FORTY_TWO, INV_ALPHA, + BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, FieldElement, PhantomData, + ProofOptions, Prover, Trace, TraceTable, VdfAir, VdfInputs, FORTY_TWO, INV_ALPHA, }; // VDF PROVER @@ -50,6 +50,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> VdfInputs { // the result is read from the second to last step because the last last step contains diff --git a/examples/src/vdf/regular/mod.rs b/examples/src/vdf/regular/mod.rs index 2d0f7d7a1..12875afa6 100644 --- a/examples/src/vdf/regular/mod.rs +++ b/examples/src/vdf/regular/mod.rs @@ -10,7 +10,7 @@ use std::time::Instant; use winterfell::{ crypto::{DefaultRandomCoin, ElementHasher}, math::{fields::f128::BaseElement, FieldElement}, - ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, + DefaultTraceLde, ProofOptions, Prover, StarkProof, Trace, TraceTable, VerifierError, }; mod air; diff --git a/examples/src/vdf/regular/prover.rs b/examples/src/vdf/regular/prover.rs index d7b5504ff..99462a753 100644 --- a/examples/src/vdf/regular/prover.rs +++ b/examples/src/vdf/regular/prover.rs @@ -4,8 +4,8 @@ // LICENSE file in the root directory of this source tree. use super::{ - BaseElement, DefaultRandomCoin, ElementHasher, FieldElement, PhantomData, ProofOptions, Prover, - Trace, TraceTable, VdfAir, VdfInputs, FORTY_TWO, INV_ALPHA, + BaseElement, DefaultRandomCoin, DefaultTraceLde, ElementHasher, FieldElement, PhantomData, + ProofOptions, Prover, Trace, TraceTable, VdfAir, VdfInputs, FORTY_TWO, INV_ALPHA, }; // VDF PROVER @@ -47,6 +47,7 @@ where type Trace = TraceTable; type HashFn = H; type RandomCoin = DefaultRandomCoin; + type TraceLde> = DefaultTraceLde; fn get_pub_inputs(&self, trace: &Self::Trace) -> VdfInputs { let last_step = trace.length() - 1; diff --git a/prover/src/lib.rs b/prover/src/lib.rs index efa8b3af2..c3ce3c3bf 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -86,7 +86,7 @@ mod composer; use composer::DeepCompositionPoly; mod trace; -pub use trace::{Trace, TraceLde, TracePolyTable, TraceTable, TraceTableFragment}; +pub use trace::{DefaultTraceLde, Trace, TraceLde, TracePolyTable, TraceTable, TraceTableFragment}; mod channel; use channel::ProverChannel; diff --git a/winterfell/src/lib.rs b/winterfell/src/lib.rs index b3eeb9f40..d580e5510 100644 --- a/winterfell/src/lib.rs +++ b/winterfell/src/lib.rs @@ -255,8 +255,9 @@ //! //! ```no_run //! use winterfell::{ +//! crypto::{hashers::Blake3_256, DefaultRandomCoin}, //! math::{fields::f128::BaseElement, FieldElement, ToElements}, -//! ProofOptions, Prover, Trace, TraceTable, crypto::{hashers::Blake3_256, DefaultRandomCoin} +//! DefaultTraceLde, ProofOptions, Prover, Trace, TraceTable, //! }; //! //! # use winterfell::{ @@ -340,6 +341,7 @@ //! type Trace = TraceTable; //! type HashFn = Blake3_256; //! type RandomCoin = DefaultRandomCoin; +//! type TraceLde> = DefaultTraceLde; //! //! // Our public inputs consist of the first and last value in the execution trace. //! fn get_pub_inputs(&self, trace: &Self::Trace) -> PublicInputs { @@ -366,7 +368,7 @@ //! ``` //! # use winterfell::{ //! # math::{fields::f128::BaseElement, FieldElement, ToElements}, -//! # Air, AirContext, Assertion, ByteWriter, EvaluationFrame, TraceInfo, +//! # Air, AirContext, Assertion, ByteWriter, DefaultTraceLde, EvaluationFrame, TraceInfo, //! # TransitionConstraintDegree, TraceTable, FieldExtension, Prover, ProofOptions, //! # StarkProof, Trace, crypto::{hashers::Blake3_256, DefaultRandomCoin}, //! # }; @@ -457,6 +459,7 @@ //! # type Trace = TraceTable; //! # type HashFn = Blake3_256; //! # type RandomCoin = DefaultRandomCoin; +//! # type TraceLde> = DefaultTraceLde; //! # //! # fn get_pub_inputs(&self, trace: &Self::Trace) -> PublicInputs { //! # let last_step = trace.length() - 1; @@ -531,9 +534,9 @@ pub use prover::{ crypto, iterators, math, Air, AirContext, Assertion, AuxTraceRandElements, BoundaryConstraint, BoundaryConstraintGroup, ByteReader, ByteWriter, ColMatrix, ConstraintCompositionCoefficients, - ConstraintDivisor, DeepCompositionCoefficients, Deserializable, DeserializationError, - EvaluationFrame, FieldExtension, ProofOptions, Prover, ProverError, Serializable, SliceReader, - StarkProof, Trace, TraceInfo, TraceLayout, TraceTable, TraceTableFragment, - TransitionConstraintDegree, + ConstraintDivisor, DeepCompositionCoefficients, DefaultTraceLde, Deserializable, + DeserializationError, EvaluationFrame, FieldExtension, ProofOptions, Prover, ProverError, + Serializable, SliceReader, StarkProof, Trace, TraceInfo, TraceLayout, TraceLde, TraceTable, + TraceTableFragment, TransitionConstraintDegree, }; pub use verifier::{verify, VerifierError}; From 68c378dbabca9484880ae2b783dd16cba227b56c Mon Sep 17 00:00:00 2001 From: grjte Date: Fri, 21 Jul 2023 15:15:05 +0200 Subject: [PATCH 4/5] tests: add unit tests for DefaultTraceLde --- .../trace_lde/{default.rs => default/mod.rs} | 27 +++++ prover/src/trace/trace_lde/default/tests.rs | 107 ++++++++++++++++++ 2 files changed, 134 insertions(+) rename prover/src/trace/trace_lde/{default.rs => default/mod.rs} (91%) create mode 100644 prover/src/trace/trace_lde/default/tests.rs diff --git a/prover/src/trace/trace_lde/default.rs b/prover/src/trace/trace_lde/default/mod.rs similarity index 91% rename from prover/src/trace/trace_lde/default.rs rename to prover/src/trace/trace_lde/default/mod.rs index de12c59b1..7bde3dfdd 100644 --- a/prover/src/trace/trace_lde/default.rs +++ b/prover/src/trace/trace_lde/default/mod.rs @@ -15,6 +15,9 @@ use log::debug; #[cfg(feature = "std")] use std::time::Instant; +#[cfg(test)] +mod tests; + // TRACE LOW DEGREE EXTENSION // ================================================================================================ /// Contains all segments of the extended execution trace, the commitments to these segments, the @@ -38,6 +41,30 @@ pub struct DefaultTraceLde> DefaultTraceLde { + // TEST HELPERS + // -------------------------------------------------------------------------------------------- + + /// Returns number of columns in the main segment of the execution trace. + pub fn main_segment_width(&self) -> usize { + self.main_segment_lde.num_cols() + } + + /// Returns a reference to [Matrix] representing the main trace segment. + pub fn get_main_segment(&self) -> &RowMatrix { + &self.main_segment_lde + } + + /// Returns the entire trace for the column at the specified index. + #[cfg(test)] + pub fn get_main_segment_column(&self, col_idx: usize) -> Vec { + (0..self.main_segment_lde.num_rows()) + .map(|row_idx| self.main_segment_lde.get(col_idx, row_idx)) + .collect() + } +} + impl> TraceLde for DefaultTraceLde { diff --git a/prover/src/trace/trace_lde/default/tests.rs b/prover/src/trace/trace_lde/default/tests.rs new file mode 100644 index 000000000..ce53e1cc3 --- /dev/null +++ b/prover/src/trace/trace_lde/default/tests.rs @@ -0,0 +1,107 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use crate::{ + tests::{build_fib_trace, MockAir}, + DefaultTraceLde, StarkDomain, Trace, TraceLde, +}; +use crypto::{hashers::Blake3_256, ElementHasher, MerkleTree}; +use math::{ + fields::f128::BaseElement, get_power_series, get_power_series_with_offset, polynom, + FieldElement, StarkField, +}; +use utils::collections::Vec; + +type Blake3 = Blake3_256; + +#[test] +fn extend_trace_table() { + // build the trace and the domain + let trace_length = 8; + let air = MockAir::with_trace_length(trace_length); + let trace = build_fib_trace(trace_length * 2); + let domain = StarkDomain::new(&air); + + // build the trace polynomials, extended trace, and commitment using the default TraceLde impl + let (trace_polys, trace_lde) = DefaultTraceLde::::new( + &trace.get_info(), + trace.main_segment(), + &domain, + ); + + // check the width and length of the extended trace + assert_eq!(2, trace_lde.main_segment_width()); + assert_eq!(64, trace_lde.trace_len()); + + // make sure trace polynomials evaluate to Fibonacci trace + let trace_root = BaseElement::get_root_of_unity(trace_length.ilog2()); + let trace_domain = get_power_series(trace_root, trace_length); + assert_eq!(2, trace_polys.num_main_trace_polys()); + assert_eq!( + vec![1u32, 2, 5, 13, 34, 89, 233, 610] + .into_iter() + .map(BaseElement::from) + .collect::>(), + polynom::eval_many(trace_polys.get_main_trace_poly(0), &trace_domain) + ); + assert_eq!( + vec![1u32, 3, 8, 21, 55, 144, 377, 987] + .into_iter() + .map(BaseElement::from) + .collect::>(), + polynom::eval_many(trace_polys.get_main_trace_poly(1), &trace_domain) + ); + + // make sure column values are consistent with trace polynomials + let lde_domain = build_lde_domain(domain.lde_domain_size()); + assert_eq!( + trace_polys.get_main_trace_poly(0), + polynom::interpolate(&lde_domain, &trace_lde.get_main_segment_column(0), true) + ); + assert_eq!( + trace_polys.get_main_trace_poly(1), + polynom::interpolate(&lde_domain, &trace_lde.get_main_segment_column(1), true) + ); +} + +#[test] +fn commit_trace_table() { + // build the trace and the domain + let trace_length = 8; + let air = MockAir::with_trace_length(trace_length); + let trace = build_fib_trace(trace_length * 2); + let domain = StarkDomain::new(&air); + + // build the trace polynomials, extended trace, and commitment using the default TraceLde impl + let (_, trace_lde) = DefaultTraceLde::::new( + &trace.get_info(), + trace.main_segment(), + &domain, + ); + + // build Merkle tree from trace rows + let mut hashed_states = Vec::new(); + let mut trace_state = vec![BaseElement::ZERO; trace_lde.main_segment_width()]; + #[allow(clippy::needless_range_loop)] + for i in 0..trace_lde.trace_len() { + for j in 0..trace_lde.main_segment_width() { + trace_state[j] = trace_lde.get_main_segment().get(j, i); + } + let buf = Blake3::hash_elements(&trace_state); + hashed_states.push(buf); + } + let expected_tree = MerkleTree::::new(hashed_states).unwrap(); + + // compare the result + assert_eq!(*expected_tree.root(), trace_lde.get_main_trace_commitment()) +} + +// HELPER FUNCTIONS +// ================================================================================================ + +fn build_lde_domain(domain_size: usize) -> Vec { + let g = B::get_root_of_unity(domain_size.ilog2()); + get_power_series_with_offset(g, B::GENERATOR, domain_size) +} From 3eb276f728b4db34724c03051cfa2b29a0b069fa Mon Sep 17 00:00:00 2001 From: grjte Date: Mon, 24 Jul 2023 17:36:42 +0100 Subject: [PATCH 5/5] docs: update changelog with TraceLde trait --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d924dfd12..79e456e79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## 0.7.0 (TBD) +* [BREAKING] replaced the `TraceLde` struct with a trait (#207). + ## 0.6.4 (2023-05-26) * Simplified construction of constraint composition polynomial (#198). * Refactored serialization of OOD frame in STARK proofs (#199).