diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ac14ae65..0e58b8839 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm ### Added - **Foreign field arithmetic** exposed through the `createForeignField()` class factory https://github.com/o1-labs/snarkyjs/pull/985 +- `Gadgets.ForeignField.assertMul()` for efficiently constraining products of sums in non-native arithmetic https://github.com/o1-labs/o1js/pull/1262 +- `Unconstrained` for safely maintaining unconstrained values in provable code https://github.com/o1-labs/o1js/pull/1262 - `Gadgets.rangeCheck8()` to assert that a value fits in 8 bits https://github.com/o1-labs/o1js/pull/1288 ### Changed diff --git a/src/index.ts b/src/index.ts index 71831d758..e759553c8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,6 +14,7 @@ export type { FlexibleProvable, FlexibleProvablePure, InferProvable, + Unconstrained, } from './lib/circuit_value.js'; export { CircuitValue, diff --git a/src/lib/circuit_value.ts b/src/lib/circuit_value.ts index 679b78efd..85793721b 100644 --- a/src/lib/circuit_value.ts +++ b/src/lib/circuit_value.ts @@ -1,5 +1,5 @@ import 'reflect-metadata'; -import { ProvablePure } from '../snarky.js'; +import { ProvablePure, Snarky } from '../snarky.js'; import { Field, Bool, Scalar, Group } from './core.js'; import { provable, @@ -15,6 +15,8 @@ import type { IsPure, } from '../bindings/lib/provable-snarky.js'; import { Provable } from './provable.js'; +import { assert } from './errors.js'; +import { inCheckedComputation } from './provable-context.js'; // external API export { @@ -43,6 +45,7 @@ export { HashInput, InferJson, InferredProvable, + Unconstrained, }; type ProvableExtension = { @@ -474,6 +477,93 @@ function Struct< return Struct_ as any; } +/** + * Container which holds an unconstrained value. This can be used to pass values + * between the out-of-circuit blocks in provable code. + * + * Invariants: + * - An `Unconstrained`'s value can only be accessed in auxiliary contexts. + * - An `Unconstrained` can be empty when compiling, but never empty when running as the prover. + * (there is no way to create an empty `Unconstrained` in the prover) + * + * @example + * ```ts + * let x = Unconstrained.from(0n); + * + * class MyContract extends SmartContract { + * `@method` myMethod(x: Unconstrained) { + * + * Provable.witness(Field, () => { + * // we can access and modify `x` here + * let newValue = x.get() + otherField.toBigInt(); + * x.set(newValue); + * + * // ... + * }); + * + * // throws an error! + * x.get(); + * } + * ``` + */ +class Unconstrained { + private option: + | { isSome: true; value: T } + | { isSome: false; value: undefined }; + + private constructor(isSome: boolean, value?: T) { + this.option = { isSome, value: value as any }; + } + + /** + * Read an unconstrained value. + * + * Note: Can only be called outside provable code. + */ + get(): T { + if (inCheckedComputation() && !Snarky.run.inProverBlock()) + throw Error(`You cannot use Unconstrained.get() in provable code. + +The only place where you can read unconstrained values is in Provable.witness() +and Provable.asProver() blocks, which execute outside the proof. +`); + assert(this.option.isSome, 'Empty `Unconstrained`'); // never triggered + return this.option.value; + } + + /** + * Modify the unconstrained value. + */ + set(value: T) { + this.option = { isSome: true, value }; + } + + /** + * Create an `Unconstrained` with the given `value`. + */ + static from(value: T) { + return new Unconstrained(true, value); + } + + /** + * Create an `Unconstrained` from a witness computation. + */ + static witness(compute: () => T) { + return Provable.witness( + Unconstrained.provable, + () => new Unconstrained(true, compute()) + ); + } + + static provable: Provable> = { + sizeInFields: () => 0, + toFields: () => [], + toAuxiliary: (t?: any) => [t ?? new Unconstrained(false)], + fromFields: (_, [t]) => t, + check: () => {}, + }; +} + let primitives = new Set([Field, Bool, Scalar, Group]); function isPrimitive(obj: any) { for (let P of primitives) { diff --git a/src/lib/circuit_value.unit-test.ts b/src/lib/circuit_value.unit-test.ts index 6372c52a7..79fbbafa9 100644 --- a/src/lib/circuit_value.unit-test.ts +++ b/src/lib/circuit_value.unit-test.ts @@ -1,4 +1,4 @@ -import { provable, Struct } from './circuit_value.js'; +import { provable, Struct, Unconstrained } from './circuit_value.js'; import { UInt32 } from './int.js'; import { PrivateKey, PublicKey } from './signature.js'; import { expect } from 'expect'; @@ -96,6 +96,7 @@ class MyStructPure extends Struct({ class MyTuple extends Struct([PublicKey, String]) {} let targetString = 'some particular string'; +let targetBigint = 99n; let gotTargetString = false; // create a smart contract and pass auxiliary data to a method @@ -106,11 +107,22 @@ class MyContract extends SmartContract { // this works because MyStructPure only contains field elements @state(MyStructPure) x = State(); - @method myMethod(value: MyStruct, tuple: MyTuple, update: AccountUpdate) { + @method myMethod( + value: MyStruct, + tuple: MyTuple, + update: AccountUpdate, + unconstrained: Unconstrained + ) { // check if we can pass in string values if (value.other === targetString) gotTargetString = true; value.uint[0].assertEquals(UInt32.zero); + // cannot access unconstrained values in provable code + if (Provable.inCheckedComputation()) + expect(() => unconstrained.get()).toThrow( + 'You cannot use Unconstrained.get() in provable code.' + ); + Provable.asProver(() => { let err = 'wrong value in prover'; if (tuple[1] !== targetString) throw Error(err); @@ -119,6 +131,9 @@ class MyContract extends SmartContract { if (update.lazyAuthorization?.kind !== 'lazy-signature') throw Error(err); if (update.lazyAuthorization.privateKey?.toBase58() !== key.toBase58()) throw Error(err); + + // check if we can pass in unconstrained values + if (unconstrained.get() !== targetBigint) throw Error(err); }); } } @@ -141,7 +156,8 @@ let tx = await transaction(() => { uint: [UInt32.from(0), UInt32.from(10)], }, [address, targetString], - accountUpdate + accountUpdate, + Unconstrained.from(targetBigint) ); }); diff --git a/src/lib/field.ts b/src/lib/field.ts index fa17bf8a8..226689979 100644 --- a/src/lib/field.ts +++ b/src/lib/field.ts @@ -11,10 +11,12 @@ export { Field }; // internal API export { - ConstantField, FieldType, FieldVar, FieldConst, + ConstantField, + VarField, + VarFieldVar, isField, withMessage, readVarMessage, @@ -70,6 +72,7 @@ type FieldVar = | [FieldType.Scale, FieldConst, FieldVar]; type ConstantFieldVar = [FieldType.Constant, FieldConst]; +type VarFieldVar = [FieldType.Var, number]; const FieldVar = { constant(x: bigint | FieldConst): ConstantFieldVar { @@ -79,6 +82,9 @@ const FieldVar = { isConstant(x: FieldVar): x is ConstantFieldVar { return x[0] === FieldType.Constant; }, + isVar(x: FieldVar): x is VarFieldVar { + return x[0] === FieldType.Var; + }, add(x: FieldVar, y: FieldVar): FieldVar { if (FieldVar.isConstant(x) && x[1][1] === 0n) return y; if (FieldVar.isConstant(y) && y[1][1] === 0n) return x; @@ -102,6 +108,7 @@ const FieldVar = { }; type ConstantField = Field & { value: ConstantFieldVar }; +type VarField = Field & { value: VarFieldVar }; /** * A {@link Field} is an element of a prime order [finite field](https://en.wikipedia.org/wiki/Finite_field). @@ -1031,7 +1038,7 @@ class Field { seal() { if (this.isConstant()) return this; let x = Snarky.field.seal(this.value); - return new Field(x); + return VarField(x); } /** @@ -1357,3 +1364,7 @@ there is \`Provable.asProver(() => { ... })\` which allows you to use ${varName} Warning: whatever happens inside asProver() will not be part of the zk proof. `; } + +function VarField(x: VarFieldVar): VarField { + return new Field(x) as VarField; +} diff --git a/src/lib/gadgets/basic.ts b/src/lib/gadgets/basic.ts new file mode 100644 index 000000000..3fde1ed59 --- /dev/null +++ b/src/lib/gadgets/basic.ts @@ -0,0 +1,83 @@ +/** + * Basic gadgets that only use generic gates + */ +import type { Field, VarField } from '../field.js'; +import { existsOne, toVar } from './common.js'; +import { Gates } from '../gates.js'; +import { TupleN } from '../util/types.js'; + +export { assertOneOf }; + +// TODO: create constant versions of these and expose on Gadgets + +/** + * Assert that a value equals one of a finite list of constants: + * `(x - c1)*(x - c2)*...*(x - cn) === 0` + * + * TODO: what prevents us from getting the same efficiency with snarky DSL code? + */ +function assertOneOf(x: Field, allowed: [bigint, bigint, ...bigint[]]) { + let xv = toVar(x); + let [c1, c2, ...c] = allowed; + let n = c.length; + if (n === 0) { + // (x - c1)*(x - c2) === 0 + assertBilinear(xv, xv, [1n, -(c1 + c2), 0n, c1 * c2]); + return; + } + // z = (x - c1)*(x - c2) + let z = bilinear(xv, xv, [1n, -(c1 + c2), 0n, c1 * c2]); + + for (let i = 0; i < n; i++) { + if (i < n - 1) { + // z = z*(x - c) + z = bilinear(z, xv, [1n, -c[i], 0n, 0n]); + } else { + // z*(x - c) === 0 + assertBilinear(z, xv, [1n, -c[i], 0n, 0n]); + } + } +} + +// low-level helpers to create generic gates + +/** + * Compute bilinear function of x and y: + * `z = a*x*y + b*x + c*y + d` + */ +function bilinear(x: VarField, y: VarField, [a, b, c, d]: TupleN) { + let z = existsOne(() => { + let x0 = x.toBigInt(); + let y0 = y.toBigInt(); + return a * x0 * y0 + b * x0 + c * y0 + d; + }); + // b*x + c*y - z + a*x*y + d === 0 + Gates.generic( + { left: b, right: c, out: -1n, mul: a, const: d }, + { left: x, right: y, out: z } + ); + return z; +} + +/** + * Assert bilinear equation on x, y and z: + * `a*x*y + b*x + c*y + d === z` + * + * The default for z is 0. + */ +function assertBilinear( + x: VarField, + y: VarField, + [a, b, c, d]: TupleN, + z?: VarField +) { + // b*x + c*y - z? + a*x*y + d === 0 + Gates.generic( + { left: b, right: c, out: z === undefined ? 0n : -1n, mul: a, const: d }, + { left: x, right: y, out: z === undefined ? emptyCell() : z } + ); +} + +function emptyCell() { + return existsOne(() => 0n); +} diff --git a/src/lib/gadgets/common.ts b/src/lib/gadgets/common.ts index 196ca64e7..e6e6c873c 100644 --- a/src/lib/gadgets/common.ts +++ b/src/lib/gadgets/common.ts @@ -1,5 +1,5 @@ import { Provable } from '../provable.js'; -import { Field, FieldConst, FieldVar, FieldType } from '../field.js'; +import { Field, FieldConst, FieldVar, VarField } from '../field.js'; import { Tuple, TupleN } from '../util/types.js'; import { Snarky } from '../../snarky.js'; import { MlArray } from '../ml/base.js'; @@ -12,6 +12,7 @@ export { existsOne, toVars, toVar, + isVar, assert, bitSlice, witnessSlice, @@ -21,7 +22,7 @@ export { function existsOne(compute: () => bigint) { let varMl = Snarky.existsVar(() => FieldConst.fromBigint(compute())); - return new Field(varMl); + return VarField(varMl); } function exists TupleN>( @@ -31,7 +32,7 @@ function exists TupleN>( let varsMl = Snarky.exists(n, () => MlArray.mapTo(compute(), FieldConst.fromBigint) ); - let vars = MlArray.mapFrom(varsMl, (v) => new Field(v)); + let vars = MlArray.mapFrom(varsMl, VarField); return TupleN.fromArray(n, vars); } @@ -43,20 +44,24 @@ function exists TupleN>( * * Same as `Field.seal()` with the difference that `seal()` leaves constants as is. */ -function toVar(x: Field | bigint) { +function toVar(x: Field | bigint): VarField { // don't change existing vars - if (x instanceof Field && x.value[1] === FieldType.Var) return x; + if (isVar(x)) return x; let xVar = existsOne(() => Field.from(x).toBigInt()); xVar.assertEquals(x); return xVar; } +function isVar(x: Field | bigint): x is VarField { + return x instanceof Field && FieldVar.isVar(x.value); +} + /** * Apply {@link toVar} to each element of a tuple. */ function toVars>( fields: T -): { [k in keyof T]: Field } { +): { [k in keyof T]: VarField } { return Tuple.map(fields, toVar); } diff --git a/src/lib/gadgets/foreign-field.ts b/src/lib/gadgets/foreign-field.ts index 9ee2822e0..f5833a685 100644 --- a/src/lib/gadgets/foreign-field.ts +++ b/src/lib/gadgets/foreign-field.ts @@ -1,12 +1,17 @@ +/** + * Foreign field arithmetic gadgets. + */ import { inverse as modInverse, mod, } from '../../bindings/crypto/finite_field.js'; import { provableTuple } from '../../bindings/lib/provable-snarky.js'; +import { Unconstrained } from '../circuit_value.js'; import { Field } from '../field.js'; import { Gates, foreignFieldAdd } from '../gates.js'; import { Tuple, TupleN } from '../util/types.js'; -import { assert, bitSlice, exists, toVars } from './common.js'; +import { assertOneOf } from './basic.js'; +import { assert, bitSlice, exists, toVar, toVars } from './common.js'; import { l, lMask, @@ -17,7 +22,11 @@ import { compactMultiRangeCheck, } from './range-check.js'; -export { ForeignField, Field3, Sign }; +// external API +export { ForeignField, Field3 }; + +// internal API +export { bigint3, Sign, split, combine, weakBound, Sum, assertMul }; /** * A 3-tuple of Fields, representing a 3-limb bigint. @@ -37,10 +46,14 @@ const ForeignField = { return sum([Field3.from(0n), x], [-1n], f); }, sum, + Sum(x: Field3) { + return new Sum(x); + }, mul: multiply, inv: inverse, div: divide, + assertMul, assertAlmostFieldElements, @@ -55,7 +68,7 @@ const ForeignField = { // provable case // we can just use negation (f - 1) - x! because the result is range-checked, it proves that x < f: // `f - 1 - x \in [0, 2^3l) => x <= x + (f - 1 - x) = f - 1 < f` - // (note: ffadd can't add higher multiples of (f - 1). it must always use an overflow of -1, except for x = 0 or 1) + // (note: ffadd can't add higher multiples of (f - 1). it must always use an overflow of -1, except for x = 0) ForeignField.negate(x, f - 1n); }, }; @@ -69,7 +82,7 @@ function sum(x: Field3[], sign: Sign[], f: bigint) { assert(x.length === sign.length + 1, 'inputs and operators match'); // constant case - if (x.every((x) => x.every((x) => x.isConstant()))) { + if (x.every(Field3.isConstant)) { let xBig = x.map(Field3.toBigint); let sum = sign.reduce((sum, s, i) => sum + s * xBig[i + 1], xBig[0]); return Field3.from(mod(sum, f)); @@ -131,7 +144,7 @@ function multiply(a: Field3, b: Field3, f: bigint): Field3 { assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits'); // constant case - if (a.every((x) => x.isConstant()) && b.every((x) => x.isConstant())) { + if (Field3.isConstant(a) && Field3.isConstant(b)) { let ab = Field3.toBigint(a) * Field3.toBigint(b); return Field3.from(mod(ab, f)); } @@ -149,7 +162,7 @@ function inverse(x: Field3, f: bigint): Field3 { assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits'); // constant case - if (x.every((x) => x.isConstant())) { + if (Field3.isConstant(x)) { let xInv = modInverse(Field3.toBigint(x), f); assert(xInv !== undefined, 'inverse exists'); return Field3.from(xInv); @@ -165,7 +178,7 @@ function inverse(x: Field3, f: bigint): Field3 { let xInv2Bound = weakBound(xInv[2], f); let one: Field2 = [Field.from(1n), Field.from(0n)]; - assertMul(x, xInv, one, f); + assertMulInternal(x, xInv, one, f); // range check on result bound // TODO: this uses two RCs too many.. need global RC stack @@ -183,7 +196,7 @@ function divide( assert(f < 1n << 259n, 'Foreign modulus fits in 259 bits'); // constant case - if (x.every((x) => x.isConstant()) && y.every((x) => x.isConstant())) { + if (Field3.isConstant(x) && Field3.isConstant(y)) { let yInv = modInverse(Field3.toBigint(y), f); assert(yInv !== undefined, 'inverse exists'); return Field3.from(mod(Field3.toBigint(x) * yInv, f)); @@ -198,7 +211,7 @@ function divide( }); multiRangeCheck(z); let z2Bound = weakBound(z[2], f); - assertMul(z, y, x, f); + assertMulInternal(z, y, x, f); // range check on result bound multiRangeCheck([z2Bound, Field.from(0n), Field.from(0n)]); @@ -220,7 +233,12 @@ function divide( /** * Common logic for gadgets that expect a certain multiplication result a priori, instead of just using the remainder. */ -function assertMul(x: Field3, y: Field3, xy: Field3 | Field2, f: bigint) { +function assertMulInternal( + x: Field3, + y: Field3, + xy: Field3 | Field2, + f: bigint +) { let { r01, r2, q } = multiplyNoRangeCheck(x, y, f); // range check on quotient @@ -391,6 +409,13 @@ const Field3 = { return combine(toBigint3(x)); }, + /** + * Turn several 3-tuples of Fields into bigints + */ + toBigints>(...xs: T) { + return Tuple.map(xs, Field3.toBigint); + }, + /** * Check whether a 3-tuple of Fields is constant */ @@ -431,3 +456,219 @@ function combine2([x0, x1]: bigint3 | [bigint, bigint]) { function split2(x: bigint): [bigint, bigint] { return [x & lMask, (x >> l) & lMask]; } + +/** + * Optimized multiplication of sums, like (x + y)*z = a + b + c + * + * We use several optimizations over naive summing and then multiplying: + * + * - we skip the range check on the remainder sum, because ffmul is sound with r being a sum of range-checked values + * - we replace the range check on the input sums with an extra low limb sum using generic gates + * - we chain the first input's sum into the ffmul gate + * + * As usual, all values are assumed to be range checked, and the left and right multiplication inputs + * are assumed to be bounded such that `l * r < 2^264 * (native modulus)`. + * However, all extra checks that are needed on the _sums_ are handled here. + */ +function assertMul( + x: Field3 | Sum, + y: Field3 | Sum, + xy: Field3 | Sum, + f: bigint +) { + x = Sum.fromUnfinished(x); + y = Sum.fromUnfinished(y); + xy = Sum.fromUnfinished(xy); + + // conservative estimate to ensure that multiplication bound is satisfied + // we assume that all summands si are bounded with si[2] <= f[2] checks, which implies si < 2^k where k := ceil(log(f)) + // our assertion below gives us + // |x|*|y| + q*f + |r| < (x.length * y.length) 2^2k + 2^2k + 2^2k < 3 * 2^(2*258) < 2^264 * (native modulus) + assert( + BigInt(Math.ceil(Math.sqrt(x.length * y.length))) * f < 1n << 258n, + `Foreign modulus is too large for multiplication of sums of lengths ${x.length} and ${y.length}` + ); + + // finish the y and xy sums with a zero gate + let y0 = y.finishForMulInput(f); + let xy0 = xy.finish(f); + + // x is chained into the ffmul gate + let x0 = x.finishForMulInput(f, true); + + // constant case + if ( + Field3.isConstant(x0) && + Field3.isConstant(y0) && + Field3.isConstant(xy0) + ) { + let x_ = Field3.toBigint(x0); + let y_ = Field3.toBigint(y0); + let xy_ = Field3.toBigint(xy0); + assert(mod(x_ * y_, f) === xy_, 'incorrect multiplication result'); + return; + } + + assertMulInternal(x0, y0, xy0, f); +} + +class Sum { + #result?: Field3; + #summands: Field3[]; + #ops: Sign[] = []; + + constructor(x: Field3) { + this.#summands = [x]; + } + + get result() { + assert(this.#result !== undefined, 'sum not finished'); + return this.#result; + } + + get length() { + return this.#summands.length; + } + + add(y: Field3) { + assert(this.#result === undefined, 'sum already finished'); + this.#ops.push(1n); + this.#summands.push(y); + return this; + } + + sub(y: Field3) { + assert(this.#result === undefined, 'sum already finished'); + this.#ops.push(-1n); + this.#summands.push(y); + return this; + } + + #return(x: Field3) { + this.#result = x; + return x; + } + + isConstant() { + return this.#summands.every(Field3.isConstant); + } + + finish(f: bigint, isChained = false) { + assert(this.#result === undefined, 'sum already finished'); + let signs = this.#ops; + let n = signs.length; + if (n === 0) return this.#return(this.#summands[0]); + + // constant case + if (this.isConstant()) { + return this.#return(sum(this.#summands, signs, f)); + } + + // provable case + let x = this.#summands.map(toVars); + let result = x[0]; + + for (let i = 0; i < n; i++) { + ({ result } = singleAdd(result, x[i + 1], signs[i], f)); + } + if (!isChained) Gates.zero(...result); + + this.#result = result; + return result; + } + + // TODO this is complex and should be removed once we fix the ffadd gate to constrain all limbs individually + finishForMulInput(f: bigint, isChained = false) { + assert(this.#result === undefined, 'sum already finished'); + let signs = this.#ops; + let n = signs.length; + if (n === 0) return this.#return(this.#summands[0]); + + // constant case + if (this.isConstant()) { + return this.#return(sum(this.#summands, signs, f)); + } + + // provable case + let xs = this.#summands.map(toVars); + + // since the sum becomes a multiplication input, we need to constrain all limbs _individually_. + // sadly, ffadd only constrains the low and middle limb together. + // we could fix it with a RC just for the lower two limbs + // but it's cheaper to add generic gates which handle the lowest limb separately, and avoids the unfilled MRC slot + let f0 = f & lMask; + + // generic gates for low limbs + let x0 = xs[0][0]; + let x0s: Field[] = []; + let overflows: Field[] = []; + let xRef = Unconstrained.witness(() => Field3.toBigint(xs[0])); + + // this loop mirrors the computation that a chain of ffadd gates does, + // but everything is done only on the lowest limb and using generic gates. + // the output is a sequence of low limbs (x0) and overflows, which will be wired to the ffadd results at each step. + for (let i = 0; i < n; i++) { + // compute carry and overflow + let [carry, overflow] = exists(2, () => { + // this duplicates some of the logic in singleAdd + let x = xRef.get(); + let x0 = x & lMask; + let xi = toBigint3(xs[i + 1]); + let sign = signs[i]; + + // figure out if there's overflow + x += sign * combine(xi); + let overflow = 0n; + if (sign === 1n && x >= f) overflow = 1n; + if (sign === -1n && x < 0n) overflow = -1n; + if (f === 0n) overflow = 0n; + xRef.set(x - overflow * f); + + // add with carry, only on the lowest limb + x0 = x0 + sign * xi[0] - overflow * f0; + let carry = x0 >> l; + return [carry, overflow]; + }); + overflows.push(overflow); + + // constrain carry + assertOneOf(carry, [0n, 1n, -1n]); + + // x0 <- x0 + s*xi0 - o*f0 - c*2^l + x0 = toVar( + x0 + .add(xs[i + 1][0].mul(signs[i])) + .sub(overflow.mul(f0)) + .sub(carry.mul(1n << l)) + ); + x0s.push(x0); + } + + // ffadd chain + let x = xs[0]; + for (let i = 0; i < n; i++) { + let { result, overflow } = singleAdd(x, xs[i + 1], signs[i], f); + // wire low limb and overflow to previous values + result[0].assertEquals(x0s[i]); + overflow.assertEquals(overflows[i]); + x = result; + } + if (!isChained) Gates.zero(...x); + + this.#result = x; + return x; + } + + rangeCheck() { + assert(this.#result !== undefined, 'sum not finished'); + if (this.#ops.length > 0) multiRangeCheck(this.#result); + } + + static fromUnfinished(x: Field3 | Sum) { + if (x instanceof Sum) { + assert(x.#result === undefined, 'sum already finished'); + return x; + } + return new Sum(x); + } +} diff --git a/src/lib/gadgets/foreign-field.unit-test.ts b/src/lib/gadgets/foreign-field.unit-test.ts index 4cb0d2d97..a3af9587c 100644 --- a/src/lib/gadgets/foreign-field.unit-test.ts +++ b/src/lib/gadgets/foreign-field.unit-test.ts @@ -7,6 +7,7 @@ import { equivalentProvable, fromRandom, record, + unit, } from '../testing/equivalent.js'; import { Random } from '../testing/random.js'; import { Gadgets } from './gadgets.js'; @@ -19,10 +20,12 @@ import { contains, equals, ifNotAllConstant, + not, repeat, withoutGenerics, } from '../testing/constraint-system.js'; import { GateType } from '../../snarky.js'; +import { AnyTuple } from '../util/types.js'; const { ForeignField, Field3 } = Gadgets; @@ -85,6 +88,24 @@ for (let F of fields) { (x, y) => ForeignField.div(x, y, F.modulus), 'div' ); + equivalentProvable({ from: [f, f], to: unit })( + (x, y) => assertMulExampleNaive(Field3.from(x), Field3.from(y), F.modulus), + (x, y) => assertMulExample(x, y, F.modulus), + 'assertMul' + ); + // test for assertMul which mostly tests the negative case because for random inputs, we expect + // (x - y) * z != a + b + equivalentProvable({ from: [f, f, f, f, f], to: unit })( + (x, y, z, a, b) => assert(F.mul(F.sub(x, y), z) === F.add(a, b)), + (x, y, z, a, b) => + ForeignField.assertMul( + ForeignField.Sum(x).sub(y), + z, + ForeignField.Sum(a).add(b), + F.modulus + ), + 'assertMul negative' + ); // tests with inputs that aren't reduced mod f let big264 = unreducedForeignField(264, F); // this is the max size supported by our range checks / ffadd @@ -185,7 +206,9 @@ let ffProgram = ZkProgram({ // tests for constraint system -let addChain = repeat(chainLength - 1, 'ForeignFieldAdd').concat('Zero'); +function addChain(length: number) { + return repeat(length - 1, 'ForeignFieldAdd').concat('Zero'); +} let mrc: GateType[] = ['RangeCheck0', 'RangeCheck0', 'RangeCheck1', 'Zero']; constraintSystem.fromZkProgram( @@ -193,8 +216,8 @@ constraintSystem.fromZkProgram( 'sumchain', ifNotAllConstant( and( - contains([addChain, mrc]), - withoutGenerics(equals([...addChain, ...mrc])) + contains([addChain(chainLength), mrc]), + withoutGenerics(equals([...addChain(chainLength), ...mrc])) ) ) ); @@ -219,9 +242,11 @@ constraintSystem.fromZkProgram(ffProgram, 'div', invLayout); // tests with proving +const runs = 2; + await ffProgram.compile(); -await equivalentAsync({ from: [array(f, chainLength)], to: f }, { runs: 3 })( +await equivalentAsync({ from: [array(f, chainLength)], to: f }, { runs })( (xs) => sum(xs, signs, F), async (xs) => { let proof = await ffProgram.sumchain(xs); @@ -231,7 +256,7 @@ await equivalentAsync({ from: [array(f, chainLength)], to: f }, { runs: 3 })( 'prove chain' ); -await equivalentAsync({ from: [f, f], to: f }, { runs: 3 })( +await equivalentAsync({ from: [f, f], to: f }, { runs })( F.mul, async (x, y) => { let proof = await ffProgram.mul(x, y); @@ -241,7 +266,7 @@ await equivalentAsync({ from: [f, f], to: f }, { runs: 3 })( 'prove mul' ); -await equivalentAsync({ from: [f, f], to: f }, { runs: 3 })( +await equivalentAsync({ from: [f, f], to: f }, { runs })( (x, y) => F.div(x, y) ?? throwError('no inverse'), async (x, y) => { let proof = await ffProgram.div(x, y); @@ -251,8 +276,76 @@ await equivalentAsync({ from: [f, f], to: f }, { runs: 3 })( 'prove div' ); +// assert mul example +// (x - y) * (x + y) = x^2 - y^2 + +function assertMulExample(x: Gadgets.Field3, y: Gadgets.Field3, f: bigint) { + // witness x^2, y^2 + let x2 = Provable.witness(Field3.provable, () => ForeignField.mul(x, x, f)); + let y2 = Provable.witness(Field3.provable, () => ForeignField.mul(y, y, f)); + + // assert (x - y) * (x + y) = x^2 - y^2 + let xMinusY = ForeignField.Sum(x).sub(y); + let xPlusY = ForeignField.Sum(x).add(y); + let x2MinusY2 = ForeignField.Sum(x2).sub(y2); + ForeignField.assertMul(xMinusY, xPlusY, x2MinusY2, f); +} + +function assertMulExampleNaive( + x: Gadgets.Field3, + y: Gadgets.Field3, + f: bigint +) { + // witness x^2, y^2 + let x2 = Provable.witness(Field3.provable, () => ForeignField.mul(x, x, f)); + let y2 = Provable.witness(Field3.provable, () => ForeignField.mul(y, y, f)); + + // assert (x - y) * (x + y) = x^2 - y^2 + let lhs = ForeignField.mul( + ForeignField.sub(x, y, f), + ForeignField.add(x, y, f), + f + ); + let rhs = ForeignField.sub(x2, y2, f); + Provable.assertEqual(Field3.provable, lhs, rhs); +} + +let from2 = { from: [f, f] satisfies AnyTuple }; +let gates = constraintSystem.size(from2, (x, y) => + assertMulExample(x, y, F.modulus) +); +let gatesNaive = constraintSystem.size(from2, (x, y) => + assertMulExampleNaive(x, y, F.modulus) +); +// the assertMul() version should save 11.5 rows: +// -2*1.5 rows by replacing input MRCs with low-limb ffadd +// -2*4 rows for avoiding the MRC on both mul() and sub() outputs +// -1 row for chaining one ffadd into ffmul +// +0.5 rows for having to combine the two lower result limbs before wiring to ffmul remainder +assert(gates + 11 <= gatesNaive, 'assertMul() saves at least 11 constraints'); + +let addChainedIntoMul: GateType[] = ['ForeignFieldAdd', ...mulChain]; + +constraintSystem( + 'assert mul', + from2, + (x, y) => assertMulExample(x, y, F.modulus), + and( + contains([addChain(1), addChain(1), addChainedIntoMul]), + // assertMul() doesn't use any range checks besides on internal values and the quotient + containsNTimes(2, mrc) + ) +); + // helper +function containsNTimes(n: number, pattern: readonly GateType[]) { + return and( + contains(repeat(n, pattern)), + not(contains(repeat(n + 1, pattern))) + ); +} + function sum(xs: bigint[], signs: (1n | -1n)[], F: FiniteField) { let sum = xs[0]; for (let i = 0; i < signs.length; i++) { diff --git a/src/lib/gadgets/gadgets.ts b/src/lib/gadgets/gadgets.ts index e4b1ba577..b8a31cb0a 100644 --- a/src/lib/gadgets/gadgets.ts +++ b/src/lib/gadgets/gadgets.ts @@ -9,7 +9,7 @@ import { } from './range-check.js'; import { not, rotate, xor, and, leftShift, rightShift } from './bitwise.js'; import { Field } from '../field.js'; -import { ForeignField, Field3 } from './foreign-field.js'; +import { ForeignField, Field3, Sum } from './foreign-field.js'; export { Gadgets }; @@ -497,6 +497,45 @@ const Gadgets = { return ForeignField.div(x, y, f); }, + /** + * Optimized multiplication of sums in a foreign field, for example: `(x - y)*z = a + b + c mod f` + * + * Note: This is much more efficient than using {@link ForeignField.add} and {@link ForeignField.sub} separately to + * compute the multiplication inputs and outputs, and then using {@link ForeignField.mul} to constrain the result. + * + * The sums passed into this method are "lazy sums" created with {@link ForeignField.Sum}. + * You can also pass in plain {@link Field3} elements. + * + * **Assumptions**: The assumptions on the _summands_ are analogous to the assumptions described in {@link ForeignField.mul}: + * - each summand's limbs are in the range [0, 2^88) + * - summands that are part of a multiplication input satisfy `x[2] <= f[2]` + * + * @throws if the modulus is so large that the second assumption no longer suffices for validity of the multiplication. + * For small sums and moduli < 2^256, this will not fail. + * + * @throws if the provided multiplication result is not correct modulo f. + * + * @example + * ```ts + * // we assume that x, y, z, a, b, c are range-checked, analogous to `ForeignField.mul()` + * let xMinusY = ForeignField.Sum(x).sub(y); + * let aPlusBPlusC = ForeignField.Sum(a).add(b).add(c); + * + * // assert that (x - y)*z = a + b + c mod f + * ForeignField.assertMul(xMinusY, z, aPlusBPlusC, f); + * ``` + */ + assertMul(x: Field3 | Sum, y: Field3 | Sum, z: Field3 | Sum, f: bigint) { + return ForeignField.assertMul(x, y, z, f); + }, + + /** + * Lazy sum of {@link Field3} elements, which can be used as input to {@link ForeignField.assertMul}. + */ + Sum(x: Field3) { + return ForeignField.Sum(x); + }, + /** * Prove that each of the given {@link Field3} elements is "almost" reduced modulo f, * i.e., satisfies the assumptions required by {@link ForeignField.mul} and other gadgets: @@ -572,4 +611,12 @@ export namespace Gadgets { * A 3-tuple of Fields, representing a 3-limb bigint. */ export type Field3 = [Field, Field, Field]; + + export namespace ForeignField { + /** + * Lazy sum of {@link Field3} elements, which can be used as input to {@link ForeignField.assertMul}. + */ + export type Sum = Sum_; + } } +type Sum_ = Sum; diff --git a/src/lib/testing/constraint-system.ts b/src/lib/testing/constraint-system.ts index 90f3ec159..d39c6f699 100644 --- a/src/lib/testing/constraint-system.ts +++ b/src/lib/testing/constraint-system.ts @@ -18,6 +18,7 @@ export { not, and, or, + satisfies, equals, contains, allConstant, @@ -26,6 +27,7 @@ export { withoutGenerics, print, repeat, + printGates, ConstraintSystemTest, }; @@ -181,6 +183,16 @@ function or(...tests: ConstraintSystemTest[]): ConstraintSystemTest { return { kind: 'or', tests, label: `or(${tests.map((t) => t.label)})` }; } +/** + * General test + */ +function satisfies( + label: string, + run: (cs: Gate[], inputs: TypeAndValue[]) => boolean +): ConstraintSystemTest { + return { run, label }; +} + /** * Test for precise equality of the constraint system with a given list of gates. */ @@ -262,14 +274,12 @@ function ifNotAllConstant(test: ConstraintSystemTest): ConstraintSystemTest { } /** - * Test whether all inputs are constant. + * Test whether constraint system is empty. */ -const isEmpty: ConstraintSystemTest = { - run(cs) { - return cs.length === 0; - }, - label: 'cs is empty', -}; +const isEmpty = satisfies( + 'constraint system is empty', + (cs) => cs.length === 0 +); /** * Modifies a test so that it runs on the constraint system with generic gates filtered out. @@ -299,9 +309,50 @@ const print: ConstraintSystemTest = { label: '', }; -function repeat(n: number, gates: GateType | GateType[]): readonly GateType[] { +// Do other useful things with constraint systems + +/** + * Get constraint system as a list of gates. + */ +constraintSystem.gates = function gates>>( + inputs: { from: Input }, + main: (...args: CsParams) => void +) { + let types = inputs.from.map(provable); + let { gates } = Provable.constraintSystem(() => { + let values = types.map((type) => + Provable.witness(type, (): unknown => { + throw Error('not needed'); + }) + ) as CsParams; + main(...values); + }); + return gates; +}; + +function map(transform: (gates: Gate[]) => T) { + return >>( + inputs: { from: Input }, + main: (...args: CsParams) => void + ) => transform(constraintSystem.gates(inputs, main)); +} + +/** + * Get size of constraint system. + */ +constraintSystem.size = map((gates) => gates.length); + +/** + * Print constraint system. + */ +constraintSystem.print = map(printGates); + +function repeat( + n: number, + gates: GateType | readonly GateType[] +): readonly GateType[] { gates = Array.isArray(gates) ? gates : [gates]; - return Array(n).fill(gates).flat(); + return Array(n).fill(gates).flat(); } function toGatess( @@ -444,9 +495,7 @@ function wiresToPretty(wires: Gate['wires'], row: number) { if (wire.row === row) { strWires.push(`${col}->${wire.col}`); } else { - let rowDelta = wire.row - row; - let rowStr = rowDelta > 0 ? `+${rowDelta}` : `${rowDelta}`; - strWires.push(`${col}->(${rowStr},${wire.col})`); + strWires.push(`${col}->(${wire.row},${wire.col})`); } } return strWires.join(', '); diff --git a/src/lib/util/types.ts b/src/lib/util/types.ts index 79cc38d08..7096aa4e7 100644 --- a/src/lib/util/types.ts +++ b/src/lib/util/types.ts @@ -1,8 +1,9 @@ import { assert } from '../errors.js'; -export { Tuple, TupleN, TupleMap }; +export { Tuple, TupleN, AnyTuple, TupleMap }; type Tuple = [T, ...T[]] | []; +type AnyTuple = Tuple; type TupleMap, B> = [ ...{ diff --git a/src/snarky.d.ts b/src/snarky.d.ts index 69fe3bcb3..343d714c4 100644 --- a/src/snarky.d.ts +++ b/src/snarky.d.ts @@ -1,5 +1,5 @@ import type { Account as JsonAccount } from './bindings/mina-transaction/gen/transaction-json.js'; -import type { Field, FieldConst, FieldVar } from './lib/field.js'; +import type { Field, FieldConst, FieldVar, VarFieldVar } from './lib/field.js'; import type { BoolVar, Bool } from './lib/bool.js'; import type { ScalarConst } from './lib/scalar.js'; import type { @@ -181,11 +181,11 @@ declare const Snarky: { exists( sizeInFields: number, compute: () => MlArray - ): MlArray; + ): MlArray; /** * witness a single field element variable */ - existsVar(compute: () => FieldConst): FieldVar; + existsVar(compute: () => FieldConst): VarFieldVar; /** * APIs that have to do with running provable code @@ -281,7 +281,7 @@ declare const Snarky: { * returns a new witness from an AST * (implemented with toConstantAndTerms) */ - seal(x: FieldVar): FieldVar; + seal(x: FieldVar): VarFieldVar; /** * Unfolds AST to get `x = c + c0*Var(i0) + ... + cn*Var(in)`, * returns `(c, [(c0, i0), ..., (cn, in)])`;