Skip to content

Commit

Permalink
feat(hints): add NewHint#47 (lambdaclass#1030)
Browse files Browse the repository at this point in the history
* Refactor some cairo_programs as library

- uint384
- uint384_extension

I need this changes for later tests

* Refactor pack and split, macro hygiene

Pack and split were unified into a generic function in `uint_utils.rs`.
UintNNN specific functions were changed to `uNNN_pack` and `uNNN_split`, that internally use the generic ones.

Some macros were using items without specifying full path, and I fixed this.
With this change a lot of imports weren't needed anymore, and so were removed to appease clippy.

Also add the groundwork for the new hint to be implemented.

* Add NewHint#47

* Update changelog

* Add stdlib prelude import

* Import HashMap from stdlib

* Fix contracts

* Fix wasm-tests

* Change bitmask to `u128::MAX`

* Revert "Change bitmask to `u128::MAX`"

This reverts commit 7a1bea7.

* Compare memory_holes with correct amount

* Rename uint512 -> fq, fq -> fq_test; Add comment

* Fix contract compilation errors

* Fix compilation error

* Fix cairo_run_fq test

* Move changelog entry

* Fix failing tests

* Fix compilation errors

* Fix test error

Was using wrong pack function

* Define empty mains on libraries

* Use bare array instead of vec

Co-authored-by: fmoletta <99273364+fmoletta@users.noreply.github.com>

* Make pack and split methods instead of functions

* Fix merge errors

* Fix errors and add BigInt3::split86

* Make u512_pack and split functions into methods

* Re-add disappeared newline inside hint

* Fix compilation errors

---------

Co-authored-by: fmoletta <99273364+fmoletta@users.noreply.github.com>
  • Loading branch information
2 people authored and kariy committed Jun 23, 2023
1 parent f103bc4 commit 7db0eb3
Show file tree
Hide file tree
Showing 56 changed files with 1,052 additions and 548 deletions.
37 changes: 37 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,43 @@
ids.res.high = res_split[1]
```

* Add missing hint on vrf.json lib [#1030](https://github.com/lambdaclass/cairo-rs/pull/1030):

`BuiltinHintProcessor` now supports the following hint:

```python
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)

def pack(z, num_bits_shift: int) -> int:
limbs = (z.low, z.high)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

def pack_extended(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2, z.d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

x = pack_extended(ids.x, num_bits_shift = 128)
div = pack(ids.div, num_bits_shift = 128)

quotient, remainder = divmod(x, div)

quotient_split = split(quotient, num_bits_shift=128, length=4)

ids.quotient.d0 = quotient_split[0]
ids.quotient.d1 = quotient_split[1]
ids.quotient.d2 = quotient_split[2]
ids.quotient.d3 = quotient_split[3]

remainder_split = split(remainder, num_bits_shift=128, length=2)
ids.remainder.low = remainder_split[0]
ids.remainder.high = remainder_split[1]
```

* Add method `Program::data_len(&self) -> usize` to get the number of data cells in a given program [#1022](https://github.com/lambdaclass/cairo-rs/pull/1022)

* Add missing hint on uint256_improvements lib [#1013](https://github.com/lambdaclass/cairo-rs/pull/1013):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.bool import TRUE
from cairo_programs.uint384 import uint384_lib, Uint384, Uint384_expand
from cairo_programs.uint384_extension import uint384_extension_lib
from cairo_programs.uint384 import u384, Uint384, Uint384_expand
from cairo_programs.uint384_extension import u384_ext
from cairo_programs.field_arithmetic import field_arithmetic


func run_get_square{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(prime: Uint384, generator: Uint384, num: Uint384, iterations: felt) {
func run_get_square{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
prime: Uint384, generator: Uint384, num: Uint384, iterations: felt
) {
alloc_locals;
if (iterations == 0) {
return ();
}

let (square) = field_arithmetic.mul(num, num, prime);
let (square) = field_arithmetic.mul(num, num, prime);

let (success, root_1) = field_arithmetic.get_square_root(square, prime, generator);
assert success = 1;

// We calculate this before in order to prevent revoked range_check_ptr reference due to branching
let (root_2) = uint384_lib.sub(prime, root_1);
let (is_first_root) = uint384_lib.eq(root_1, num);
let (root_2) = u384.sub(prime, root_1);
let (is_first_root) = u384.eq(root_1, num);

if ( is_first_root != TRUE) {
if (is_first_root != TRUE) {
assert root_2 = num;
}

return run_get_square(prime, generator, square, iterations -1);
return run_get_square(prime, generator, square, iterations - 1);
}

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
let p = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime
let p = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime
let x = Uint384(5, 0, 0);
let g = Uint384(7, 0, 0);
run_get_square(p, g, x, 100);
Expand Down
46 changes: 24 additions & 22 deletions cairo_programs/field_arithmetic.cairo
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
%builtins range_check bitwise

// Code taken from https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/fbf532651959f27037d70cd70ec6dbaf987f535c/lib/field_arithmetic.cairo
from starkware.cairo.common.bitwise import bitwise_and, bitwise_or, bitwise_xor
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.math import assert_in_range, assert_le, assert_nn_le, assert_not_zero
from starkware.cairo.common.math_cmp import is_le
from starkware.cairo.common.pow import pow
from starkware.cairo.common.registers import get_ap, get_fp_and_pc
from cairo_programs.uint384 import uint384_lib, Uint384, Uint384_expand, SHIFT, HALF_SHIFT
from cairo_programs.uint384_extension import uint384_extension_lib, Uint768
from cairo_programs.uint384 import u384, Uint384, Uint384_expand, SHIFT, HALF_SHIFT
from cairo_programs.uint384_extension import u384_ext, Uint768

// Functions for operating elements in a finite field F_p (i.e. modulo a prime p), with p of at most 384 bits
namespace field_arithmetic {
// Computes a * b modulo p
func mul{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (res: Uint384) {
let (low: Uint384, high: Uint384) = uint384_lib.mul_d(a, b);
let (low: Uint384, high: Uint384) = u384.mul_d(a, b);
let full_mul_result: Uint768 = Uint768(low.d0, low.d1, low.d2, high.d0, high.d1, high.d2);
let (
quotient: Uint768, remainder: Uint384
) = uint384_extension_lib.unsigned_div_rem_uint768_by_uint384(full_mul_result, p);
let (quotient: Uint768, remainder: Uint384) = u384_ext.unsigned_div_rem_uint768_by_uint384(
full_mul_result, p
);
return (remainder,);
}

// Computes a**2 modulo p
func square{range_check_ptr}(a: Uint384, p: Uint384) -> (res: Uint384) {
let (low: Uint384, high: Uint384) = uint384_lib.square_e(a);
let (low: Uint384, high: Uint384) = u384.square_e(a);
let full_mul_result: Uint768 = Uint768(low.d0, low.d1, low.d2, high.d0, high.d1, high.d2);
let (
quotient: Uint768, remainder: Uint384
) = uint384_extension_lib.unsigned_div_rem_uint768_by_uint384(full_mul_result, p);
let (quotient: Uint768, remainder: Uint384) = u384_ext.unsigned_div_rem_uint768_by_uint384(
full_mul_result, p
);
return (remainder,);
}

Expand All @@ -43,7 +45,7 @@ namespace field_arithmetic {
alloc_locals;

// TODO: Create an equality function within field_arithmetic to avoid overflow bugs
let (is_zero) = uint384_lib.eq(x, Uint384(0, 0, 0));
let (is_zero) = u384.eq(x, Uint384(0, 0, 0));
if (is_zero == 1) {
return (1, Uint384(0, 0, 0));
}
Expand Down Expand Up @@ -101,22 +103,22 @@ namespace field_arithmetic {
// Verify that the values computed in the hint are what they are supposed to be
let (gx: Uint384) = mul(generator, x, p);
if (success_x == 1) {
uint384_lib.check(sqrt_x);
let (is_valid) = uint384_lib.lt(sqrt_x, p);
assert is_valid = 1;
u384.check(sqrt_x);
let (is_valid) = u384.lt(sqrt_x, p);
assert is_valid = 1;
let (sqrt_x_squared: Uint384) = mul(sqrt_x, sqrt_x, p);
// Note these checks may fail if the input x does not satisfy 0<= x < p
// TODO: Create a equality function within field_arithmetic to avoid overflow bugs
let (check_x) = uint384_lib.eq(x, sqrt_x_squared);
let (check_x) = u384.eq(x, sqrt_x_squared);
assert check_x = 1;
return (1, sqrt_x);
} else {
// In this case success_gx = 1
uint384_lib.check(sqrt_gx);
let (is_valid) = uint384_lib.lt(sqrt_gx, p);
assert is_valid = 1;
u384.check(sqrt_gx);
let (is_valid) = u384.lt(sqrt_gx, p);
assert is_valid = 1;
let (sqrt_gx_squared: Uint384) = mul(sqrt_gx, sqrt_gx, p);
let (check_gx) = uint384_lib.eq(gx, sqrt_gx_squared);
let (check_gx) = u384.eq(gx, sqrt_gx_squared);
assert check_gx = 1;
// No square roots were found
// Note that Uint384(0, 0, 0) is not a square root here, but something needs to be returned
Expand Down Expand Up @@ -158,7 +160,7 @@ namespace field_arithmetic {
ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1]
ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2]
%}
uint384_lib.check(b_inverse_mod_p);
u384.check(b_inverse_mod_p);
let (b_times_b_inverse) = mul(b, b_inverse_mod_p, p);
assert b_times_b_inverse = Uint384(1, 0, 0);

Expand All @@ -171,7 +173,7 @@ func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: B
alloc_locals;
// Test get_square

//Small prime
// Small prime
let p_a = Uint384(7, 0, 0);
let x_a = Uint384(2, 0, 0);
let generator_a = Uint384(3, 0, 0);
Expand All @@ -183,7 +185,7 @@ func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: B
assert r_a.d2 = 0;

// Goldilocks Prime
let p_b = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime
let p_b = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime
let x_b = Uint384(25, 0, 0);
let generator_b = Uint384(7, 0, 0);
let (s_b, r_b) = field_arithmetic.get_square_root(x_b, p_b, generator_b);
Expand Down
195 changes: 195 additions & 0 deletions cairo_programs/fq.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from starkware.cairo.common.uint256 import Uint256, split_64
from starkware.cairo.common.math_cmp import is_le

from cairo_programs.uint384_extension import Uint384, Uint768, u384

struct Uint512 {
d0: felt,
d1: felt,
d2: felt,
d3: felt,
}

const SHIFT = 2 ** 128;
const ALL_ONES = 2 ** 128 - 1;
const HALF_SHIFT = 2 ** 64;

namespace u512 {
func add_u512_and_u256{range_check_ptr}(a: Uint512, b: Uint256) -> Uint512 {
alloc_locals;

let a_low = Uint256(low=a.d0, high=a.d1);
let a_high = Uint256(low=a.d2, high=a.d3);

let (sum_low, carry0) = add_carry(a_low, b);

local res: Uint512;

res.d0 = sum_low.low;
res.d1 = sum_low.high;
// res.d2 = sum_low.d2;

// TODO : create add_one (high bits not needed)
let a_high_plus_carry = add(a_high, Uint256(carry0, 0));

res.d2 = a_high_plus_carry.low;
res.d3 = a_high_plus_carry.high;

return res;
}

func mul_u512_by_u256{range_check_ptr}(a: Uint512, b: Uint256) -> Uint768 {
alloc_locals;
let (a0, a1) = split_64(a.d0);
let (a2, a3) = split_64(a.d1);
let (a4, a5) = split_64(a.d2);
let (a6, a7) = split_64(a.d3);

let (b0, b1) = split_64(b.low);
let (b2, b3) = split_64(b.high);

local B0 = b0 * HALF_SHIFT;
local b12 = b1 + b2 * HALF_SHIFT;

let (res0, carry) = u384.split_128(a1 * B0 + a0 * b.low);
let (res2, carry) = u384.split_128(a3 * B0 + a2 * b.low + a1 * b12 + a0 * b.high + carry);
let (res4, carry) = u384.split_128(
a5 * B0 + a4 * b.low + a3 * b12 + a2 * b.high + a1 * b3 + carry
);
let (res6, carry) = u384.split_128(
a7 * B0 + a6 * b.low + a5 * b12 + a4 * b.high + a3 * b3 + carry
);
let (res8, carry) = u384.split_128(a7 * b12 + a6 * b.high + a5 * b3 + carry);
let (res10, carry) = u384.split_128(a7 * b3 + carry);
let res = Uint768(d0=res0, d1=res2, d2=res4, d3=res6, d4=res8, d5=res10);
return res;
}

func u512_unsigned_div_rem{range_check_ptr}(x: Uint512, div: Uint256) -> (
q: Uint512, r: Uint256
) {
alloc_locals;
local quotient: Uint512;
local remainder: Uint256;

%{
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)

def pack(z, num_bits_shift: int) -> int:
limbs = (z.low, z.high)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

def pack_extended(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2, z.d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

x = pack_extended(ids.x, num_bits_shift = 128)
div = pack(ids.div, num_bits_shift = 128)

quotient, remainder = divmod(x, div)

quotient_split = split(quotient, num_bits_shift=128, length=4)

ids.quotient.d0 = quotient_split[0]
ids.quotient.d1 = quotient_split[1]
ids.quotient.d2 = quotient_split[2]
ids.quotient.d3 = quotient_split[3]

remainder_split = split(remainder, num_bits_shift=128, length=2)
ids.remainder.low = remainder_split[0]
ids.remainder.high = remainder_split[1]
%}

let res_mul: Uint768 = mul_u512_by_u256(quotient, div);

assert res_mul.d4 = 0;
assert res_mul.d5 = 0;

let check_val: Uint512 = add_u512_and_u256(
Uint512(res_mul.d0, res_mul.d1, res_mul.d2, res_mul.d3), remainder
);

// assert add_carry = 0;
assert check_val = x;

let is_valid = lt(remainder, div);
assert is_valid = 1;

return (quotient, remainder);
}

// Verifies that the given integer is valid.
func check{range_check_ptr}(a: Uint256) {
// tempvar h = a.high - 2 ** 127;
[range_check_ptr] = a.low;
[range_check_ptr + 1] = a.high;
let range_check_ptr = range_check_ptr + 2;
return ();
}

// Assume a and b are lower than 2**255-19
func add{range_check_ptr}(a: Uint256, b: Uint256) -> Uint256 {
alloc_locals;
local res: Uint256;
local carry_low: felt;
// unused. added to use UINT256_ADD
local carry_high: felt;
// this hint is not implemented:
// %{
// sum_low = ids.a.low + ids.b.low
// ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
// %}
%{
sum_low = ids.a.low + ids.b.low
ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
sum_high = ids.a.high + ids.b.high + ids.carry_low
ids.carry_high = 1 if sum_high >= ids.SHIFT else 0
%}
// changed hint, no carry_high
assert carry_low * carry_low = carry_low;

assert res.low = a.low + b.low - carry_low * SHIFT;
assert res.high = a.high + b.high + carry_low;
// check(res);

return res;
}

func add_carry{range_check_ptr}(a: Uint256, b: Uint256) -> (res: Uint256, carry: felt) {
alloc_locals;
local res: Uint256;
local carry_low: felt;
local carry_high: felt;
%{
sum_low = ids.a.low + ids.b.low
ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
sum_high = ids.a.high + ids.b.high + ids.carry_low
ids.carry_high = 1 if sum_high >= ids.SHIFT else 0
%}

assert carry_low * carry_low = carry_low;
assert carry_high * carry_high = carry_high;

assert res.low = a.low + b.low - carry_low * SHIFT;
assert res.high = a.high + b.high + carry_low - carry_high * SHIFT;
check(res);

return (res, carry_high);
}

func lt{range_check_ptr}(a: Uint256, b: Uint256) -> felt {
if (a.high == b.high) {
return is_le(a.low + 1, b.low);
}
return is_le(a.high + 1, b.high);
}
}

func main() {
return ();
}
Loading

0 comments on commit 7db0eb3

Please sign in to comment.