Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: simplify some memoset generics #1045

Merged
merged 2 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/coprocessor/memoset/demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {
type CQ = DemoCircuitQuery<F>;

// DemoQuery and Scope depend on each other.
fn eval(&self, s: &Store<F>, scope: &mut Scope<F, Self, LogMemo<F>>) -> Ptr {
fn eval(&self, s: &Store<F>, scope: &mut Scope<Self, LogMemo<F>>) -> Ptr {
match self {
Self::Factorial(n) => {
let n_zptr = s.hash_ptr(n);
Expand All @@ -51,7 +51,7 @@ impl<F: LurkField> Query<F> for DemoQuery<F> {

fn recursive_eval(
&self,
scope: &mut Scope<F, Self, LogMemo<F>>,
scope: &mut Scope<Self, LogMemo<F>>,
s: &Store<F>,
subquery: Self,
) -> Ptr {
Expand Down Expand Up @@ -102,7 +102,7 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, Self, LogMemo<F>>,
scope: &mut CircuitScope<F, LogMemo<F>>,
acc: &AllocatedPtr<F>,
transcript: &CircuitTranscript<F>,
) -> Result<(AllocatedPtr<F>, AllocatedPtr<F>, CircuitTranscript<F>), SynthesisError> {
Expand Down Expand Up @@ -238,7 +238,7 @@ mod test {
#[test]
fn test_factorial() {
let s = Store::default();
let mut scope: Scope<F, DemoQuery<F>, LogMemo<F>> = Scope::default();
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> = Scope::default();
let zero = s.num(F::ZERO);
let one = s.num(F::ONE);
let two = s.num(F::from_u64(2));
Expand Down
36 changes: 16 additions & 20 deletions src/coprocessor/memoset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::tag::{ExprTag, Tag as XTag};
use crate::z_ptr::ZPtr;

use multiset::MultiSet;
use query::{CircuitQuery, Query};
pub use query::{CircuitQuery, Query};

mod demo;
mod multiset;
Expand Down Expand Up @@ -180,7 +180,7 @@ impl<F: LurkField> CircuitTranscript<F> {
/// A `Scope` tracks the queries made while evaluating, including the subqueries that result from evaluating other
/// queries -- then makes use of the bookkeeping performed at evaluation time to synthesize proof of each query
/// performed.
pub struct Scope<F, Q, M> {
pub struct Scope<Q, M> {
memoset: M,
/// k => v
queries: HashMap<Ptr, Ptr>,
Expand All @@ -192,10 +192,9 @@ pub struct Scope<F, Q, M> {
internal_insertions: Vec<Ptr>,
/// unique keys
all_insertions: Vec<Ptr>,
_p: PhantomData<F>,
}

impl<F: LurkField, Q> Default for Scope<F, Q, LogMemo<F>> {
impl<F: LurkField, Q> Default for Scope<Q, LogMemo<F>> {
fn default() -> Self {
Self {
memoset: Default::default(),
Expand All @@ -204,22 +203,20 @@ impl<F: LurkField, Q> Default for Scope<F, Q, LogMemo<F>> {
toplevel_insertions: Default::default(),
internal_insertions: Default::default(),
all_insertions: Default::default(),
_p: Default::default(),
}
}
}

pub struct CircuitScope<F: LurkField, CQ: CircuitQuery<F>, M: MemoSet<F>> {
pub struct CircuitScope<F: LurkField, M> {
memoset: M,
/// k -> v
queries: HashMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
/// k -> allocated v
transcript: CircuitTranscript<F>,
acc: Option<AllocatedPtr<F>>,
_p: PhantomData<CQ>,
}

impl<F: LurkField, Q: Query<F>> Scope<F, Q, LogMemo<F>> {
impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>> {
pub fn query(&mut self, s: &Store<F>, form: Ptr) -> Ptr {
let (response, kv_ptr) = self.query_aux(s, form);

Expand Down Expand Up @@ -357,12 +354,12 @@ impl<F: LurkField, Q: Query<F>> Scope<F, Q, LogMemo<F>> {
}
}

impl<F: LurkField, CQ: CircuitQuery<F>> CircuitScope<F, CQ, LogMemo<F>> {
fn from_scope<CS: ConstraintSystem<F>, Q: Query<F, CQ = CQ>>(
impl<F: LurkField> CircuitScope<F, LogMemo<F>> {
fn from_scope<CS: ConstraintSystem<F>, Q: Query<F>>(
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
scope: &Scope<F, Q, LogMemo<F>>,
scope: &Scope<Q, LogMemo<F>>,
) -> Self {
let queries = scope
.queries
Expand All @@ -374,7 +371,6 @@ impl<F: LurkField, CQ: CircuitQuery<F>> CircuitScope<F, CQ, LogMemo<F>> {
queries,
transcript: CircuitTranscript::new(cs, g, s),
acc: Default::default(),
_p: Default::default(),
}
}

Expand Down Expand Up @@ -496,9 +492,9 @@ impl<F: LurkField, CQ: CircuitQuery<F>> CircuitScope<F, CQ, LogMemo<F>> {
Ok((value, new_acc, new_insertion_transcript))
}

fn synthesize_insert_toplevel_queries<CS: ConstraintSystem<F>, Q: Query<F, CQ = CQ>>(
fn synthesize_insert_toplevel_queries<CS: ConstraintSystem<F>, Q: Query<F>>(
&mut self,
scope: &mut Scope<F, Q, LogMemo<F>>,
scope: &mut Scope<Q, LogMemo<F>>,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
Expand Down Expand Up @@ -546,20 +542,20 @@ impl<F: LurkField, CQ: CircuitQuery<F>> CircuitScope<F, CQ, LogMemo<F>> {
Ok(())
}

fn synthesize_prove_all_queries<CS: ConstraintSystem<F>, Q: Query<F, CQ = CQ>>(
fn synthesize_prove_all_queries<CS: ConstraintSystem<F>, Q: Query<F>>(
&mut self,
scope: &mut Scope<F, Q, LogMemo<F>>,
scope: &mut Scope<Q, LogMemo<F>>,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
s: &Store<F>,
) -> Result<(), SynthesisError> {
for (i, kv) in scope.all_insertions.iter().enumerate() {
self.synthesize_prove_query(cs, g, s, i, kv)?;
self.synthesize_prove_query::<_, Q::CQ>(cs, g, s, i, kv)?;
}
Ok(())
}

fn synthesize_prove_query<CS: ConstraintSystem<F>>(
fn synthesize_prove_query<CS: ConstraintSystem<F>, CQ: CircuitQuery<F>>(
&mut self,
cs: &mut CS,
g: &mut GlobalAllocator<F>,
Expand Down Expand Up @@ -752,7 +748,7 @@ mod test {
#[test]
fn test_query() {
let s = &Store::<F>::default();
let mut scope: Scope<F, DemoQuery<F>, LogMemo<F>> = Scope::default();
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> = Scope::default();
let state = State::init_lurk_state();

let fact_4 = s.read_with_default_state("(factorial 4)").unwrap();
Expand Down Expand Up @@ -803,7 +799,7 @@ mod test {
assert!(cs.is_satisfied());
}
{
let mut scope: Scope<F, DemoQuery<F>, LogMemo<F>> = Scope::default();
let mut scope: Scope<DemoQuery<F>, LogMemo<F>> = Scope::default();
scope.query(s, fact_4);
scope.query(s, fact_3);

Expand Down
6 changes: 3 additions & 3 deletions src/coprocessor/memoset/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ where
{
type CQ: CircuitQuery<F>;

fn eval(&self, s: &Store<F>, scope: &mut Scope<F, Self, LogMemo<F>>) -> Ptr;
fn eval(&self, s: &Store<F>, scope: &mut Scope<Self, LogMemo<F>>) -> Ptr;
fn recursive_eval(
&self,
scope: &mut Scope<F, Self, LogMemo<F>>,
scope: &mut Scope<Self, LogMemo<F>>,
s: &Store<F>,
subquery: Self,
) -> Ptr;
Expand All @@ -39,7 +39,7 @@ where
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, Self, LogMemo<F>>,
scope: &mut CircuitScope<F, LogMemo<F>>,
acc: &AllocatedPtr<F>,
transcript: &CircuitTranscript<F>,
) -> Result<(AllocatedPtr<F>, AllocatedPtr<F>, CircuitTranscript<F>), SynthesisError>;
Expand Down
Loading