Skip to content

Commit

Permalink
Merge pull request #1223 from o1-labs/feature/ffmul
Browse files Browse the repository at this point in the history
Foreign fields 4: Multiplication
  • Loading branch information
mitschabaude committed Nov 21, 2023
2 parents dac55a5 + 021b001 commit 084a1e7
Show file tree
Hide file tree
Showing 12 changed files with 556 additions and 69 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm

## [Unreleased](https://github.com/o1-labs/o1js/compare/26363465d...HEAD)

### Breaking changes

- Change return signature of `ZkProgram.analyzeMethods()` to be a keyed object https://github.com/o1-labs/o1js/pull/1223

### Added

- Provable non-native field arithmetic:
- `Gadgets.ForeignField.{add, sub, sumchain}()` for addition and subtraction https://github.com/o1-labs/o1js/pull/1220
- `Gadgets.ForeignField.{mul, inv, div}()` for multiplication and division https://github.com/o1-labs/o1js/pull/1223
- Comprehensive internal testing of constraint system layouts generated by new gadgets https://github.com/o1-labs/o1js/pull/1241 https://github.com/o1-labs/o1js/pull/1220

### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/bindings
2 changes: 1 addition & 1 deletion src/lib/gadgets/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ function toVars<T extends Tuple<Field | bigint>>(
return Tuple.map(fields, toVar);
}

function assert(stmt: boolean, message?: string) {
function assert(stmt: boolean, message?: string): asserts stmt {
if (!stmt) {
throw Error(message ?? 'Assertion failed');
}
Expand Down
274 changes: 256 additions & 18 deletions src/lib/gadgets/foreign-field.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import { mod } from '../../bindings/crypto/finite_field.js';
import {
inverse as modInverse,
mod,
} from '../../bindings/crypto/finite_field.js';
import { provableTuple } from '../../bindings/lib/provable-snarky.js';
import { Field } from '../field.js';
import { Gates, foreignFieldAdd } from '../gates.js';
import { Tuple } from '../util/types.js';
import { assert, exists, toVars } from './common.js';
import { L, lMask, multiRangeCheck, twoL, twoLMask } from './range-check.js';
import { assert, bitSlice, exists, toVars } from './common.js';
import {
l,
lMask,
multiRangeCheck,
l2,
l2Mask,
l3,
compactMultiRangeCheck,
} from './range-check.js';

export { ForeignField, Field3, Sign };

Expand All @@ -23,6 +34,10 @@ const ForeignField = {
return sum([x, y], [-1n], f);
},
sum,

mul: multiply,
inv: inverse,
div: divide,
};

/**
Expand Down Expand Up @@ -70,17 +85,17 @@ function singleAdd(x: Field3, y: Field3, sign: Sign, f: bigint) {
let y_ = toBigint3(y);

// figure out if there's overflow
let r = collapse(x_) + sign * collapse(y_);
let r = combine(x_) + sign * combine(y_);
let overflow = 0n;
if (sign === 1n && r >= f) overflow = 1n;
if (sign === -1n && r < 0n) overflow = -1n;
if (f === 0n) overflow = 0n; // special case where overflow doesn't change anything

// do the add with carry
// note: this "just works" with negative r01
let r01 = collapse2(x_) + sign * collapse2(y_) - overflow * collapse2(f_);
let carry = r01 >> twoL;
r01 &= twoLMask;
let r01 = combine2(x_) + sign * combine2(y_) - overflow * combine2(f_);
let carry = r01 >> l2;
r01 &= l2Mask;
let [r0, r1] = split2(r01);
let r2 = x_[2] + sign * y_[2] - overflow * f_[2] + carry;

Expand All @@ -92,19 +107,238 @@ function singleAdd(x: Field3, y: Field3, sign: Sign, f: bigint) {
return { result: [r0, r1, r2] satisfies Field3, overflow };
}

function multiply(a: Field3, b: Field3, f: bigint): Field3 {
assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits');

// constant case
if (a.every((x) => x.isConstant()) && b.every((x) => x.isConstant())) {
let ab = Field3.toBigint(a) * Field3.toBigint(b);
return Field3.from(mod(ab, f));
}

// provable case
let { r01, r2, q } = multiplyNoRangeCheck(a, b, f);

// limb range checks on quotient and remainder
multiRangeCheck(q);
let r = compactMultiRangeCheck(r01, r2);
return r;
}

function inverse(x: Field3, f: bigint): Field3 {
assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits');

// constant case
if (x.every((x) => x.isConstant())) {
let xInv = modInverse(Field3.toBigint(x), f);
assert(xInv !== undefined, 'inverse exists');
return Field3.from(xInv);
}

// provable case
let xInv = exists(3, () => {
let xInv = modInverse(Field3.toBigint(x), f);
return xInv === undefined ? [0n, 0n, 0n] : split(xInv);
});
multiRangeCheck(xInv);
// we need to bound xInv because it's a multiplication input
let xInv2Bound = weakBound(xInv[2], f);

let one: Field2 = [Field.from(1n), Field.from(0n)];
assertMul(x, xInv, one, f);

// range check on result bound
// TODO: this uses two RCs too many.. need global RC stack
multiRangeCheck([xInv2Bound, Field.from(0n), Field.from(0n)]);

return xInv;
}

function divide(
x: Field3,
y: Field3,
f: bigint,
{ allowZeroOverZero = false } = {}
) {
assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits');

// constant case
if (x.every((x) => x.isConstant()) && y.every((x) => x.isConstant())) {
let yInv = modInverse(Field3.toBigint(y), f);
assert(yInv !== undefined, 'inverse exists');
return Field3.from(mod(Field3.toBigint(x) * yInv, f));
}

// provable case
// to show that z = x/y, we prove that z*y = x and y != 0 (the latter avoids the unconstrained 0/0 case)
let z = exists(3, () => {
let yInv = modInverse(Field3.toBigint(y), f);
if (yInv === undefined) return [0n, 0n, 0n];
return split(mod(Field3.toBigint(x) * yInv, f));
});
multiRangeCheck(z);
let z2Bound = weakBound(z[2], f);
assertMul(z, y, x, f);

// range check on result bound
multiRangeCheck([z2Bound, Field.from(0n), Field.from(0n)]);

if (!allowZeroOverZero) {
// assert that y != 0 mod f by checking that it doesn't equal 0 or f
// this works because we assume y[2] <= f2
// TODO is this the most efficient way?
let y01 = y[0].add(y[1].mul(1n << l));
y01.equals(0n).and(y[2].equals(0n)).assertFalse();
let [f0, f1, f2] = split(f);
let f01 = combine2([f0, f1]);
y01.equals(f01).and(y[2].equals(f2)).assertFalse();
}

return z;
}

/**
* Common logic for gadgets that expect a certain multiplication result a priori, instead of just using the remainder.
*/
function assertMul(x: Field3, y: Field3, xy: Field3 | Field2, f: bigint) {
let { r01, r2, q } = multiplyNoRangeCheck(x, y, f);

// range check on quotient
multiRangeCheck(q);

// bind remainder to input xy
if (xy.length === 2) {
let [xy01, xy2] = xy;
r01.assertEquals(xy01);
r2.assertEquals(xy2);
} else {
let xy01 = xy[0].add(xy[1].mul(1n << l));
r01.assertEquals(xy01);
r2.assertEquals(xy[2]);
}
}

/**
* Core building block for all gadgets using foreign field multiplication.
*/
function multiplyNoRangeCheck(a: Field3, b: Field3, f: bigint) {
// notation follows https://github.com/o1-labs/rfcs/blob/main/0006-ffmul-revised.md
let f_ = (1n << l3) - f;
let [f_0, f_1, f_2] = split(f_);
let f2 = f >> l2;
let f2Bound = (1n << l) - f2 - 1n;

let witnesses = exists(21, () => {
// convert inputs to bigints
let [a0, a1, a2] = toBigint3(a);
let [b0, b1, b2] = toBigint3(b);

// compute q and r such that a*b = q*f + r
let ab = combine([a0, a1, a2]) * combine([b0, b1, b2]);
let q = ab / f;
let r = ab - q * f;

let [q0, q1, q2] = split(q);
let [r0, r1, r2] = split(r);
let r01 = combine2([r0, r1]);

// compute product terms
let p0 = a0 * b0 + q0 * f_0;
let p1 = a0 * b1 + a1 * b0 + q0 * f_1 + q1 * f_0;
let p2 = a0 * b2 + a1 * b1 + a2 * b0 + q0 * f_2 + q1 * f_1 + q2 * f_0;

let [p10, p110, p111] = split(p1);
let p11 = combine2([p110, p111]);

// carry bottom limbs
let c0 = (p0 + (p10 << l) - r01) >> l2;

// carry top limb
let c1 = (p2 - r2 + p11 + c0) >> l;

// split high carry
let c1_00 = bitSlice(c1, 0, 12);
let c1_12 = bitSlice(c1, 12, 12);
let c1_24 = bitSlice(c1, 24, 12);
let c1_36 = bitSlice(c1, 36, 12);
let c1_48 = bitSlice(c1, 48, 12);
let c1_60 = bitSlice(c1, 60, 12);
let c1_72 = bitSlice(c1, 72, 12);
let c1_84 = bitSlice(c1, 84, 2);
let c1_86 = bitSlice(c1, 86, 2);
let c1_88 = bitSlice(c1, 88, 2);
let c1_90 = bitSlice(c1, 90, 1);

// quotient high bound
let q2Bound = q2 + f2Bound;

// prettier-ignore
return [
r01, r2,
q0, q1, q2,
q2Bound,
p10, p110, p111,
c0,
c1_00, c1_12, c1_24, c1_36, c1_48, c1_60, c1_72,
c1_84, c1_86, c1_88, c1_90,
];
});

// prettier-ignore
let [
r01, r2,
q0, q1, q2,
q2Bound,
p10, p110, p111,
c0,
c1_00, c1_12, c1_24, c1_36, c1_48, c1_60, c1_72,
c1_84, c1_86, c1_88, c1_90,
] = witnesses;

let q: Field3 = [q0, q1, q2];

// ffmul gate. this already adds the following zero row.
Gates.foreignFieldMul({
left: a,
right: b,
remainder: [r01, r2],
quotient: q,
quotientHiBound: q2Bound,
product1: [p10, p110, p111],
carry0: c0,
carry1p: [c1_00, c1_12, c1_24, c1_36, c1_48, c1_60, c1_72],
carry1c: [c1_84, c1_86, c1_88, c1_90],
foreignFieldModulus2: f2,
negForeignFieldModulus: [f_0, f_1, f_2],
});

// multi-range check on internal values
multiRangeCheck([p10, p110, q2Bound]);

// note: this function is supposed to be the most flexible interface to the ffmul gate.
// that's why we don't add range checks on q and r here, because there are valid use cases
// for not range-checking either of them -- for example, they could be wired to other
// variables that are already range-checked, or to constants / public inputs.
return { r01, r2, q };
}

function weakBound(x: Field, f: bigint) {
return x.add(lMask - (f >> l2));
}

const Field3 = {
/**
* Turn a bigint into a 3-tuple of Fields
*/
from(x: bigint): Field3 {
return toField3(split(x));
return Tuple.map(split(x), Field.from);
},

/**
* Turn a 3-tuple of Fields into a bigint
*/
toBigint(x: Field3): bigint {
return collapse(toBigint3(x));
return combine(toBigint3(x));
},

/**
Expand All @@ -116,23 +350,27 @@ const Field3 = {
provable: provableTuple([Field, Field, Field]),
};

function toField3(x: bigint3): Field3 {
return Tuple.map(x, (x) => new Field(x));
}
type Field2 = [Field, Field];
const Field2 = {
toBigint(x: Field2): bigint {
return combine2(Tuple.map(x, (x) => x.toBigInt()));
},
};

function toBigint3(x: Field3): bigint3 {
return Tuple.map(x, (x) => x.toBigInt());
}

function collapse([x0, x1, x2]: bigint3) {
return x0 + (x1 << L) + (x2 << twoL);
function combine([x0, x1, x2]: bigint3) {
return x0 + (x1 << l) + (x2 << l2);
}
function split(x: bigint): bigint3 {
return [x & lMask, (x >> L) & lMask, (x >> twoL) & lMask];
return [x & lMask, (x >> l) & lMask, (x >> l2) & lMask];
}

function collapse2([x0, x1]: bigint3 | [bigint, bigint]) {
return x0 + (x1 << L);
function combine2([x0, x1]: bigint3 | [bigint, bigint]) {
return x0 + (x1 << l);
}
function split2(x: bigint): [bigint, bigint] {
return [x & lMask, (x >> L) & lMask];
return [x & lMask, (x >> l) & lMask];
}

0 comments on commit 084a1e7

Please sign in to comment.