Skip to content

Commit

Permalink
Refactor TraceOodFrame (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
plafer authored and irakliyk committed May 9, 2024
1 parent 00e47cd commit 27f2574
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 201 deletions.
2 changes: 1 addition & 1 deletion air/src/proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ mod queries;
pub use queries::Queries;

mod ood_frame;
pub use ood_frame::{OodFrame, OodFrameTraceStates, ParsedOodFrame};
pub use ood_frame::{OodFrame, TraceOodFrame};

mod table;
pub use table::Table;
Expand Down
202 changes: 134 additions & 68 deletions air/src/proof/ood_frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,17 @@
// LICENSE file in the root directory of this source tree.

use alloc::vec::Vec;
use crypto::ElementHasher;
use math::FieldElement;
use utils::{
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable, SliceReader,
};

use crate::LagrangeKernelEvaluationFrame;
use crate::{EvaluationFrame, LagrangeKernelEvaluationFrame};

// OUT-OF-DOMAIN FRAME
// ================================================================================================

/// Represents an [`OodFrame`] where the trace and constraint evaluations have been parsed out.
pub struct ParsedOodFrame<E> {
pub trace_evaluations: Vec<E>,
pub lagrange_kernel_trace_evaluations: Option<Vec<E>>,
pub constraint_evaluations: Vec<E>,
}

/// Trace and constraint polynomial evaluations at an out-of-domain point.
///
/// This struct contains the following evaluations:
Expand All @@ -45,49 +39,53 @@ impl OodFrame {
// UPDATERS
// --------------------------------------------------------------------------------------------

/// Updates the trace state portion of this out-of-domain frame. This also returns a compacted
/// version of the out-of-domain frame (including the Lagrange kernel frame, if any) with the
/// rows interleaved. This is done so that reseeding of the random coin needs to be done only
/// once as opposed to once per each row.
/// Updates the trace state portion of this out-of-domain frame, and returns the hash of the
/// trace states.
///
/// The out-of-domain frame is stored as one vector of interleaved values, one from the current
/// row and the other from the next row. Given the input frame
///
/// +-------+-------+-------+-------+-------+-------+-------+-------+
/// | a1 | a2 | ... | an | c1 | c2 | ... | cm |
/// +-------+-------+-------+-------+-------+-------+-------+-------+
/// | b1 | b2 | ... | bn | d1 | d2 | ... | dm |
/// +-------+-------+-------+-------+-------+-------+-------+-------+
///
/// with n being the main trace width and m the auxiliary trace width, the values are stored as
///
/// [a1, b1, a2, b2, ..., an, bn, c1, d1, c2, d2, ..., cm, dm]
///
/// into `Self::trace_states` (as byte values).
///
/// # Panics
/// Panics if evaluation frame has already been set.
pub fn set_trace_states<E: FieldElement>(
&mut self,
trace_states: &OodFrameTraceStates<E>,
) -> Vec<E> {
pub fn set_trace_states<E, H>(&mut self, trace_ood_frame: &TraceOodFrame<E>) -> H::Digest
where
E: FieldElement,
H: ElementHasher<BaseField = E::BaseField>,
{
assert!(self.trace_states.is_empty(), "trace sates have already been set");

// save the evaluations with the current and next evaluations interleaved for each polynomial

let mut result = vec![];
for col in 0..trace_states.num_columns() {
result.push(trace_states.current_row[col]);
result.push(trace_states.next_row[col]);
}
let (main_and_aux_trace_states, lagrange_trace_states) = trace_ood_frame.to_trace_states();

// there are 2 frames: current and next
let frame_size: u8 = 2;

self.trace_states.write_u8(frame_size);
self.trace_states.write_many(&result);
self.trace_states.write_many(&main_and_aux_trace_states);

// save the Lagrange kernel evaluation frame (if any)
let lagrange_trace_states = {
let lagrange_trace_states = match trace_states.lagrange_kernel_frame {
Some(ref lagrange_trace_states) => lagrange_trace_states.inner().to_vec(),
None => Vec::new(),
};

{
// trace states length will be smaller than u8::MAX, since it is `== log2(trace_len) + 1`
debug_assert!(lagrange_trace_states.len() < u8::MAX.into());
self.lagrange_kernel_trace_states.write_u8(lagrange_trace_states.len() as u8);
self.lagrange_kernel_trace_states.write_many(&lagrange_trace_states);

lagrange_trace_states
};

result.into_iter().chain(lagrange_trace_states).collect()
let elements_to_hash: Vec<E> =
main_and_aux_trace_states.into_iter().chain(lagrange_trace_states).collect();

H::hash_elements(&elements_to_hash)
}

/// Updates constraint evaluation portion of this out-of-domain frame.
Expand All @@ -104,16 +102,16 @@ impl OodFrame {

// PARSER
// --------------------------------------------------------------------------------------------
/// Returns main and auxiliary (if any) trace evaluation frames and a vector of out-of-domain
/// constraint evaluations contained in `self`.
/// Returns an out-of-domain trace frame and a vector of out-of-domain constraint evaluations
/// contained in `self`.
///
/// # Panics
/// Panics if either `main_trace_width` or `num_evaluations` are equal to zero.
///
/// # Errors
/// Returns an error if:
/// * Valid [`crate::EvaluationFrame`]s for the specified `main_trace_width` and `aux_trace_width`
/// could not be parsed from the internal bytes.
/// * Valid [`crate::EvaluationFrame`]s for the specified `main_trace_width` and
/// `aux_trace_width` could not be parsed from the internal bytes.
/// * A vector of evaluations specified by `num_evaluations` could not be parsed from the
/// internal bytes.
/// * Any unconsumed bytes remained after the parsing was complete.
Expand All @@ -122,24 +120,39 @@ impl OodFrame {
main_trace_width: usize,
aux_trace_width: usize,
num_evaluations: usize,
) -> Result<ParsedOodFrame<E>, DeserializationError> {
) -> Result<(TraceOodFrame<E>, Vec<E>), DeserializationError> {
assert!(main_trace_width > 0, "trace width cannot be zero");
assert!(num_evaluations > 0, "number of evaluations cannot be zero");

// parse main and auxiliary trace evaluation frames
let mut reader = SliceReader::new(&self.trace_states);
let frame_size = reader.read_u8()? as usize;
let trace = reader.read_many((main_trace_width + aux_trace_width) * frame_size)?;
// Parse main and auxiliary trace evaluation frames. This does the reverse operation done in
// `set_trace_states()`.
let (current_row, next_row) = {
let mut reader = SliceReader::new(&self.trace_states);
let frame_size = reader.read_u8()? as usize;
let trace = reader.read_many((main_trace_width + aux_trace_width) * frame_size)?;

if reader.has_more_bytes() {
return Err(DeserializationError::UnconsumedBytes);
}
if reader.has_more_bytes() {
return Err(DeserializationError::UnconsumedBytes);
}

let mut current_row = Vec::with_capacity(main_trace_width);
let mut next_row = Vec::with_capacity(main_trace_width);

for col in trace.chunks_exact(2) {
current_row.push(col[0]);
next_row.push(col[1]);
}

(current_row, next_row)
};

// parse Lagrange kernel column trace
let mut reader = SliceReader::new(&self.lagrange_kernel_trace_states);
let lagrange_kernel_frame_size = reader.read_u8()? as usize;
let lagrange_kernel_trace = if lagrange_kernel_frame_size > 0 {
Some(reader.read_many(lagrange_kernel_frame_size)?)
let lagrange_kernel_frame = if lagrange_kernel_frame_size > 0 {
let lagrange_kernel_trace = reader.read_many(lagrange_kernel_frame_size)?;

Some(LagrangeKernelEvaluationFrame::new(lagrange_kernel_trace))
} else {
None
};
Expand All @@ -151,11 +164,10 @@ impl OodFrame {
return Err(DeserializationError::UnconsumedBytes);
}

Ok(ParsedOodFrame {
trace_evaluations: trace,
lagrange_kernel_trace_evaluations: lagrange_kernel_trace,
constraint_evaluations: evaluations,
})
Ok((
TraceOodFrame::new(current_row, next_row, main_trace_width, lagrange_kernel_frame),
evaluations,
))
}
}

Expand Down Expand Up @@ -213,28 +225,31 @@ impl Deserializable for OodFrame {
// OOD FRAME TRACE STATES
// ================================================================================================

/// Stores the trace evaluations at `z` and `gz`, where `z` is a random Field element. If
/// the Air contains a Lagrange kernel auxiliary column, then that column interpolated polynomial
/// will be evaluated at `z`, `gz`, `g^2 z`, ... `g^(2^(v-1)) z`, where `v == log(trace_len)`, and
/// stored in `lagrange_kernel_frame`.
pub struct OodFrameTraceStates<E: FieldElement> {
/// Stores the trace evaluations at `z` and `gz`, where `z` is a random Field element in
/// `current_row` and `next_row`, respectively. If the Air contains a Lagrange kernel auxiliary
/// column, then that column interpolated polynomial will be evaluated at `z`, `gz`, `g^2 z`, ...
/// `g^(2^(v-1)) z`, where `v == log(trace_len)`, and stored in `lagrange_kernel_frame`.
pub struct TraceOodFrame<E: FieldElement> {
current_row: Vec<E>,
next_row: Vec<E>,
main_trace_width: usize,
lagrange_kernel_frame: Option<LagrangeKernelEvaluationFrame<E>>,
}

impl<E: FieldElement> OodFrameTraceStates<E> {
/// Creates a new [`OodFrameTraceStates`] from current, next and optionally Lagrange kernel frames.
impl<E: FieldElement> TraceOodFrame<E> {
/// Creates a new [`TraceOodFrame`] from current, next and optionally Lagrange kernel frames.
pub fn new(
current_frame: Vec<E>,
next_frame: Vec<E>,
current_row: Vec<E>,
next_row: Vec<E>,
main_trace_width: usize,
lagrange_kernel_frame: Option<LagrangeKernelEvaluationFrame<E>>,
) -> Self {
assert_eq!(current_frame.len(), next_frame.len());
assert_eq!(current_row.len(), next_row.len());

Self {
current_row: current_frame,
next_row: next_frame,
current_row,
next_row,
main_trace_width,
lagrange_kernel_frame,
}
}
Expand All @@ -244,18 +259,69 @@ impl<E: FieldElement> OodFrameTraceStates<E> {
self.current_row.len()
}

/// Returns the current frame.
pub fn current_frame(&self) -> &[E] {
/// Returns the current row, consisting of both main and auxiliary columns.
pub fn current_row(&self) -> &[E] {
&self.current_row
}

/// Returns the next frame.
pub fn next_frame(&self) -> &[E] {
/// Returns the next frame, consisting of both main and auxiliary columns.
pub fn next_row(&self) -> &[E] {
&self.next_row
}

/// Returns the evaluation frame for the main trace
pub fn main_frame(&self) -> EvaluationFrame<E> {
let current = self.current_row[0..self.main_trace_width].to_vec();
let next = self.next_row[0..self.main_trace_width].to_vec();

EvaluationFrame::from_rows(current, next)
}

/// Returns the evaluation frame for the auxiliary trace
pub fn aux_frame(&self) -> Option<EvaluationFrame<E>> {
if self.has_aux_frame() {
let current = self.current_row[self.main_trace_width..].to_vec();
let next = self.next_row[self.main_trace_width..].to_vec();

Some(EvaluationFrame::from_rows(current, next))
} else {
None
}
}

/// Hashes the main, auxiliary and Lagrange kernel frame in a manner consistent with
/// [`OodFrame::set_trace_states`], with the purpose of reseeding the public coin.
pub fn hash<H: ElementHasher<BaseField = E::BaseField>>(&self) -> H::Digest {
let (mut trace_states, mut lagrange_trace_states) = self.to_trace_states();
trace_states.append(&mut lagrange_trace_states);

H::hash_elements(&trace_states)
}

/// Returns the Lagrange kernel frame, if any.
pub fn lagrange_kernel_frame(&self) -> Option<&LagrangeKernelEvaluationFrame<E>> {
self.lagrange_kernel_frame.as_ref()
}

/// Returns true if an auxiliary frame is present
fn has_aux_frame(&self) -> bool {
self.current_row.len() > self.main_trace_width
}

/// Returns the main/aux frame and Lagrange kernel frame as element vectors. Specifically, the
/// main and auxiliary frames are interleaved, as described in [`OodFrame::set_trace_states`].
fn to_trace_states(&self) -> (Vec<E>, Vec<E>) {
let mut main_and_aux_frame_states = Vec::new();
for col in 0..self.current_row.len() {
main_and_aux_frame_states.push(self.current_row[col]);
main_and_aux_frame_states.push(self.next_row[col]);
}

let lagrange_frame_states = match self.lagrange_kernel_frame {
Some(ref lagrange_kernel_frame) => lagrange_kernel_frame.inner().to_vec(),
None => Vec::new(),
};

(main_and_aux_frame_states, lagrange_frame_states)
}
}
8 changes: 4 additions & 4 deletions prover/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// LICENSE file in the root directory of this source tree.

use air::{
proof::{Commitments, Context, OodFrame, OodFrameTraceStates, Queries, StarkProof},
proof::{Commitments, Context, OodFrame, Queries, StarkProof, TraceOodFrame},
Air, ConstraintCompositionCoefficients, DeepCompositionCoefficients,
};
use alloc::vec::Vec;
Expand Down Expand Up @@ -85,9 +85,9 @@ where

/// Saves the evaluations of trace polynomials over the out-of-domain evaluation frame. This
/// also reseeds the public coin with the hashes of the evaluation frame states.
pub fn send_ood_trace_states(&mut self, trace_states: &OodFrameTraceStates<E>) {
let result = self.ood_frame.set_trace_states(trace_states);
self.public_coin.reseed(H::hash_elements(&result));
pub fn send_ood_trace_states(&mut self, trace_ood_frame: &TraceOodFrame<E>) {
let trace_states_hash = self.ood_frame.set_trace_states::<E, H>(trace_ood_frame);
self.public_coin.reseed(trace_states_hash);
}

/// Saves the evaluations of constraint composition polynomial columns at the out-of-domain
Expand Down
12 changes: 6 additions & 6 deletions prover/src/composer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// LICENSE file in the root directory of this source tree.

use super::{constraints::CompositionPoly, StarkDomain, TracePolyTable};
use air::{proof::OodFrameTraceStates, DeepCompositionCoefficients};
use air::{proof::TraceOodFrame, DeepCompositionCoefficients};
use alloc::vec::Vec;
use math::{add_in_place, fft, mul_acc, polynom, ExtensionOf, FieldElement, StarkField};
use utils::iter_mut;
Expand Down Expand Up @@ -64,7 +64,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
pub fn add_trace_polys(
&mut self,
trace_polys: TracePolyTable<E>,
ood_trace_states: OodFrameTraceStates<E>,
ood_trace_states: TraceOodFrame<E>,
) {
assert!(self.coefficients.is_empty());

Expand All @@ -89,7 +89,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
acc_trace_poly::<E::BaseField, E>(
&mut t1_composition,
poly,
ood_trace_states.current_frame()[i],
ood_trace_states.current_row()[i],
self.cc.trace[i],
);

Expand All @@ -98,7 +98,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
acc_trace_poly::<E::BaseField, E>(
&mut t2_composition,
poly,
ood_trace_states.next_frame()[i],
ood_trace_states.next_row()[i],
self.cc.trace[i],
);

Expand All @@ -112,7 +112,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
acc_trace_poly::<E, E>(
&mut t1_composition,
poly,
ood_trace_states.current_frame()[i],
ood_trace_states.current_row()[i],
self.cc.trace[i],
);

Expand All @@ -121,7 +121,7 @@ impl<E: FieldElement> DeepCompositionPoly<E> {
acc_trace_poly::<E, E>(
&mut t2_composition,
poly,
ood_trace_states.next_frame()[i],
ood_trace_states.next_row()[i],
self.cc.trace[i],
);

Expand Down
Loading

0 comments on commit 27f2574

Please sign in to comment.