Skip to content

Commit

Permalink
implement remaining Field API
Browse files Browse the repository at this point in the history
  • Loading branch information
mitschabaude committed May 15, 2023
1 parent 4f2dd19 commit 455f2ca
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 44 deletions.
2 changes: 1 addition & 1 deletion src/bindings
133 changes: 93 additions & 40 deletions src/lib/field.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ type ConstantFieldVar = [FieldType.Constant, FieldConst];

const FieldVar = {
constant(x: bigint | FieldConst): [FieldType.Constant, FieldConst] {
if (typeof x === 'bigint') return [0, constFromBigint(x)];
if (typeof x === 'bigint') return [0, FieldConst.fromBigint(x)];
return [FieldType.Constant, x];
},
// TODO: handle (special) constants
Expand Down Expand Up @@ -111,7 +111,7 @@ const Field = toFunctionConstructor(
if (this.isConstant()) return this;
// TODO: fix OCaml error message, `Can't evaluate prover code outside an as_prover block`
let value = Snarky.field.readVar(this.value);
return Field.#fromConst(value);
return new Field(FieldVar.constant(value)) as Field & ConstantFieldRaw;
}

toBigInt() {
Expand Down Expand Up @@ -225,48 +225,74 @@ const Field = toFunctionConstructor(
return new Field(z);
}

equals(y: Field | bigint | number | string): Bool {
if (this.isConstant() && isConstant(y)) {
return Bool(this.toBigInt() === toFp(y));
isZero() {
if (this.isConstant()) {
return Bool(this.toBigInt() === 0n);
}
// create x - y
let xMinusY = this.sub(y);
// create witnesses z = -1/(x - y), or z=0 if x=y,
// and b = 1 + z(x - y)
// create witnesses z = -1/x, or z=0 if x=0,
// and b = 1 + zx
let [, b, z] = Snarky.exists(2, () => {
let delta = xMinusY.toBigInt();
let z = Fp.negate(Fp.inverse(delta) ?? 0n);
let b = Fp.add(1n, Fp.mul(z, delta));
let x = this.toBigInt();
let z = Fp.negate(Fp.inverse(x) ?? 0n);
let b = Fp.add(1n, Fp.mul(z, x));
return [0, FieldConst.fromBigint(b), FieldConst.fromBigint(z)];
});
// add constraints
// z * (x - y) === b - 1
Snarky.field.assertMul(z, xMinusY.value, FieldVar.add(b, FieldVar[-1]));
// b * (x - y) === 0
Snarky.field.assertMul(b, xMinusY.value, FieldVar[0]);
// ^^^ these prove that b = Bool(x === y):
// if x = y, the 1st equation implies b = 1
// if x != y, the 2nd implies b = 0
// z * x === b - 1
Snarky.field.assertMul(z, this.value, FieldVar.add(b, FieldVar[-1]));
// b * x === 0
Snarky.field.assertMul(b, this.value, FieldVar[0]);
// ^^^ these prove that b = Bool(x === 0):
// if x = 0, the 1st equation implies b = 1
// if x != 0, the 2nd implies b = 0
return Bool.Unsafe.ofField(new Field(b));
}

equals(y: Field | bigint | number | string): Bool {
// x == y is equivalent to x - y == 0
// if one of the two is constant, we just need the two constraints in `isZero`
if (this.isConstant() || isConstant(y)) {
return this.sub(y).isZero();
}
// if both are variables, we create one new variable for x-y so that `isZero` doesn't create two
let xMinusY = Snarky.existsVar(() =>
FieldConst.fromBigint(Fp.sub(this.toBigInt(), toFp(y)))
);
Snarky.field.assertEqual(this.sub(y).value, xMinusY);
return new Field(xMinusY).isZero();
}

// internal base method for all comparisons
#compare(y: FieldVar) {
// TODO: support all bit lengths
let length = Fp.sizeInBits - 2;
let [, less, lessOrEqual] = Snarky.field.compare(length, this.value, y);
return {
less: Bool.Unsafe.ofField(new Field(less)),
lessOrEqual: Bool.Unsafe.ofField(new Field(lessOrEqual)),
};
}

lessThan(y: Field | bigint | number | string): Bool {
if (this.isConstant() && isConstant(y)) {
return Bool(this.toBigInt() < toFp(y));
}
return SnarkyField(this).lessThan(y);
return this.#compare(Field.#toVar(y)).less;
}

lessThanOrEqual(y: Field | bigint | number | string): Bool {
if (this.isConstant() && isConstant(y)) {
return Bool(this.toBigInt() <= toFp(y));
}
return SnarkyField(this).lessThanOrEqual(y);
return this.#compare(Field.#toVar(y)).lessOrEqual;
}

greaterThan(y: Field | bigint | number | string) {
return new Field(y).lessThan(this);
return Field.from(y).lessThan(this);
}

greaterThanOrEqual(y: Field | bigint | number | string) {
return new Field(y).lessThanOrEqual(this);
return Field.from(y).lessThanOrEqual(this);
}

assertLessThan(y: Field | bigint | number | string, message?: string) {
Expand All @@ -277,11 +303,13 @@ const Field = toFunctionConstructor(
}
return;
}
SnarkyField(this).assertLessThan(y);
let { less } = this.#compare(Field.#toVar(y));
less.assertTrue();
} catch (err) {
throw withMessage(err, message);
}
}

assertLessThanOrEqual(
y: Field | bigint | number | string,
message?: string
Expand All @@ -293,27 +321,39 @@ const Field = toFunctionConstructor(
}
return;
}
SnarkyField(this).assertLessThanOrEqual(y);
let { lessOrEqual } = this.#compare(Field.#toVar(y));
lessOrEqual.assertTrue();
} catch (err) {
throw withMessage(err, message);
}
}

assertGreaterThan(y: Field | bigint | number | string, message?: string) {
new Field(y).assertLessThan(this, message);
Field.from(y).assertLessThan(this, message);
}

assertGreaterThanOrEqual(
y: Field | bigint | number | string,
message?: string
) {
new Field(y).assertLessThanOrEqual(this, message);
Field.from(y).assertLessThanOrEqual(this, message);
}

isZero() {
if (this.isConstant()) {
return Bool(this.toBigInt() === 0n);
assertNonZero(message?: string) {
try {
if (this.isConstant()) {
if (!(this.toBigInt() !== 0n)) {
throw Error(`Field.assertNonZero(): expected 0, got ${this}`);
}
return;
}
// proving the inverse also proves that the field element is non-zero
this.inv();
} catch (err) {
throw withMessage(err, message);
}
return SnarkyField(this).isZero();
}

assertBool(message?: string) {
try {
if (this.isConstant()) {
Expand All @@ -323,7 +363,8 @@ const Field = toFunctionConstructor(
}
return;
}
SnarkyField(this).assertBool();
// x^2 = x <--> x(1 - x) = 0 <--> x is 0 or 1
Snarky.field.assertMul(this.value, this.value, this.value);
} catch (err) {
throw withMessage(err, message);
}
Expand Down Expand Up @@ -351,7 +392,11 @@ const Field = toFunctionConstructor(
}
return bits.map(Bool);
}
return SnarkyField(this).toBits();
let [, ...bits] = Snarky.field.toBits(
length ?? Fp.sizeInBits,
this.value
);
return bits.map((b) => Bool.Unsafe.ofField(new Field(b)));
}

static fromBits(bits: (Bool | boolean)[]) {
Expand All @@ -365,23 +410,31 @@ const Field = toFunctionConstructor(
.concat(Array(Fp.sizeInBits - length).fill(false));
return new Field(Fp.fromBits(bits_));
}
return SnarkyField.fromBits(bits);
let bitsVars = bits.map((b): FieldVar => {
if (typeof b === 'boolean') return b ? FieldVar[1] : FieldVar[0];
return b.toField().value;
});
let x = Snarky.field.fromBits([0, ...bitsVars]);
return new Field(x);
}

// TODO rename
rangeCheckHelper(numBits: number) {
Field.#checkBitLength('Field.rangeCheckHelper()', numBits);
rangeCheckHelper(length: number) {
Field.#checkBitLength('Field.rangeCheckHelper()', length);
if (this.isConstant()) {
let bits = Fp.toBits(this.toBigInt())
.slice(0, numBits)
.concat(Array(Fp.sizeInBits - numBits).fill(false));
.slice(0, length)
.concat(Array(Fp.sizeInBits - length).fill(false));
return new Field(Fp.fromBits(bits));
}
return SnarkyField(this).rangeCheckHelper(numBits);
let x = Snarky.field.truncateToBits(length, this.value);
return new Field(x);
}

seal() {
if (this.isConstant()) return this;
return SnarkyField(this).seal();
let x = Snarky.field.seal(this.value);
return new Field(x);
}

static random() {
Expand Down
9 changes: 6 additions & 3 deletions src/snarky.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ declare namespace Snarky {
type VerificationKey = unknown;
type Proof = unknown;
}
// same representation, but use a different name to communicate intent / constraints
type BoolVar = FieldVar;

/**
* Internal interface to snarky-ml
Expand Down Expand Up @@ -127,17 +129,18 @@ declare const Snarky: {
* check x < y and x <= y
*/
compare(
bitLength: number,
x: FieldVar,
y: FieldVar
): [flag: 0, less: FieldVar, lessOrEqual: FieldVar];
): [flag: 0, less: BoolVar, lessOrEqual: BoolVar];
/**
*
*/
toBits(length: number, x: FieldVar): MlList<FieldVar>;
toBits(length: number, x: FieldVar): MlArray<BoolVar>;
/**
*
*/
fromBits(bits: MlList<FieldVar>): FieldVar;
fromBits(bits: MlArray<BoolVar>): FieldVar;
/**
* returns x truncated to the lowest `length` bits
* => can be used to assert that x fits in `length` bits.
Expand Down

0 comments on commit 455f2ca

Please sign in to comment.