diff --git a/CHANGELOG.md b/CHANGELOG.md index f2024564a4..abb4be7a03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,57 @@ #### Upcoming Changes +* Implement hints on field_arithmetic lib[#985](https://github.com/lambdaclass/cairo-rs/pull/983) + + `BuiltinHintProcessor` now supports the following hint: + + ```python + %{ + from starkware.python.math_utils import is_quad_residue, sqrt + + def split(num: int, num_bits_shift: int = 128, length: int = 3): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int = 128) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + + generator = pack(ids.generator) + x = pack(ids.x) + p = pack(ids.p) + + success_x = is_quad_residue(x, p) + root_x = sqrt(x, p) if success_x else None + + success_gx = is_quad_residue(generator*x, p) + root_gx = sqrt(generator*x, p) if success_gx else None + + # Check that one is 0 and the other is 1 + if x != 0: + assert success_x + success_gx ==1 + + # `None` means that no root was found, but we need to transform these into a felt no matter what + if root_x == None: + root_x = 0 + if root_gx == None: + root_gx = 0 + ids.success_x = int(success_x) + split_root_x = split(root_x) + split_root_gx = split(root_gx) + ids.sqrt_x.d0 = split_root_x[0] + ids.sqrt_x.d1 = split_root_x[1] + ids.sqrt_x.d2 = split_root_x[2] + ids.sqrt_gx.d0 = split_root_gx[0] + ids.sqrt_gx.d1 = split_root_gx[1] + ids.sqrt_gx.d2 = split_root_gx[2] + %} + ``` + * Add missing hint on uint256_improvements lib [#1016](https://github.com/lambdaclass/cairo-rs/pull/1016): `BuiltinHintProcessor` now supports the following hint: @@ -85,6 +136,7 @@ * The new version carries an 85% reduction in execution time for ECDSA signature verification * BREAKING CHANGE: refactor `Program` to optimize `Program::clone` [#999](https://github.com/lambdaclass/cairo-rs/pull/999) + * Breaking change: many fields that were (unnecessarily) public become hidden by the refactor. * BREAKING CHANGE: Add _builtin suffix to builtin names e.g.: output -> output_builtin [#1005](https://github.com/lambdaclass/cairo-rs/pull/1005) diff --git a/Cargo.lock b/Cargo.lock index 80bf85a28f..f9b54c6a45 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,17 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "ahash" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcb51a0695d8f838b1ee009b3fbf66bda078cd64590202a864a8f3e8c4315c47" +dependencies = [ + "getrandom", + "once_cell", + "version_check", +] + [[package]] name = "ahash" version = "0.8.3" @@ -214,8 +225,10 @@ dependencies = [ "nom", "num-bigint", "num-integer", + "num-prime", "num-traits", "proptest", + "rand", "rand_core", "rstest", "rusty-hook", @@ -694,6 +707,9 @@ name = "hashbrown" version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +dependencies = [ + "ahash 0.7.6", +] [[package]] name = "hashbrown" @@ -701,7 +717,7 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" dependencies = [ - "ahash", + "ahash 0.8.3", "serde", ] @@ -865,6 +881,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "lru" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999beba7b6e8345721bd280141ed958096a2e4abdf74f67ff4ce49b4b54e47a" +dependencies = [ + "hashbrown 0.12.3", +] + [[package]] name = "memchr" version = "2.5.0" @@ -920,6 +945,7 @@ dependencies = [ "autocfg", "num-integer", "num-traits", + "rand", "serde", ] @@ -933,6 +959,33 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-modular" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a5fe11d4135c3bcdf3a95b18b194afa9608a5f6ff034f5d857bc9a27fb0119" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-prime" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f4e3bc495f6e95bc15a6c0c55ac00421504a5a43d09e3cc455d1fea7015581d" +dependencies = [ + "bitvec", + "either", + "lru", + "num-bigint", + "num-integer", + "num-modular", + "num-traits", + "rand", +] + [[package]] name = "num-traits" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index 0a22e81198..70b65f0318 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,9 +32,11 @@ hooks = [] [dependencies] mimalloc = { version = "0.1.29", default-features = false, optional = true } -num-bigint = { version = "0.4", features = ["serde"], default-features = false } +num-bigint = { version = "0.4", features = ["serde", "rand"], default-features = false } +rand = { version = "0.8.3", features = ["small_rng"], default-features = false } num-traits = { version = "0.2", default-features = false } num-integer = { version = "0.1.45", default-features = false } +num-prime = {version = "0.4.3", features = ["big-int"], default-features = false } serde = { version = "1.0", features = ["derive"], default-features = false } serde_bytes = { version = "0.11.9", default-features = false, features = [ "alloc", diff --git a/cairo_programs/benchmarks/field_arithmetic_get_square_benchmark.cairo b/cairo_programs/benchmarks/field_arithmetic_get_square_benchmark.cairo new file mode 100644 index 0000000000..28b95e3c34 --- /dev/null +++ b/cairo_programs/benchmarks/field_arithmetic_get_square_benchmark.cairo @@ -0,0 +1,36 @@ +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.bool import TRUE +from cairo_programs.uint384 import uint384_lib, Uint384, Uint384_expand +from cairo_programs.uint384_extension import uint384_extension_lib +from cairo_programs.field_arithmetic import field_arithmetic + + +func run_get_square{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(prime: Uint384, generator: Uint384, num: Uint384, iterations: felt) { + alloc_locals; + if (iterations == 0) { + return (); + } + + let (square) = field_arithmetic.mul(num, num, prime); + + let (success, root_1) = field_arithmetic.get_square_root(square, prime, generator); + assert success = 1; + + // We calculate this before in order to prevent revoked range_check_ptr reference due to branching + let (root_2) = uint384_lib.sub(prime, root_1); + let (is_first_root) = uint384_lib.eq(root_1, num); + + if ( is_first_root != TRUE) { + assert root_2 = num; + } + + return run_get_square(prime, generator, square, iterations -1); +} + +func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() { + let p = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime + let x = Uint384(5, 0, 0); + let g = Uint384(7, 0, 0); + run_get_square(p, g, x, 100); + return (); +} diff --git a/cairo_programs/field_arithmetic.cairo b/cairo_programs/field_arithmetic.cairo new file mode 100644 index 0000000000..9d9c09393e --- /dev/null +++ b/cairo_programs/field_arithmetic.cairo @@ -0,0 +1,171 @@ +// Code taken from https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/fbf532651959f27037d70cd70ec6dbaf987f535c/lib/field_arithmetic.cairo +from starkware.cairo.common.bitwise import bitwise_and, bitwise_or, bitwise_xor +from starkware.cairo.common.cairo_builtins import BitwiseBuiltin +from starkware.cairo.common.math import assert_in_range, assert_le, assert_nn_le, assert_not_zero +from starkware.cairo.common.math_cmp import is_le +from starkware.cairo.common.pow import pow +from starkware.cairo.common.registers import get_ap, get_fp_and_pc +from cairo_programs.uint384 import uint384_lib, Uint384, Uint384_expand, SHIFT, HALF_SHIFT +from cairo_programs.uint384_extension import uint384_extension_lib, Uint768 + +// Functions for operating elements in a finite field F_p (i.e. modulo a prime p), with p of at most 384 bits +namespace field_arithmetic { + // Computes a * b modulo p + func mul{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (res: Uint384) { + let (low: Uint384, high: Uint384) = uint384_lib.mul_d(a, b); + let full_mul_result: Uint768 = Uint768(low.d0, low.d1, low.d2, high.d0, high.d1, high.d2); + let ( + quotient: Uint768, remainder: Uint384 + ) = uint384_extension_lib.unsigned_div_rem_uint768_by_uint384(full_mul_result, p); + return (remainder,); + } + + // Computes a**2 modulo p + func square{range_check_ptr}(a: Uint384, p: Uint384) -> (res: Uint384) { + let (low: Uint384, high: Uint384) = uint384_lib.square_e(a); + let full_mul_result: Uint768 = Uint768(low.d0, low.d1, low.d2, high.d0, high.d1, high.d2); + let ( + quotient: Uint768, remainder: Uint384 + ) = uint384_extension_lib.unsigned_div_rem_uint768_by_uint384(full_mul_result, p); + return (remainder,); + } + + // Finds a square of x in F_p, i.e. x ≅ y**2 (mod p) for some y + // To do so, the following is done in a hint: + // 0. Assume x is not 0 mod p + // 1. Check if x is a square, if yes, find a square root r of it + // 2. If (and only if not), then gx *is* a square (for g a generator of F_p^*), so find a square root r of it + // 3. Check in Cairo that r**2 = x (mod p) or r**2 = gx (mod p), respectively + // NOTE: The function assumes that 0 <= x < p + func get_square_root{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}( + x: Uint384, p: Uint384, generator: Uint384 + ) -> (success: felt, res: Uint384) { + alloc_locals; + + // TODO: Create an equality function within field_arithmetic to avoid overflow bugs + let (is_zero) = uint384_lib.eq(x, Uint384(0, 0, 0)); + if (is_zero == 1) { + return (1, Uint384(0, 0, 0)); + } + + local success_x: felt; + local sqrt_x: Uint384; + local sqrt_gx: Uint384; + + // Compute square roots in a hint + %{ + from starkware.python.math_utils import is_quad_residue, sqrt + + def split(num: int, num_bits_shift: int = 128, length: int = 3): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int = 128) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + + generator = pack(ids.generator) + x = pack(ids.x) + p = pack(ids.p) + + success_x = is_quad_residue(x, p) + root_x = sqrt(x, p) if success_x else None + + success_gx = is_quad_residue(generator*x, p) + root_gx = sqrt(generator*x, p) if success_gx else None + + # Check that one is 0 and the other is 1 + if x != 0: + assert success_x + success_gx ==1 + + # `None` means that no root was found, but we need to transform these into a felt no matter what + if root_x == None: + root_x = 0 + if root_gx == None: + root_gx = 0 + ids.success_x = int(success_x) + split_root_x = split(root_x) + split_root_gx = split(root_gx) + ids.sqrt_x.d0 = split_root_x[0] + ids.sqrt_x.d1 = split_root_x[1] + ids.sqrt_x.d2 = split_root_x[2] + ids.sqrt_gx.d0 = split_root_gx[0] + ids.sqrt_gx.d1 = split_root_gx[1] + ids.sqrt_gx.d2 = split_root_gx[2] + %} + + // Verify that the values computed in the hint are what they are supposed to be + let (gx: Uint384) = mul(generator, x, p); + if (success_x == 1) { + uint384_lib.check(sqrt_x); + let (is_valid) = uint384_lib.lt(sqrt_x, p); + assert is_valid = 1; + let (sqrt_x_squared: Uint384) = mul(sqrt_x, sqrt_x, p); + // Note these checks may fail if the input x does not satisfy 0<= x < p + // TODO: Create a equality function within field_arithmetic to avoid overflow bugs + let (check_x) = uint384_lib.eq(x, sqrt_x_squared); + assert check_x = 1; + return (1, sqrt_x); + } else { + // In this case success_gx = 1 + uint384_lib.check(sqrt_gx); + let (is_valid) = uint384_lib.lt(sqrt_gx, p); + assert is_valid = 1; + let (sqrt_gx_squared: Uint384) = mul(sqrt_gx, sqrt_gx, p); + let (check_gx) = uint384_lib.eq(gx, sqrt_gx_squared); + assert check_gx = 1; + // No square roots were found + // Note that Uint384(0, 0, 0) is not a square root here, but something needs to be returned + return (0, Uint384(0, 0, 0)); + } + } + +} + +func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() { + // Test get_square + + //Small prime + let p_a = Uint384(7, 0, 0); + let x_a = Uint384(2, 0, 0); + let generator_a = Uint384(3, 0, 0); + let (s_a, r_a) = field_arithmetic.get_square_root(x_a, p_a, generator_a); + assert s_a = 1; + + assert r_a.d0 = 3; + assert r_a.d1 = 0; + assert r_a.d2 = 0; + + // Goldilocks Prime + let p_b = Uint384(18446744069414584321, 0, 0); // Goldilocks Prime + let x_b = Uint384(25, 0, 0); + let generator_b = Uint384(7, 0, 0); + let (s_b, r_b) = field_arithmetic.get_square_root(x_b, p_b, generator_b); + assert s_b = 1; + + assert r_b.d0 = 5; + assert r_b.d1 = 0; + assert r_b.d2 = 0; + + // Prime 2**101-99 + let p_c = Uint384(77371252455336267181195165, 32767, 0); + let x_c = Uint384(96059601, 0, 0); + let generator_c = Uint384(3, 0, 0); + let (s_c, r_c) = field_arithmetic.get_square_root(x_c, p_c, generator_c); + assert s_c = 1; + + assert r_c.d0 = 9801; + assert r_c.d1 = 0; + assert r_c.d2 = 0; + + return (); +} + +func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() { + test_field_arithmetics_extension_operations(); + return (); +} diff --git a/cairo_programs/uint384.cairo b/cairo_programs/uint384.cairo index d2d5b91f10..6a7d0ca2a4 100644 --- a/cairo_programs/uint384.cairo +++ b/cairo_programs/uint384.cairo @@ -1,4 +1,4 @@ -%builtins range_check +%builtins range_check bitwise // Code taken from https://github.com/NethermindEth/research-basic-Cairo-operations-big-integers/blob/main/lib/uint384.cairo from starkware.cairo.common.bitwise import bitwise_and, bitwise_or, bitwise_xor from starkware.cairo.common.cairo_builtins import BitwiseBuiltin @@ -564,7 +564,7 @@ func test_uint384_operations{range_check_ptr}() { return (); } -func main{range_check_ptr: felt}() { +func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() { test_uint384_operations(); return (); } diff --git a/cairo_programs/uint384_extension.cairo b/cairo_programs/uint384_extension.cairo index a4b467811d..4ddeb24499 100644 --- a/cairo_programs/uint384_extension.cairo +++ b/cairo_programs/uint384_extension.cairo @@ -298,7 +298,7 @@ func test_uint384_extension_operations{range_check_ptr}() { return (); } -func main{range_check_ptr: felt}() { +func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() { test_uint384_extension_operations(); return (); } diff --git a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs index ebdb340370..df555e2b98 100644 --- a/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs +++ b/src/hint_processor/builtin_hint_processor/builtin_hint_processor_definition.rs @@ -13,6 +13,7 @@ use crate::{ dict_squash_update_ptr, dict_update, dict_write, }, ec_utils::{chained_ec_op_random_ec_point_hint, random_ec_point_hint, recover_y_hint}, + field_arithmetic::get_square_root, find_element_hint::{find_element, search_sorted_lower}, garaga::get_felt_bitlenght, hint_code, @@ -538,6 +539,9 @@ impl HintProcessor for BuiltinHintProcessor { hint_code::UNSIGNED_DIV_REM_UINT768_BY_UINT384 => { unsigned_div_rem_uint768_by_uint384(vm, &hint_data.ids_data, &hint_data.ap_tracking) } + hint_code::GET_SQUARE_ROOT => { + get_square_root(vm, &hint_data.ids_data, &hint_data.ap_tracking) + } hint_code::UINT384_SIGNED_NN => { uint384_signed_nn(vm, &hint_data.ids_data, &hint_data.ap_tracking) } diff --git a/src/hint_processor/builtin_hint_processor/field_arithmetic.rs b/src/hint_processor/builtin_hint_processor/field_arithmetic.rs new file mode 100644 index 0000000000..ee53e5f7c8 --- /dev/null +++ b/src/hint_processor/builtin_hint_processor/field_arithmetic.rs @@ -0,0 +1,271 @@ +use felt::Felt252; +use num_bigint::BigUint; +use num_traits::Zero; + +use crate::math_utils::{is_quad_residue, sqrt_prime_power}; +use crate::serde::deserialize_program::ApTracking; +use crate::stdlib::{collections::HashMap, prelude::*}; +use crate::vm::errors::hint_errors::HintError; +use crate::{ + hint_processor::hint_processor_definition::HintReference, vm::vm_core::VirtualMachine, +}; + +use super::hint_utils::{get_relocatable_from_var_name, insert_value_from_var_name}; +use super::secp::bigint_utils::BigInt3; +use super::uint384::{pack, split}; +/* Implements Hint: + %{ + from starkware.python.math_utils import is_quad_residue, sqrt + + def split(num: int, num_bits_shift: int = 128, length: int = 3): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + + def pack(z, num_bits_shift: int = 128) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + + generator = pack(ids.generator) + x = pack(ids.x) + p = pack(ids.p) + + success_x = is_quad_residue(x, p) + root_x = sqrt(x, p) if success_x else None + + success_gx = is_quad_residue(generator*x, p) + root_gx = sqrt(generator*x, p) if success_gx else None + + # Check that one is 0 and the other is 1 + if x != 0: + assert success_x + success_gx ==1 + + # `None` means that no root was found, but we need to transform these into a felt no matter what + if root_x == None: + root_x = 0 + if root_gx == None: + root_gx = 0 + ids.success_x = int(success_x) + split_root_x = split(root_x) + split_root_gx = split(root_gx) + ids.sqrt_x.d0 = split_root_x[0] + ids.sqrt_x.d1 = split_root_x[1] + ids.sqrt_x.d2 = split_root_x[2] + ids.sqrt_gx.d0 = split_root_gx[0] + ids.sqrt_gx.d1 = split_root_gx[1] + ids.sqrt_gx.d2 = split_root_gx[2] + %} +*/ +pub fn get_square_root( + vm: &mut VirtualMachine, + ids_data: &HashMap, + ap_tracking: &ApTracking, +) -> Result<(), HintError> { + let sqrt_x_addr = get_relocatable_from_var_name("sqrt_x", vm, ids_data, ap_tracking)?; + let sqrt_gx_addr = get_relocatable_from_var_name("sqrt_gx", vm, ids_data, ap_tracking)?; + let generator = pack( + BigInt3::from_var_name("generator", vm, ids_data, ap_tracking)?, + 128, + ); + let x = pack(BigInt3::from_var_name("x", vm, ids_data, ap_tracking)?, 128); + let p = pack(BigInt3::from_var_name("p", vm, ids_data, ap_tracking)?, 128); + let success_x = is_quad_residue(&x, &p)?; + + let root_x = if success_x { + sqrt_prime_power(&x, &p).unwrap_or_default() + } else { + BigUint::zero() + }; + + let gx = generator * &x; + let success_gx = is_quad_residue(&gx, &p)?; + + let root_gx = if success_gx { + sqrt_prime_power(&gx, &p).unwrap_or_default() + } else { + BigUint::zero() + }; + + if !&x.is_zero() && !(success_x ^ success_gx) { + return Err(HintError::AssertionFailed(String::from( + "assert success_x + success_gx ==1", + ))); + } + insert_value_from_var_name( + "success_x", + Felt252::from(success_x as u8), + vm, + ids_data, + ap_tracking, + )?; + let split_root_x = split::<3>(&root_x, 128); + for (i, root_x) in split_root_x.iter().enumerate() { + vm.insert_value((sqrt_x_addr + i)?, Felt252::from(root_x))?; + } + let split_root_gx = split::<3>(&root_gx, 128); + for (i, root_gx) in split_root_gx.iter().enumerate() { + vm.insert_value((sqrt_gx_addr + i)?, Felt252::from(root_gx))?; + } + + Ok(()) +} +#[cfg(test)] +mod tests { + use super::*; + use crate::hint_processor::builtin_hint_processor::hint_code; + use crate::vm::vm_memory::memory_segments::MemorySegmentManager; + use crate::{ + any_box, + hint_processor::{ + builtin_hint_processor::builtin_hint_processor_definition::{ + BuiltinHintProcessor, HintProcessorData, + }, + hint_processor_definition::HintProcessor, + }, + types::{exec_scope::ExecutionScopes, relocatable::MaybeRelocatable}, + utils::test_utils::*, + vm::{ + errors::memory_errors::MemoryError, runners::builtin_runner::RangeCheckBuiltinRunner, + vm_core::VirtualMachine, vm_memory::memory::Memory, + }, + }; + use assert_matches::assert_matches; + + #[cfg(target_arch = "wasm32")] + use wasm_bindgen_test::*; + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_get_square_ok_goldilocks_prime() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 14; + //Create hint_data + let ids_data = non_continuous_ids_data![ + ("p", -14), + ("x", -11), + ("generator", -8), + ("sqrt_x", -5), + ("sqrt_gx", -2), + ("success_x", 1) + ]; + //Insert ids into memory + vm.segments = segments![ + //p + ((1, 0), 18446744069414584321), + ((1, 1), 0), + ((1, 2), 0), + //x + ((1, 3), 25), + ((1, 4), 0), + ((1, 5), 0), + //generator + ((1, 6), 7), + ((1, 7), 0), + ((1, 8), 0) + ]; + //Execute the hint + assert_matches!(run_hint!(vm, ids_data, hint_code::GET_SQUARE_ROOT), Ok(())); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // sqrt_x + ((1, 9), 5), + ((1, 10), 0), + ((1, 11), 0), + // sqrt_gx + ((1, 12), 0), + ((1, 13), 0), + ((1, 14), 0), + // success_x + ((1, 15), 1) + ]; + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_get_square_no_successes() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 14; + //Create hint_data + let ids_data = non_continuous_ids_data![ + ("p", -14), + ("x", -11), + ("generator", -8), + ("sqrt_x", -5), + ("sqrt_gx", -2), + ("success_x", 1) + ]; + //Insert ids into memory + vm.segments = segments![ + //p + ((1, 0), 3), + ((1, 1), 0), + ((1, 2), 0), + //x + ((1, 3), 17), + ((1, 4), 0), + ((1, 5), 0), + //generator + ((1, 6), 1), + ((1, 7), 0), + ((1, 8), 0) + ]; + //Execute the hint + assert_matches!(run_hint!(vm, ids_data, hint_code::GET_SQUARE_ROOT), + Err(HintError::AssertionFailed(s)) if s == "assert success_x + success_gx ==1" + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn run_get_square_ok_success_gx() { + let mut vm = vm_with_range_check!(); + //Initialize fp + vm.run_context.fp = 14; + //Create hint_data + let ids_data = non_continuous_ids_data![ + ("p", -14), + ("x", -11), + ("generator", -8), + ("sqrt_x", -5), + ("sqrt_gx", -2), + ("success_x", 1) + ]; + //Insert ids into memory + vm.segments = segments![ + //p + ((1, 0), 3), + ((1, 1), 0), + ((1, 2), 0), + //x + ((1, 3), 17), + ((1, 4), 0), + ((1, 5), 0), + //generator + ((1, 6), 71), + ((1, 7), 0), + ((1, 8), 0) + ]; + //Execute the hint + assert_matches!(run_hint!(vm, ids_data, hint_code::GET_SQUARE_ROOT), Ok(())); + //Check hint memory inserts + check_memory![ + vm.segments.memory, + // sqrt_x + ((1, 9), 0), + ((1, 10), 0), + ((1, 11), 0), + // sqrt_gx + ((1, 12), 1), + ((1, 13), 0), + ((1, 14), 0), + // success_x + ((1, 15), 0) + ]; + } +} diff --git a/src/hint_processor/builtin_hint_processor/hint_code.rs b/src/hint_processor/builtin_hint_processor/hint_code.rs index 1553408baf..fb5af02b06 100644 --- a/src/hint_processor/builtin_hint_processor/hint_code.rs +++ b/src/hint_processor/builtin_hint_processor/hint_code.rs @@ -850,6 +850,49 @@ ids.remainder.d1 = remainder_split[1] ids.remainder.d2 = remainder_split[2]"; pub const UINT384_SIGNED_NN: &str = "memory[ap] = 1 if 0 <= (ids.a.d2 % PRIME) < 2 ** 127 else 0"; +pub(crate) const GET_SQUARE_ROOT: &str = + "from starkware.python.math_utils import is_quad_residue, sqrt + +def split(num: int, num_bits_shift: int = 128, length: int = 3): + a = [] + for _ in range(length): + a.append( num & ((1 << num_bits_shift) - 1) ) + num = num >> num_bits_shift + return tuple(a) + +def pack(z, num_bits_shift: int = 128) -> int: + limbs = (z.d0, z.d1, z.d2) + return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs)) + + +generator = pack(ids.generator) +x = pack(ids.x) +p = pack(ids.p) + +success_x = is_quad_residue(x, p) +root_x = sqrt(x, p) if success_x else None + +success_gx = is_quad_residue(generator*x, p) +root_gx = sqrt(generator*x, p) if success_gx else None + +# Check that one is 0 and the other is 1 +if x != 0: + assert success_x + success_gx ==1 + +# `None` means that no root was found, but we need to transform these into a felt no matter what +if root_x == None: + root_x = 0 +if root_gx == None: + root_gx = 0 +ids.success_x = int(success_x) +split_root_x = split(root_x) +split_root_gx = split(root_gx) +ids.sqrt_x.d0 = split_root_x[0] +ids.sqrt_x.d1 = split_root_x[1] +ids.sqrt_x.d2 = split_root_x[2] +ids.sqrt_gx.d0 = split_root_gx[0] +ids.sqrt_gx.d1 = split_root_gx[1] +ids.sqrt_gx.d2 = split_root_gx[2]"; pub const HI_MAX_BITLEN: &str = "ids.len_hi = max(ids.scalar_u.d2.bit_length(), ids.scalar_v.d2.bit_length())-1"; diff --git a/src/hint_processor/builtin_hint_processor/mod.rs b/src/hint_processor/builtin_hint_processor/mod.rs index 53dfe1a726..d83801f382 100644 --- a/src/hint_processor/builtin_hint_processor/mod.rs +++ b/src/hint_processor/builtin_hint_processor/mod.rs @@ -5,6 +5,7 @@ pub mod cairo_keccak; pub mod dict_hint_utils; pub mod dict_manager; pub mod ec_utils; +pub mod field_arithmetic; pub mod find_element_hint; pub mod garaga; pub mod hint_code; diff --git a/src/hint_processor/builtin_hint_processor/uint384.rs b/src/hint_processor/builtin_hint_processor/uint384.rs index fd9662c873..3fd08af63c 100644 --- a/src/hint_processor/builtin_hint_processor/uint384.rs +++ b/src/hint_processor/builtin_hint_processor/uint384.rs @@ -196,9 +196,9 @@ pub fn add_no_uint384_check( let sum_d0 = a.d0.to_biguint() + b.d0.to_biguint(); let carry_d0 = Felt252::from((sum_d0 >= shift) as usize); - let sum_d1 = a.d1.to_biguint() + b.d1.to_biguint(); + let sum_d1 = a.d1.to_biguint() + b.d1.to_biguint() + carry_d0.to_biguint(); let carry_d1 = Felt252::from((sum_d1 >= shift) as usize); - let sum_d2 = a.d2.to_biguint() + b.d2.to_biguint(); + let sum_d2 = a.d2.to_biguint() + b.d2.to_biguint() + carry_d1.to_biguint(); let carry_d2 = Felt252::from((sum_d2 >= shift) as usize); insert_value_from_var_name("carry_d0", carry_d0, vm, ids_data, ap_tracking)?; diff --git a/src/math_utils.rs b/src/math_utils.rs index d7ddbb5d9a..8ca9166102 100644 --- a/src/math_utils.rs +++ b/src/math_utils.rs @@ -1,10 +1,13 @@ +use core::cmp::min; + use crate::stdlib::ops::Shr; use crate::types::errors::math_errors::MathError; use felt::Felt252; -use num_bigint::{BigInt, BigUint}; +use num_bigint::{BigInt, BigUint, RandBigInt}; use num_integer::Integer; +use num_prime::nt_funcs::is_prime; use num_traits::{Bounded, One, Pow, Signed, Zero}; - +use rand::{rngs::SmallRng, SeedableRng}; ///Returns the integer square root of the nonnegative integer n. ///This is the floor of the exact square root of n. ///Unlike math.sqrt(), this function doesn't have rounding error issues. @@ -182,14 +185,139 @@ pub fn sqrt(n: &Felt252) -> Felt252 { } } +// Adapted from sympy _sqrt_prime_power with k == 1 +pub fn sqrt_prime_power(a: &BigUint, p: &BigUint) -> Option { + if p.is_zero() || !is_prime(p, None).probably() { + return None; + } + let two = BigUint::from(2_u32); + let a = a.mod_floor(p); + if p == &two { + return Some(a); + } + if !(a < two || (a.modpow(&(p - 1_u32).div_floor(&two), p)).is_one()) { + return None; + }; + + if p.mod_floor(&BigUint::from(4_u32)) == 3_u32.into() { + let res = a.modpow(&(p + 1_u32).div_floor(&BigUint::from(4_u32)), p); + return Some(min(res.clone(), p - res)); + }; + + if p.mod_floor(&BigUint::from(8_u32)) == 5_u32.into() { + let sign = a.modpow(&(p - 1_u32).div_floor(&BigUint::from(4_u32)), p); + if sign.is_one() { + let res = a.modpow(&(p + 3_u32).div_floor(&BigUint::from(8_u32)), p); + return Some(min(res.clone(), p - res)); + } else { + let b = (4_u32 * &a).modpow(&(p - 5_u32).div_floor(&BigUint::from(8_u32)), p); + let x = (2_u32 * &a * b).mod_floor(p); + if x.modpow(&two, p) == a { + return Some(x); + } + } + }; + + Some(sqrt_tonelli_shanks(&a, p)) +} + +fn sqrt_tonelli_shanks(n: &BigUint, prime: &BigUint) -> BigUint { + // Based on Tonelli-Shanks' algorithm for finding square roots + // and sympy's library implementation of said algorithm. + if n.is_zero() || n.is_one() { + return n.clone(); + } + let s = (prime - 1_u32).trailing_zeros().unwrap_or_default(); + let t = prime >> s; + let a = n.modpow(&t, prime); + // Rng is not critical here so its safe to use a seeded value + let mut rng = SmallRng::seed_from_u64(11480028852697973135); + let mut d; + loop { + d = RandBigInt::gen_biguint_range(&mut rng, &BigUint::from(2_u32), &(prime - 1_u32)); + let r = legendre_symbol(&d, prime); + if r == -1 { + break; + }; + } + d = d.modpow(&t, prime); + let mut m = BigUint::zero(); + let mut exponent = BigUint::one() << (s - 1); + let mut adm; + for i in 0..s as u32 { + adm = &a * &d.modpow(&m, prime); + adm = adm.modpow(&exponent, prime); + exponent >>= 1; + if adm == (prime - 1_u32) { + m += BigUint::from(1_u32) << i; + } + } + let root_1 = + (n.modpow(&((t + 1_u32) >> 1), prime) * d.modpow(&(m >> 1), prime)).mod_floor(prime); + let root_2 = prime - &root_1; + if root_1 < root_2 { + root_1 + } else { + root_2 + } +} + +/* Disclaimer: Some asumptions have been taken based on the functions that rely on this function, make sure these are true before calling this function individually +Adpted from sympy implementation, asuming: + - p is an odd prime number + - a.mod_floor(p) == a +Returns the Legendre symbol `(a / p)`. + + For an integer ``a`` and an odd prime ``p``, the Legendre symbol is + defined as + + .. math :: + \genfrac(){}{}{a}{p} = \begin{cases} + 0 & \text{if } p \text{ divides } a\\ + 1 & \text{if } a \text{ is a quadratic residue modulo } p\\ + -1 & \text{if } a \text{ is a quadratic nonresidue modulo } p + \end{cases} +*/ +fn legendre_symbol(a: &BigUint, p: &BigUint) -> i8 { + if a.is_zero() { + return 0; + }; + if is_quad_residue(a, p).unwrap_or_default() { + 1 + } else { + -1 + } +} + +// Ported from sympy implementation +// Simplified as a & p are nonnegative +// Asumes p is a prime number +pub(crate) fn is_quad_residue(a: &BigUint, p: &BigUint) -> Result { + if p.is_zero() { + return Err(MathError::IsQuadResidueZeroPrime); + } + let a = if a >= p { a.mod_floor(p) } else { a.clone() }; + if a < BigUint::from(2_u8) || p < &BigUint::from(3_u8) { + return Ok(true); + } + Ok( + a.modpow(&(p - BigUint::one()).div_floor(&BigUint::from(2_u8)), p) + .is_one(), + ) +} + #[cfg(test)] mod tests { use super::*; use crate::utils::test_utils::*; use crate::utils::CAIRO_PRIME; use assert_matches::assert_matches; + use num_traits::Num; + #[cfg(not(target_arch = "wasm32"))] + use num_prime::RandPrime; + #[cfg(not(target_arch = "wasm32"))] use proptest::prelude::*; @@ -592,6 +720,7 @@ mod tests { } #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] fn safe_div_bigint_by_zero() { let x = BigInt::one(); let y = BigInt::zero(); @@ -599,6 +728,7 @@ mod tests { } #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] fn test_sqrt() { let n = Felt252::from_str_radix( "99957092485221722822822221624080199277265330641980989815386842231144616633668", @@ -614,6 +744,97 @@ mod tests { } #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_sqrt_prime_power() { + let n: BigUint = 25_u32.into(); + let p: BigUint = 18446744069414584321_u128.into(); + assert_eq!(sqrt_prime_power(&n, &p), Some(5_u32.into())); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_sqrt_prime_power_p_is_zero() { + let n = BigUint::one(); + let p: BigUint = BigUint::zero(); + assert_eq!(sqrt_prime_power(&n, &p), None); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_sqrt_prime_power_non_prime() { + let p: BigUint = BigUint::from_bytes_be(&[ + 69, 15, 232, 82, 215, 167, 38, 143, 173, 94, 133, 111, 1, 2, 182, 229, 110, 113, 76, 0, + 47, 110, 148, 109, 6, 133, 27, 190, 158, 197, 168, 219, 165, 254, 81, 53, 25, 34, + ]); + let n = BigUint::from_bytes_be(&[ + 9, 13, 22, 191, 87, 62, 157, 83, 157, 85, 93, 105, 230, 187, 32, 101, 51, 181, 49, 202, + 203, 195, 76, 193, 149, 78, 109, 146, 240, 126, 182, 115, 161, 238, 30, 118, 157, 252, + ]); + + assert_eq!(sqrt_prime_power(&n, &p), None); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_sqrt_prime_power_none() { + let n: BigUint = 10_u32.into(); + let p: BigUint = 602_u32.into(); + assert_eq!(sqrt_prime_power(&n, &p), None); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_sqrt_prime_power_prime_two() { + let n: BigUint = 25_u32.into(); + let p: BigUint = 2_u32.into(); + assert_eq!(sqrt_prime_power(&n, &p), Some(BigUint::one())); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_sqrt_prime_power_prime_mod_8_is_5_sign_not_one() { + let n: BigUint = 676_u32.into(); + let p: BigUint = 9956234341095173_u64.into(); + assert_eq!( + sqrt_prime_power(&n, &p), + Some(BigUint::from(9956234341095147_u64)) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_sqrt_prime_power_prime_mod_8_is_5_sign_is_one() { + let n: BigUint = 130283432663_u64.into(); + let p: BigUint = 743900351477_u64.into(); + assert_eq!( + sqrt_prime_power(&n, &p), + Some(BigUint::from(123538694848_u64)) + ); + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_legendre_symbol_zero() { + assert!(legendre_symbol(&BigUint::zero(), &BigUint::one()).is_zero()) + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_is_quad_residue_prime_zero() { + assert_eq!( + is_quad_residue(&BigUint::one(), &BigUint::zero()), + Err(MathError::IsQuadResidueZeroPrime) + ) + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] + fn test_is_quad_residue_prime_a_one_true() { + assert_eq!(is_quad_residue(&BigUint::one(), &BigUint::one()), Ok(true)) + } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] fn mul_inv_0_is_0() { let p = &(*CAIRO_PRIME).clone().into(); let x = &BigInt::zero(); @@ -625,16 +846,33 @@ mod tests { #[cfg(not(target_arch = "wasm32"))] proptest! { #[test] - // Test for sqrt of a quadratic residue. Result should be the minimum root. - fn sqrt_felt_test(ref x in "([1-9][0-9]*)") { - let x = &Felt252::parse_bytes(x.as_bytes(), 10).unwrap(); - let x_sq = x * x; - let sqrt = x_sq.sqrt(); + // Test for sqrt of a quadratic residue. Result should be the minimum root. + fn sqrt_felt_test(ref x in any::<[u8; 32]>()) { + let x = &Felt252::from_bytes_be(x); + let x_sq = x * x; + let sqrt = sqrt(&x_sq); if &sqrt != x { - assert_eq!(Felt252::max_value() - sqrt + 1_usize, *x); + prop_assert_eq!(&(Felt252::max_value() - sqrt + 1_usize), x); } else { - assert_eq!(&sqrt, x); + prop_assert_eq!(&sqrt, x); + } + } + + #[test] + // Test for sqrt_prime_power_ of a quadratic residue. Result should be the minimum root. + fn sqrt_prime_power_using_random_prime(ref x in any::<[u8; 38]>(), ref y in any::()) { + let mut rng = SmallRng::seed_from_u64(*y); + let x = &BigUint::from_bytes_be(x); + // Generate a prime here instead of relying on y, otherwise y may never be a prime number + let p : &BigUint = &RandPrime::gen_prime(&mut rng, 384, None); + let x_sq = x * x; + if let Some(sqrt) = sqrt_prime_power(&x_sq, p) { + if &sqrt != x { + prop_assert_eq!(&(p - sqrt), x); + } else { + prop_assert_eq!(&sqrt, x); + } } } diff --git a/src/tests/cairo_run_test.rs b/src/tests/cairo_run_test.rs index 403b006ee0..168d81f08a 100644 --- a/src/tests/cairo_run_test.rs +++ b/src/tests/cairo_run_test.rs @@ -1309,6 +1309,13 @@ fn cairo_run_uint384_extension() { run_program_simple(program_data.as_slice()); } +#[test] +#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] +fn cairo_field_arithmetic() { + let program_data = include_bytes!("../../cairo_programs/field_arithmetic.json"); + run_program_simple(program_data.as_slice()); +} + #[test] #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)] fn cairo_run_ed25519_field() { diff --git a/src/types/errors/math_errors.rs b/src/types/errors/math_errors.rs index b614453290..44382ba53e 100644 --- a/src/types/errors/math_errors.rs +++ b/src/types/errors/math_errors.rs @@ -27,6 +27,8 @@ pub enum MathError { DividedByZero, #[error("Failed to calculate the square root of: {0})")] FailedToGetSqrt(BigUint), + #[error("is_quad_residue: p must be > 0")] + IsQuadResidueZeroPrime, // Relocatable Operations #[error("Cant convert felt: {0} to Relocatable")] Felt252ToRelocatable(Felt252),