diff --git a/prover/src/trace/trace_lde/default.rs b/prover/src/trace/trace_lde/default/mod.rs similarity index 91% rename from prover/src/trace/trace_lde/default.rs rename to prover/src/trace/trace_lde/default/mod.rs index de12c59b1..7bde3dfdd 100644 --- a/prover/src/trace/trace_lde/default.rs +++ b/prover/src/trace/trace_lde/default/mod.rs @@ -15,6 +15,9 @@ use log::debug; #[cfg(feature = "std")] use std::time::Instant; +#[cfg(test)] +mod tests; + // TRACE LOW DEGREE EXTENSION // ================================================================================================ /// Contains all segments of the extended execution trace, the commitments to these segments, the @@ -38,6 +41,30 @@ pub struct DefaultTraceLde> DefaultTraceLde { + // TEST HELPERS + // -------------------------------------------------------------------------------------------- + + /// Returns number of columns in the main segment of the execution trace. + pub fn main_segment_width(&self) -> usize { + self.main_segment_lde.num_cols() + } + + /// Returns a reference to [Matrix] representing the main trace segment. + pub fn get_main_segment(&self) -> &RowMatrix { + &self.main_segment_lde + } + + /// Returns the entire trace for the column at the specified index. + #[cfg(test)] + pub fn get_main_segment_column(&self, col_idx: usize) -> Vec { + (0..self.main_segment_lde.num_rows()) + .map(|row_idx| self.main_segment_lde.get(col_idx, row_idx)) + .collect() + } +} + impl> TraceLde for DefaultTraceLde { diff --git a/prover/src/trace/trace_lde/default/tests.rs b/prover/src/trace/trace_lde/default/tests.rs new file mode 100644 index 000000000..ce53e1cc3 --- /dev/null +++ b/prover/src/trace/trace_lde/default/tests.rs @@ -0,0 +1,107 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use crate::{ + tests::{build_fib_trace, MockAir}, + DefaultTraceLde, StarkDomain, Trace, TraceLde, +}; +use crypto::{hashers::Blake3_256, ElementHasher, MerkleTree}; +use math::{ + fields::f128::BaseElement, get_power_series, get_power_series_with_offset, polynom, + FieldElement, StarkField, +}; +use utils::collections::Vec; + +type Blake3 = Blake3_256; + +#[test] +fn extend_trace_table() { + // build the trace and the domain + let trace_length = 8; + let air = MockAir::with_trace_length(trace_length); + let trace = build_fib_trace(trace_length * 2); + let domain = StarkDomain::new(&air); + + // build the trace polynomials, extended trace, and commitment using the default TraceLde impl + let (trace_polys, trace_lde) = DefaultTraceLde::::new( + &trace.get_info(), + trace.main_segment(), + &domain, + ); + + // check the width and length of the extended trace + assert_eq!(2, trace_lde.main_segment_width()); + assert_eq!(64, trace_lde.trace_len()); + + // make sure trace polynomials evaluate to Fibonacci trace + let trace_root = BaseElement::get_root_of_unity(trace_length.ilog2()); + let trace_domain = get_power_series(trace_root, trace_length); + assert_eq!(2, trace_polys.num_main_trace_polys()); + assert_eq!( + vec![1u32, 2, 5, 13, 34, 89, 233, 610] + .into_iter() + .map(BaseElement::from) + .collect::>(), + polynom::eval_many(trace_polys.get_main_trace_poly(0), &trace_domain) + ); + assert_eq!( + vec![1u32, 3, 8, 21, 55, 144, 377, 987] + .into_iter() + .map(BaseElement::from) + .collect::>(), + polynom::eval_many(trace_polys.get_main_trace_poly(1), &trace_domain) + ); + + // make sure column values are consistent with trace polynomials + let lde_domain = build_lde_domain(domain.lde_domain_size()); + assert_eq!( + trace_polys.get_main_trace_poly(0), + polynom::interpolate(&lde_domain, &trace_lde.get_main_segment_column(0), true) + ); + assert_eq!( + trace_polys.get_main_trace_poly(1), + polynom::interpolate(&lde_domain, &trace_lde.get_main_segment_column(1), true) + ); +} + +#[test] +fn commit_trace_table() { + // build the trace and the domain + let trace_length = 8; + let air = MockAir::with_trace_length(trace_length); + let trace = build_fib_trace(trace_length * 2); + let domain = StarkDomain::new(&air); + + // build the trace polynomials, extended trace, and commitment using the default TraceLde impl + let (_, trace_lde) = DefaultTraceLde::::new( + &trace.get_info(), + trace.main_segment(), + &domain, + ); + + // build Merkle tree from trace rows + let mut hashed_states = Vec::new(); + let mut trace_state = vec![BaseElement::ZERO; trace_lde.main_segment_width()]; + #[allow(clippy::needless_range_loop)] + for i in 0..trace_lde.trace_len() { + for j in 0..trace_lde.main_segment_width() { + trace_state[j] = trace_lde.get_main_segment().get(j, i); + } + let buf = Blake3::hash_elements(&trace_state); + hashed_states.push(buf); + } + let expected_tree = MerkleTree::::new(hashed_states).unwrap(); + + // compare the result + assert_eq!(*expected_tree.root(), trace_lde.get_main_trace_commitment()) +} + +// HELPER FUNCTIONS +// ================================================================================================ + +fn build_lde_domain(domain_size: usize) -> Vec { + let g = B::get_root_of_unity(domain_size.ilog2()); + get_power_series_with_offset(g, B::GENERATOR, domain_size) +}