Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hints): add NewHint#47 #1030

Merged
merged 34 commits into from
Apr 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ee6d7e3
Refactor some cairo_programs as library
MegaRedHand Apr 20, 2023
59d9f35
Refactor pack and split, macro hygiene
MegaRedHand Apr 21, 2023
35d7275
Add NewHint#47
MegaRedHand Apr 21, 2023
2e6cd35
Update changelog
MegaRedHand Apr 21, 2023
70c22fa
Add stdlib prelude import
MegaRedHand Apr 21, 2023
a01dae9
Import HashMap from stdlib
MegaRedHand Apr 21, 2023
1de8d7d
Fix contracts
MegaRedHand Apr 21, 2023
f16eb18
Fix wasm-tests
MegaRedHand Apr 21, 2023
4a55fd4
Merge branch 'main' into newhint47-u512_udiv_rem
MegaRedHand Apr 21, 2023
7a1bea7
Change bitmask to `u128::MAX`
MegaRedHand Apr 21, 2023
a941a8e
Revert "Change bitmask to `u128::MAX`"
MegaRedHand Apr 21, 2023
14bedfe
Compare memory_holes with correct amount
MegaRedHand Apr 21, 2023
a3cb6aa
Rename uint512 -> fq, fq -> fq_test; Add comment
MegaRedHand Apr 21, 2023
726a3ef
Fix contract compilation errors
MegaRedHand Apr 21, 2023
425b0d0
Merge branch 'main' into newhint47-u512_udiv_rem
MegaRedHand Apr 21, 2023
af3d056
Fix compilation error
MegaRedHand Apr 21, 2023
b06e426
Fix cairo_run_fq test
MegaRedHand Apr 21, 2023
89ac8d2
Merge branch 'main' into newhint47-u512_udiv_rem
MegaRedHand Apr 21, 2023
976220c
Move changelog entry
MegaRedHand Apr 21, 2023
e233696
Fix failing tests
MegaRedHand Apr 24, 2023
44d2ddc
Merge branch 'main' into newhint47-u512_udiv_rem
MegaRedHand Apr 24, 2023
a20f4f3
Fix compilation errors
MegaRedHand Apr 24, 2023
e1c8a3f
Fix test error
MegaRedHand Apr 24, 2023
f07c7bf
Define empty mains on libraries
MegaRedHand Apr 24, 2023
fbfd226
Use bare array instead of vec
MegaRedHand Apr 24, 2023
99181d5
Make pack and split methods instead of functions
MegaRedHand Apr 24, 2023
2244805
Merge branch 'main' into newhint47-u512_udiv_rem
MegaRedHand Apr 24, 2023
f895c4f
Fix merge errors
MegaRedHand Apr 24, 2023
32e7e95
Fix errors and add BigInt3::split86
MegaRedHand Apr 24, 2023
c141e93
Merge branch 'main' into newhint47-u512_udiv_rem
MegaRedHand Apr 24, 2023
3cf7081
Make u512_pack and split functions into methods
MegaRedHand Apr 24, 2023
e1a5ae1
Re-add disappeared newline inside hint
MegaRedHand Apr 24, 2023
6888ed8
Fix compilation errors
MegaRedHand Apr 24, 2023
679121e
Merge branch 'main' into newhint47-u512_udiv_rem
MegaRedHand Apr 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,43 @@
ids.res.high = res_split[1]
```

* Add missing hint on vrf.json lib [#1030](https://github.com/lambdaclass/cairo-rs/pull/1030):

`BuiltinHintProcessor` now supports the following hint:

```python
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)

def pack(z, num_bits_shift: int) -> int:
limbs = (z.low, z.high)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

def pack_extended(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2, z.d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

x = pack_extended(ids.x, num_bits_shift = 128)
div = pack(ids.div, num_bits_shift = 128)

quotient, remainder = divmod(x, div)

quotient_split = split(quotient, num_bits_shift=128, length=4)

ids.quotient.d0 = quotient_split[0]
ids.quotient.d1 = quotient_split[1]
ids.quotient.d2 = quotient_split[2]
ids.quotient.d3 = quotient_split[3]

remainder_split = split(remainder, num_bits_shift=128, length=2)
ids.remainder.low = remainder_split[0]
ids.remainder.high = remainder_split[1]
```

* Add method `Program::data_len(&self) -> usize` to get the number of data cells in a given program [#1022](https://github.com/lambdaclass/cairo-rs/pull/1022)

* Add missing hint on uint256_improvements lib [#1013](https://github.com/lambdaclass/cairo-rs/pull/1013):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.bool import TRUE
from cairo_programs.uint384 import uint384_lib, Uint384, Uint384_expand
from cairo_programs.uint384_extension import uint384_extension_lib
from cairo_programs.uint384 import u384, Uint384, Uint384_expand
from cairo_programs.uint384_extension import u384_ext
from cairo_programs.field_arithmetic import field_arithmetic


func run_get_square{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(prime: Uint384, generator: Uint384, num: Uint384, iterations: felt) {
func run_get_square{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
prime: Uint384, generator: Uint384, num: Uint384, iterations: felt
) {
alloc_locals;
if (iterations == 0) {
return ();
}

let (square) = field_arithmetic.mul(num, num, prime);
let (square) = field_arithmetic.mul(num, num, prime);

let (success, root_1) = field_arithmetic.get_square_root(square, prime, generator);
assert success = 1;

// We calculate this before in order to prevent revoked range_check_ptr reference due to branching
let (root_2) = uint384_lib.sub(prime, root_1);
let (is_first_root) = uint384_lib.eq(root_1, num);
let (root_2) = u384.sub(prime, root_1);
let (is_first_root) = u384.eq(root_1, num);

if ( is_first_root != TRUE) {
if (is_first_root != TRUE) {
assert root_2 = num;
}

return run_get_square(prime, generator, square, iterations -1);
return run_get_square(prime, generator, square, iterations - 1);
}

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
let p = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime
let p = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime
let x = Uint384(5, 0, 0);
let g = Uint384(7, 0, 0);
run_get_square(p, g, x, 100);
Expand Down
46 changes: 24 additions & 22 deletions cairo_programs/field_arithmetic.cairo
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
%builtins range_check bitwise

// Code taken from https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/fbf532651959f27037d70cd70ec6dbaf987f535c/lib/field_arithmetic.cairo
from starkware.cairo.common.bitwise import bitwise_and, bitwise_or, bitwise_xor
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.math import assert_in_range, assert_le, assert_nn_le, assert_not_zero
from starkware.cairo.common.math_cmp import is_le
from starkware.cairo.common.pow import pow
from starkware.cairo.common.registers import get_ap, get_fp_and_pc
from cairo_programs.uint384 import uint384_lib, Uint384, Uint384_expand, SHIFT, HALF_SHIFT
from cairo_programs.uint384_extension import uint384_extension_lib, Uint768
from cairo_programs.uint384 import u384, Uint384, Uint384_expand, SHIFT, HALF_SHIFT
from cairo_programs.uint384_extension import u384_ext, Uint768

// Functions for operating elements in a finite field F_p (i.e. modulo a prime p), with p of at most 384 bits
namespace field_arithmetic {
// Computes a * b modulo p
func mul{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (res: Uint384) {
let (low: Uint384, high: Uint384) = uint384_lib.mul_d(a, b);
let (low: Uint384, high: Uint384) = u384.mul_d(a, b);
let full_mul_result: Uint768 = Uint768(low.d0, low.d1, low.d2, high.d0, high.d1, high.d2);
let (
quotient: Uint768, remainder: Uint384
) = uint384_extension_lib.unsigned_div_rem_uint768_by_uint384(full_mul_result, p);
let (quotient: Uint768, remainder: Uint384) = u384_ext.unsigned_div_rem_uint768_by_uint384(
full_mul_result, p
);
return (remainder,);
}

// Computes a**2 modulo p
func square{range_check_ptr}(a: Uint384, p: Uint384) -> (res: Uint384) {
let (low: Uint384, high: Uint384) = uint384_lib.square_e(a);
let (low: Uint384, high: Uint384) = u384.square_e(a);
let full_mul_result: Uint768 = Uint768(low.d0, low.d1, low.d2, high.d0, high.d1, high.d2);
let (
quotient: Uint768, remainder: Uint384
) = uint384_extension_lib.unsigned_div_rem_uint768_by_uint384(full_mul_result, p);
let (quotient: Uint768, remainder: Uint384) = u384_ext.unsigned_div_rem_uint768_by_uint384(
full_mul_result, p
);
return (remainder,);
}

Expand All @@ -43,7 +45,7 @@ namespace field_arithmetic {
alloc_locals;

// TODO: Create an equality function within field_arithmetic to avoid overflow bugs
let (is_zero) = uint384_lib.eq(x, Uint384(0, 0, 0));
let (is_zero) = u384.eq(x, Uint384(0, 0, 0));
if (is_zero == 1) {
return (1, Uint384(0, 0, 0));
}
Expand Down Expand Up @@ -101,22 +103,22 @@ namespace field_arithmetic {
// Verify that the values computed in the hint are what they are supposed to be
let (gx: Uint384) = mul(generator, x, p);
if (success_x == 1) {
uint384_lib.check(sqrt_x);
let (is_valid) = uint384_lib.lt(sqrt_x, p);
assert is_valid = 1;
u384.check(sqrt_x);
let (is_valid) = u384.lt(sqrt_x, p);
assert is_valid = 1;
let (sqrt_x_squared: Uint384) = mul(sqrt_x, sqrt_x, p);
// Note these checks may fail if the input x does not satisfy 0<= x < p
// TODO: Create a equality function within field_arithmetic to avoid overflow bugs
let (check_x) = uint384_lib.eq(x, sqrt_x_squared);
let (check_x) = u384.eq(x, sqrt_x_squared);
assert check_x = 1;
return (1, sqrt_x);
} else {
// In this case success_gx = 1
uint384_lib.check(sqrt_gx);
let (is_valid) = uint384_lib.lt(sqrt_gx, p);
assert is_valid = 1;
u384.check(sqrt_gx);
let (is_valid) = u384.lt(sqrt_gx, p);
assert is_valid = 1;
let (sqrt_gx_squared: Uint384) = mul(sqrt_gx, sqrt_gx, p);
let (check_gx) = uint384_lib.eq(gx, sqrt_gx_squared);
let (check_gx) = u384.eq(gx, sqrt_gx_squared);
assert check_gx = 1;
// No square roots were found
// Note that Uint384(0, 0, 0) is not a square root here, but something needs to be returned
Expand Down Expand Up @@ -158,7 +160,7 @@ namespace field_arithmetic {
ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1]
ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2]
%}
uint384_lib.check(b_inverse_mod_p);
u384.check(b_inverse_mod_p);
let (b_times_b_inverse) = mul(b, b_inverse_mod_p, p);
assert b_times_b_inverse = Uint384(1, 0, 0);

Expand All @@ -171,7 +173,7 @@ func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: B
alloc_locals;
// Test get_square

//Small prime
// Small prime
let p_a = Uint384(7, 0, 0);
let x_a = Uint384(2, 0, 0);
let generator_a = Uint384(3, 0, 0);
Expand All @@ -183,7 +185,7 @@ func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: B
assert r_a.d2 = 0;

// Goldilocks Prime
let p_b = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime
let p_b = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime
let x_b = Uint384(25, 0, 0);
let generator_b = Uint384(7, 0, 0);
let (s_b, r_b) = field_arithmetic.get_square_root(x_b, p_b, generator_b);
Expand Down
195 changes: 195 additions & 0 deletions cairo_programs/fq.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from starkware.cairo.common.uint256 import Uint256, split_64
from starkware.cairo.common.math_cmp import is_le

from cairo_programs.uint384_extension import Uint384, Uint768, u384

struct Uint512 {
d0: felt,
d1: felt,
d2: felt,
d3: felt,
}

const SHIFT = 2 ** 128;
const ALL_ONES = 2 ** 128 - 1;
const HALF_SHIFT = 2 ** 64;

namespace u512 {
func add_u512_and_u256{range_check_ptr}(a: Uint512, b: Uint256) -> Uint512 {
alloc_locals;

let a_low = Uint256(low=a.d0, high=a.d1);
let a_high = Uint256(low=a.d2, high=a.d3);

let (sum_low, carry0) = add_carry(a_low, b);

local res: Uint512;

res.d0 = sum_low.low;
res.d1 = sum_low.high;
// res.d2 = sum_low.d2;

// TODO : create add_one (high bits not needed)
let a_high_plus_carry = add(a_high, Uint256(carry0, 0));

res.d2 = a_high_plus_carry.low;
res.d3 = a_high_plus_carry.high;

return res;
}

func mul_u512_by_u256{range_check_ptr}(a: Uint512, b: Uint256) -> Uint768 {
alloc_locals;
let (a0, a1) = split_64(a.d0);
let (a2, a3) = split_64(a.d1);
let (a4, a5) = split_64(a.d2);
let (a6, a7) = split_64(a.d3);

let (b0, b1) = split_64(b.low);
let (b2, b3) = split_64(b.high);

local B0 = b0 * HALF_SHIFT;
local b12 = b1 + b2 * HALF_SHIFT;

let (res0, carry) = u384.split_128(a1 * B0 + a0 * b.low);
let (res2, carry) = u384.split_128(a3 * B0 + a2 * b.low + a1 * b12 + a0 * b.high + carry);
let (res4, carry) = u384.split_128(
a5 * B0 + a4 * b.low + a3 * b12 + a2 * b.high + a1 * b3 + carry
);
let (res6, carry) = u384.split_128(
a7 * B0 + a6 * b.low + a5 * b12 + a4 * b.high + a3 * b3 + carry
);
let (res8, carry) = u384.split_128(a7 * b12 + a6 * b.high + a5 * b3 + carry);
let (res10, carry) = u384.split_128(a7 * b3 + carry);
let res = Uint768(d0=res0, d1=res2, d2=res4, d3=res6, d4=res8, d5=res10);
return res;
}

func u512_unsigned_div_rem{range_check_ptr}(x: Uint512, div: Uint256) -> (
q: Uint512, r: Uint256
) {
alloc_locals;
local quotient: Uint512;
local remainder: Uint256;

%{
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)

def pack(z, num_bits_shift: int) -> int:
limbs = (z.low, z.high)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

def pack_extended(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2, z.d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

x = pack_extended(ids.x, num_bits_shift = 128)
div = pack(ids.div, num_bits_shift = 128)

quotient, remainder = divmod(x, div)

quotient_split = split(quotient, num_bits_shift=128, length=4)

ids.quotient.d0 = quotient_split[0]
ids.quotient.d1 = quotient_split[1]
ids.quotient.d2 = quotient_split[2]
ids.quotient.d3 = quotient_split[3]

remainder_split = split(remainder, num_bits_shift=128, length=2)
ids.remainder.low = remainder_split[0]
ids.remainder.high = remainder_split[1]
%}

let res_mul: Uint768 = mul_u512_by_u256(quotient, div);

assert res_mul.d4 = 0;
assert res_mul.d5 = 0;

let check_val: Uint512 = add_u512_and_u256(
Uint512(res_mul.d0, res_mul.d1, res_mul.d2, res_mul.d3), remainder
);

// assert add_carry = 0;
assert check_val = x;

let is_valid = lt(remainder, div);
assert is_valid = 1;

return (quotient, remainder);
}

// Verifies that the given integer is valid.
func check{range_check_ptr}(a: Uint256) {
// tempvar h = a.high - 2 ** 127;
[range_check_ptr] = a.low;
[range_check_ptr + 1] = a.high;
let range_check_ptr = range_check_ptr + 2;
return ();
}

// Assume a and b are lower than 2**255-19
func add{range_check_ptr}(a: Uint256, b: Uint256) -> Uint256 {
alloc_locals;
local res: Uint256;
local carry_low: felt;
// unused. added to use UINT256_ADD
local carry_high: felt;
// this hint is not implemented:
// %{
// sum_low = ids.a.low + ids.b.low
// ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
// %}
%{
sum_low = ids.a.low + ids.b.low
ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
sum_high = ids.a.high + ids.b.high + ids.carry_low
ids.carry_high = 1 if sum_high >= ids.SHIFT else 0
%}
// changed hint, no carry_high
assert carry_low * carry_low = carry_low;

assert res.low = a.low + b.low - carry_low * SHIFT;
assert res.high = a.high + b.high + carry_low;
// check(res);

return res;
}

func add_carry{range_check_ptr}(a: Uint256, b: Uint256) -> (res: Uint256, carry: felt) {
alloc_locals;
local res: Uint256;
local carry_low: felt;
local carry_high: felt;
%{
sum_low = ids.a.low + ids.b.low
ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
sum_high = ids.a.high + ids.b.high + ids.carry_low
ids.carry_high = 1 if sum_high >= ids.SHIFT else 0
%}

assert carry_low * carry_low = carry_low;
assert carry_high * carry_high = carry_high;

assert res.low = a.low + b.low - carry_low * SHIFT;
assert res.high = a.high + b.high + carry_low - carry_high * SHIFT;
check(res);

return (res, carry_high);
}

func lt{range_check_ptr}(a: Uint256, b: Uint256) -> felt {
if (a.high == b.high) {
return is_le(a.low + 1, b.low);
}
return is_le(a.high + 1, b.high);
}
}

func main() {
return ();
}
Loading
Loading