Skip to content

Commit

Permalink
tests: add unit tests for DefaultTraceLde
Browse files Browse the repository at this point in the history
  • Loading branch information
grjte committed Jul 24, 2023
1 parent 31d6078 commit 68c378d
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
Expand Up @@ -15,6 +15,9 @@ use log::debug;
#[cfg(feature = "std")]
use std::time::Instant;

#[cfg(test)]
mod tests;

// TRACE LOW DEGREE EXTENSION
// ================================================================================================
/// Contains all segments of the extended execution trace, the commitments to these segments, the
Expand All @@ -38,6 +41,30 @@ pub struct DefaultTraceLde<E: FieldElement, H: ElementHasher<BaseField = E::Base
trace_info: TraceInfo,
}

#[cfg(test)]
impl<E: FieldElement, H: ElementHasher<BaseField = E::BaseField>> DefaultTraceLde<E, H> {
// TEST HELPERS
// --------------------------------------------------------------------------------------------

/// Returns number of columns in the main segment of the execution trace.
pub fn main_segment_width(&self) -> usize {
self.main_segment_lde.num_cols()
}

/// Returns a reference to [Matrix] representing the main trace segment.
pub fn get_main_segment(&self) -> &RowMatrix<E::BaseField> {
&self.main_segment_lde
}

/// Returns the entire trace for the column at the specified index.
#[cfg(test)]
pub fn get_main_segment_column(&self, col_idx: usize) -> Vec<E::BaseField> {
(0..self.main_segment_lde.num_rows())
.map(|row_idx| self.main_segment_lde.get(col_idx, row_idx))
.collect()
}
}

impl<E: FieldElement, H: ElementHasher<BaseField = E::BaseField>> TraceLde
for DefaultTraceLde<E, H>
{
Expand Down
107 changes: 107 additions & 0 deletions prover/src/trace/trace_lde/default/tests.rs
@@ -0,0 +1,107 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

use crate::{
tests::{build_fib_trace, MockAir},
DefaultTraceLde, StarkDomain, Trace, TraceLde,
};
use crypto::{hashers::Blake3_256, ElementHasher, MerkleTree};
use math::{
fields::f128::BaseElement, get_power_series, get_power_series_with_offset, polynom,
FieldElement, StarkField,
};
use utils::collections::Vec;

type Blake3 = Blake3_256<BaseElement>;

#[test]
fn extend_trace_table() {
// build the trace and the domain
let trace_length = 8;
let air = MockAir::with_trace_length(trace_length);
let trace = build_fib_trace(trace_length * 2);
let domain = StarkDomain::new(&air);

// build the trace polynomials, extended trace, and commitment using the default TraceLde impl
let (trace_polys, trace_lde) = DefaultTraceLde::<BaseElement, Blake3>::new(
&trace.get_info(),
trace.main_segment(),
&domain,
);

// check the width and length of the extended trace
assert_eq!(2, trace_lde.main_segment_width());
assert_eq!(64, trace_lde.trace_len());

// make sure trace polynomials evaluate to Fibonacci trace
let trace_root = BaseElement::get_root_of_unity(trace_length.ilog2());
let trace_domain = get_power_series(trace_root, trace_length);
assert_eq!(2, trace_polys.num_main_trace_polys());
assert_eq!(
vec![1u32, 2, 5, 13, 34, 89, 233, 610]
.into_iter()
.map(BaseElement::from)
.collect::<Vec<BaseElement>>(),
polynom::eval_many(trace_polys.get_main_trace_poly(0), &trace_domain)
);
assert_eq!(
vec![1u32, 3, 8, 21, 55, 144, 377, 987]
.into_iter()
.map(BaseElement::from)
.collect::<Vec<BaseElement>>(),
polynom::eval_many(trace_polys.get_main_trace_poly(1), &trace_domain)
);

// make sure column values are consistent with trace polynomials
let lde_domain = build_lde_domain(domain.lde_domain_size());
assert_eq!(
trace_polys.get_main_trace_poly(0),
polynom::interpolate(&lde_domain, &trace_lde.get_main_segment_column(0), true)
);
assert_eq!(
trace_polys.get_main_trace_poly(1),
polynom::interpolate(&lde_domain, &trace_lde.get_main_segment_column(1), true)
);
}

#[test]
fn commit_trace_table() {
// build the trace and the domain
let trace_length = 8;
let air = MockAir::with_trace_length(trace_length);
let trace = build_fib_trace(trace_length * 2);
let domain = StarkDomain::new(&air);

// build the trace polynomials, extended trace, and commitment using the default TraceLde impl
let (_, trace_lde) = DefaultTraceLde::<BaseElement, Blake3>::new(
&trace.get_info(),
trace.main_segment(),
&domain,
);

// build Merkle tree from trace rows
let mut hashed_states = Vec::new();
let mut trace_state = vec![BaseElement::ZERO; trace_lde.main_segment_width()];
#[allow(clippy::needless_range_loop)]
for i in 0..trace_lde.trace_len() {
for j in 0..trace_lde.main_segment_width() {
trace_state[j] = trace_lde.get_main_segment().get(j, i);
}
let buf = Blake3::hash_elements(&trace_state);
hashed_states.push(buf);
}
let expected_tree = MerkleTree::<Blake3>::new(hashed_states).unwrap();

// compare the result
assert_eq!(*expected_tree.root(), trace_lde.get_main_trace_commitment())
}

// HELPER FUNCTIONS
// ================================================================================================

fn build_lde_domain<B: StarkField>(domain_size: usize) -> Vec<B> {
let g = B::get_root_of_unity(domain_size.ilog2());
get_power_series_with_offset(g, B::GENERATOR, domain_size)
}

0 comments on commit 68c378d

Please sign in to comment.