Skip to content

Commit

Permalink
rewrite poly trait (#12)
Browse files Browse the repository at this point in the history
* rewrite poly trait

* docs
  • Loading branch information
montekki committed Sep 5, 2022
1 parent ec6117f commit 206039c
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 203 deletions.
83 changes: 37 additions & 46 deletions matrix-multiplication/src/lib.rs
Expand Up @@ -34,15 +34,15 @@ impl<F: Field> G<F> {
where
M: IntoIterator<Item = F>,
{
let f_a = DenseMultilinearExtension::from_evaluations_vec(n, a.into_iter().collect());
let f_a = f_a.relabel(0, n / 2, n / 2);
let f_a = f_a.fix_variables(&point[..n / 2]);
let f_a = DenseMultilinearExtension::from_evaluations_vec(n * 2, a.into_iter().collect());
let f_a = f_a.relabel(0, n, n);
let f_a = f_a.fix_variables(&point[..n]);

let f_b = DenseMultilinearExtension::from_evaluations_vec(n, b.into_iter().collect());
let f_b = f_b.fix_variables(&point[n / 2..]);
let f_b = DenseMultilinearExtension::from_evaluations_vec(n * 2, b.into_iter().collect());
let f_b = f_b.fix_variables(&point[n..]);

assert_eq!(f_a.num_vars(), n / 2);
assert_eq!(f_b.num_vars(), n / 2);
assert_eq!(f_a.num_vars(), n);
assert_eq!(f_b.num_vars(), n);

Self { f_a, f_b }
}
Expand All @@ -56,42 +56,33 @@ impl<F: FftField> SumCheckPolynomial<F> for G<F> {
Some(f_a * f_b)
}

fn to_univariate_at_point(&self, at: usize, point: &[F]) -> Option<SparsePolynomial<F>> {
let n = self.f_a.num_vars();
let mut poly_f_a = self.f_a.fix_variables(&point[..at]);
let mut poly_f_b = self.f_b.fix_variables(&point[..at]);
fn fix_variables(&self, partial_point: &[F]) -> Self {
let f_a = self.f_a.fix_variables(partial_point);
let f_b = self.f_b.fix_variables(partial_point);

if at != n - 1 {
poly_f_a.relabel_inplace(0, poly_f_a.num_vars() - 1, 1);

poly_f_a = poly_f_a.fix_variables(&[point[n - 1]]);
poly_f_a = poly_f_a.fix_variables(&point[at + 1..n - 1]);

poly_f_b.relabel_inplace(0, poly_f_b.num_vars() - 1, 1);

poly_f_b = poly_f_b.fix_variables(&[point[n - 1]]);
poly_f_b = poly_f_b.fix_variables(&point[at + 1..n - 1]);
}

let domain = GeneralEvaluationDomain::new(3).unwrap();
Self { f_a, f_b }
}

let evaluations_f_a = domain
.elements()
.map(|e| poly_f_a.evaluate(&[e]).unwrap())
.collect();
fn to_univariate(&self) -> SparsePolynomial<F> {
let domain: GeneralEvaluationDomain<F> = GeneralEvaluationDomain::new(3).unwrap();

let evaluations_f_b = domain
let evals = domain
.elements()
.map(|e| poly_f_b.evaluate(&[e]).unwrap())
.map(|e| {
let f_a_evals = self.f_a.fix_variables(&[e]).to_evaluations();
let f_b_evals = self.f_b.fix_variables(&[e]).to_evaluations();
f_a_evals
.into_iter()
.zip(f_b_evals.into_iter())
.map(|(a, b)| a * b)
.sum()
})
.collect();

let evaluations_f_a = Evaluations::from_vec_and_domain(evaluations_f_a, domain);
let evaluations_f_b = Evaluations::from_vec_and_domain(evaluations_f_b, domain);

let p_a = evaluations_f_a.interpolate();
let p_b = evaluations_f_b.interpolate();
let evaluations = Evaluations::from_vec_and_domain(evals, domain);
let p = evaluations.interpolate();

Some((&p_a * &p_b).into())
p.into()
}

fn num_vars(&self) -> usize {
Expand Down Expand Up @@ -168,12 +159,12 @@ mod tests {
fn matrix_test_from_book() {
let a = vec![
vec![
Fp5::from_bigint(0u32.into()).unwrap(), // P(0, 0)
Fp5::from_bigint(1u32.into()).unwrap(), // P(0, 1)
Fp5::from_bigint(0u32.into()).unwrap(),
Fp5::from_bigint(1u32.into()).unwrap(),
],
vec![
Fp5::from_bigint(2u32.into()).unwrap(), // P(1, 0)
Fp5::from_bigint(0u32.into()).unwrap(), // P(1, 1)
Fp5::from_bigint(2u32.into()).unwrap(),
Fp5::from_bigint(0u32.into()).unwrap(),
],
];

Expand Down Expand Up @@ -212,12 +203,12 @@ mod tests {
let rng = &mut test_rng();
let a = vec![
vec![
Fp5::from_bigint(0u32.into()).unwrap(), // P(0, 0)
Fp5::from_bigint(1u32.into()).unwrap(), // P(0, 1)
Fp5::from_bigint(0u32.into()).unwrap(),
Fp5::from_bigint(1u32.into()).unwrap(),
],
vec![
Fp5::from_bigint(2u32.into()).unwrap(), // P(1, 0)
Fp5::from_bigint(0u32.into()).unwrap(), // P(1, 1)
Fp5::from_bigint(2u32.into()).unwrap(),
Fp5::from_bigint(0u32.into()).unwrap(),
],
];

Expand All @@ -238,7 +229,7 @@ mod tests {
point.append(&mut u32_to_boolean_vec(j, 1usize));

let g = G::new(
2,
1,
a.iter().flatten().cloned(),
b.iter().flatten().cloned(),
&point,
Expand Down Expand Up @@ -295,7 +286,7 @@ mod tests {
let mut point: Vec<_> = u32_to_boolean_vec(i, p as usize);
point.append(&mut u32_to_boolean_vec(j, p as usize));
let g = G::new(
p as usize * 2,
p as usize,
a.0.iter().flatten().cloned(),
b.0.iter().flatten().cloned(),
&point,
Expand Down

0 comments on commit 206039c

Please sign in to comment.