Skip to content

Commit

Permalink
feat: Implement hints on field_arithmetic lib (Part 2) (lambdaclass#1004
Browse files Browse the repository at this point in the history
)

* Add hint code for UINT348_UNSIGNED_DIV_REM

* Add file for uint348 files

* Add pack & split for uint348

* Move comment

* Implement uint348_unsigned_div_rem hint

* Add integration test

* Add integration test

* Add unit tests

* Add hint on split_128

* Test split_128 hint

* Add add_no_uint384_hint

* Fix hint + add tests

* Add hint code for UINT348_UNSIGNED_DIV_REM_EXPAND

* Msc fixes

* Add integration test

* Reduce Uint384_expand representation to the 3 used limbs

* Add unit test

* Add hint code for UINT384_SQRT

* Add implementation for hint on sqrt

* Integration test

* Add unit tests

* Fix missing directive

* Run cairo-format

* Add changelog entry

* Spelling

* Add hint code + Uint768 type

* Implement hint unsigned_div_rem_uint768_by_uint384

* Update src/hint_processor/builtin_hint_processor/uint384.rs

Co-authored-by: Mario Rugiero <mario.rugiero@lambdaclass.com>

* Update src/hint_processor/builtin_hint_processor/uint384.rs

Co-authored-by: Mario Rugiero <mario.rugiero@lambdaclass.com>

* Update src/hint_processor/builtin_hint_processor/uint384.rs

Co-authored-by: Mario Rugiero <mario.rugiero@lambdaclass.com>

* Make hint code more readable

* Add integration test

* Add test

* Add unit test

* Add changelog entry + fmt

* Fix plural

* cargo fmt

* Add first draft of get_square_root

* Fix test

* Fix syntax

* Fix test

* Add necessary lib fns

* fix fmt

* Fix test value

* Add test program

* Add hint to execute_hint

* Fix wrong hint being tested

* Implement sqrt

* Add test fix file

* Fix _sqrt_mod_tonelli_shanks implementation

* Expand integration test

* Add unit test

* Add proptests

* Fix merge conflict

* Fix merge conflict

* Add changelog entry

* Use no-std compatible rng when std is not enabled

* Clippy

* Use seeded rng instead of from_entropy

* Catch potential zero divison errors

* Implement hint on div

* Expand field_arithmetic integration test

* Expand field_arithmetic integration test

* Add test + fix hint

* Fix merge conflict

* Use mul_inv instead of div_mod

* Add Changelog entry

* Add unit tests

* remove unused feature

* Fix test value

* Update cairo_programs/field_arithmetic.cairo

Co-authored-by: Mario Rugiero <mario.rugiero@lambdaclass.com>

* Increase number of memory holes for field_arithmetic test

---------

Co-authored-by: Mario Rugiero <mario.rugiero@lambdaclass.com>
  • Loading branch information
2 people authored and kariy committed Jun 23, 2023
1 parent adaa140 commit 6b089c3
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 5 deletions.
36 changes: 36 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,42 @@

#### Upcoming Changes

* Implement hints on field_arithmetic lib (Part 2) [#1004](https://github.com/lambdaclass/cairo-rs/pull/1004)

`BuiltinHintProcessor` now supports the following hint:

```python
%{
from starkware.python.math_utils import div_mod

def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)

def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

a = pack(ids.a, num_bits_shift = 128)
b = pack(ids.b, num_bits_shift = 128)
p = pack(ids.p, num_bits_shift = 128)
# For python3.8 and above the modular inverse can be computed as follows:
# b_inverse_mod_p = pow(b, -1, p)
# Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils
b_inverse_mod_p = div_mod(1, b, p)


b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3)

ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0]
ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1]
ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2]
%}
```

* Optimizations for hash builtin [#1029](https://github.com/lambdaclass/cairo-rs/pull/1029):
* Track the verified addresses by offset in a `Vec<bool>` rather than storing the address in a `Vec<Relocatable>`

Expand Down
61 changes: 61 additions & 0 deletions cairo_programs/field_arithmetic.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,51 @@ namespace field_arithmetic {
}
}

// Computes a * b^{-1} modulo p
// NOTE: The modular inverse of b modulo p is computed in a hint and verified outside the hint with a multiplicaiton
func div{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (res: Uint384) {
alloc_locals;
local b_inverse_mod_p: Uint384;
%{
from starkware.python.math_utils import div_mod

def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)

def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

a = pack(ids.a, num_bits_shift = 128)
b = pack(ids.b, num_bits_shift = 128)
p = pack(ids.p, num_bits_shift = 128)
# For python3.8 and above the modular inverse can be computed as follows:
# b_inverse_mod_p = pow(b, -1, p)
# Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils
b_inverse_mod_p = div_mod(1, b, p)


b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3)

ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0]
ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1]
ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2]
%}
uint384_lib.check(b_inverse_mod_p);
let (b_times_b_inverse) = mul(b, b_inverse_mod_p, p);
assert b_times_b_inverse = Uint384(1, 0, 0);

let (res: Uint384) = mul(a, b_inverse_mod_p, p);
return (res,);
}
}

func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
alloc_locals;
// Test get_square

//Small prime
Expand Down Expand Up @@ -162,6 +204,25 @@ func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: B
assert r_c.d1 = 0;
assert r_c.d2 = 0;

// Test div
// Small inputs
let a = Uint384(25, 0, 0);
let a_div = Uint384(5, 0, 0);
let a_p = Uint384(31, 0, 0);
let (a_r) = field_arithmetic.div(a, a_div, a_p);
assert a_r.d0 = 5;
assert a_r.d1 = 0;
assert a_r.d2 = 0;

// Cairo Prime
let b = Uint384(1, 0, 5044639098474805171426);
let b_div = Uint384(1, 0, 2);
let b_p = Uint384(1, 0, 604462909807314605178880);
let (b_r) = field_arithmetic.div(b, b_div, b_p);
assert b_r.d0 = 280171807489444591652763463227596156607;
assert b_r.d1 = 122028556426724038784654414222572127555;
assert b_r.d2 = 410614585309032623322981;

return ();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ use felt::Felt252;
#[cfg(feature = "skip_next_instruction_hint")]
use crate::hint_processor::builtin_hint_processor::skip_next_instruction::skip_next_instruction;

use super::field_arithmetic::uint384_div;
use super::vrf::inv_mod_p_uint512::inv_mod_p_uint512;

pub struct HintProcessorData {
Expand Down Expand Up @@ -562,6 +563,7 @@ impl HintProcessor for BuiltinHintProcessor {
hint_code::UINT384_SIGNED_NN => {
uint384_signed_nn(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::UINT384_DIV => uint384_div(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::UINT256_MUL_DIV_MOD => {
uint256_mul_div_mod(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
Expand Down
168 changes: 166 additions & 2 deletions src/hint_processor/builtin_hint_processor/field_arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use felt::Felt252;
use num_bigint::BigUint;
use num_bigint::{BigUint, ToBigInt};
use num_integer::Integer;
use num_traits::Zero;

use crate::math_utils::{is_quad_residue, sqrt_prime_power};
use crate::math_utils::{is_quad_residue, mul_inv, sqrt_prime_power};
use crate::serde::deserialize_program::ApTracking;
use crate::stdlib::{collections::HashMap, prelude::*};
use crate::types::errors::math_errors::MathError;
use crate::vm::errors::hint_errors::HintError;
use crate::{
hint_processor::hint_processor_definition::HintReference, vm::vm_core::VirtualMachine,
Expand Down Expand Up @@ -112,6 +114,68 @@ pub fn get_square_root(

Ok(())
}

/* Implements Hint:
%{
from starkware.python.math_utils import div_mod
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)
def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
a = pack(ids.a, num_bits_shift = 128)
b = pack(ids.b, num_bits_shift = 128)
p = pack(ids.p, num_bits_shift = 128)
# For python3.8 and above the modular inverse can be computed as follows:
# b_inverse_mod_p = pow(b, -1, p)
# Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils
b_inverse_mod_p = div_mod(1, b, p)
b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3)
ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0]
ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1]
ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2]
%}
*/
pub fn uint384_div(
vm: &mut VirtualMachine,
ids_data: &HashMap<String, HintReference>,
ap_tracking: &ApTracking,
) -> Result<(), HintError> {
// Note: ids.a is not used here, nor is it used by following hints, so we dont need to extract it.
let b = pack(BigInt3::from_var_name("b", vm, ids_data, ap_tracking)?, 128)
.to_bigint()
.unwrap_or_default();
let p = pack(BigInt3::from_var_name("p", vm, ids_data, ap_tracking)?, 128)
.to_bigint()
.unwrap_or_default();
let b_inverse_mod_p_addr =
get_relocatable_from_var_name("b_inverse_mod_p", vm, ids_data, ap_tracking)?;
if b.is_zero() {
return Err(MathError::DividedByZero.into());
}
let b_inverse_mod_p = mul_inv(&b, &p)
.mod_floor(&p)
.to_biguint()
.unwrap_or_default();
let b_inverse_mod_p_split = split::<3>(&b_inverse_mod_p, 128);
for (i, b_inverse_mod_p_split) in b_inverse_mod_p_split.iter().enumerate() {
vm.insert_value(
(b_inverse_mod_p_addr + i)?,
Felt252::from(b_inverse_mod_p_split),
)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -268,4 +332,104 @@ mod tests {
((1, 15), 0)
];
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint384_div_ok() {
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 11;
//Create hint_data
let ids_data =
non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)];
//Insert ids into memory
vm.segments = segments![
//a
((1, 0), 25),
((1, 1), 0),
((1, 2), 0),
//b
((1, 3), 5),
((1, 4), 0),
((1, 5), 0),
//p
((1, 6), 31),
((1, 7), 0),
((1, 8), 0)
];
//Execute the hint
assert_matches!(run_hint!(vm, ids_data, hint_code::UINT384_DIV), Ok(()));
//Check hint memory inserts
check_memory![
vm.segments.memory,
// b_inverse_mod_p
((1, 9), 25),
((1, 10), 0),
((1, 11), 0)
];
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint384_div_b_is_zero() {
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 11;
//Create hint_data
let ids_data =
non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)];
//Insert ids into memory
vm.segments = segments![
//a
((1, 0), 25),
((1, 1), 0),
((1, 2), 0),
//b
((1, 3), 0),
((1, 4), 0),
((1, 5), 0),
//p
((1, 6), 31),
((1, 7), 0),
((1, 8), 0)
];
//Execute the hint
assert_matches!(
run_hint!(vm, ids_data, hint_code::UINT384_DIV),
Err(HintError::Math(MathError::DividedByZero))
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint384_div_inconsistent_memory() {
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 11;
//Create hint_data
let ids_data =
non_continuous_ids_data![("a", -11), ("b", -8), ("p", -5), ("b_inverse_mod_p", -2)];
//Insert ids into memory
vm.segments = segments![
//a
((1, 0), 25),
((1, 1), 0),
((1, 2), 0),
//b
((1, 3), 5),
((1, 4), 0),
((1, 5), 0),
//p
((1, 6), 31),
((1, 7), 0),
((1, 8), 0),
//b_inverse_mod_p
((1, 9), 0)
];
//Execute the hint
assert_matches!(
run_hint!(vm, ids_data, hint_code::UINT384_DIV),
Err(HintError::Memory(MemoryError::InconsistentMemory(_, _, _)))
);
}
}
29 changes: 28 additions & 1 deletion src/hint_processor/builtin_hint_processor/hint_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ s = pack(ids.s, PRIME) % N
value = res = div_mod(x, s, N)";
pub(crate) const XS_SAFE_DIV: &str = "value = k = safe_div(res * s - x, N)";

// The following hints support the lib https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib/uint384.cairo
// The following hints support the lib https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib
pub const UINT384_UNSIGNED_DIV_REM: &str = "def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
Expand Down Expand Up @@ -914,6 +914,33 @@ ids.sqrt_x.d2 = split_root_x[2]
ids.sqrt_gx.d0 = split_root_gx[0]
ids.sqrt_gx.d1 = split_root_gx[1]
ids.sqrt_gx.d2 = split_root_gx[2]";
pub const UINT384_DIV: &str = "from starkware.python.math_utils import div_mod
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)
def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
a = pack(ids.a, num_bits_shift = 128)
b = pack(ids.b, num_bits_shift = 128)
p = pack(ids.p, num_bits_shift = 128)
# For python3.8 and above the modular inverse can be computed as follows:
# b_inverse_mod_p = pow(b, -1, p)
# Instead we use the python3.7-friendly function div_mod from starkware.python.math_utils
b_inverse_mod_p = div_mod(1, b, p)
b_inverse_mod_p_split = split(b_inverse_mod_p, num_bits_shift=128, length=3)
ids.b_inverse_mod_p.d0 = b_inverse_mod_p_split[0]
ids.b_inverse_mod_p.d1 = b_inverse_mod_p_split[1]
ids.b_inverse_mod_p.d2 = b_inverse_mod_p_split[2]";
pub const HI_MAX_BITLEN: &str =
"ids.len_hi = max(ids.scalar_u.d2.bit_length(), ids.scalar_v.d2.bit_length())-1";

Expand Down
2 changes: 1 addition & 1 deletion src/math_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub fn safe_div_usize(x: usize, y: usize) -> Result<usize, MathError> {
}

///Returns num_a^-1 mod p
fn mul_inv(num_a: &BigInt, p: &BigInt) -> BigInt {
pub(crate) fn mul_inv(num_a: &BigInt, p: &BigInt) -> BigInt {
if num_a.is_zero() {
return BigInt::zero();
}
Expand Down
2 changes: 1 addition & 1 deletion src/tests/cairo_run_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ fn uint384_extension() {
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn field_arithmetic() {
let program_data = include_bytes!("../../cairo_programs/field_arithmetic.json");
run_program_simple_with_memory_holes(program_data.as_slice(), 192);
run_program_simple_with_memory_holes(program_data.as_slice(), 272);
}

#[test]
Expand Down

0 comments on commit 6b089c3

Please sign in to comment.