Skip to content

Commit

Permalink
Implement Hint#48 pack512 (#1000)
Browse files Browse the repository at this point in the history
* Add cairo program for integration test

* WIP implement hint

* Add pack_512 fn + unit tests

* Update pack_512.cairo

* Add hint code && implement hint

* Reimplement pack_512 function && fix hint

* handle error

* Fix merge conflicts

* Update Changelog

* Rename pack_512.rs -> inv_mod_p_uint512.rs

* Add integration test

* Rename u255.cairo -> cairo_programs/fq_uint256.cairo

* pack_512 -> BigUint

* Add unit test

* fix wasm tests

* Remove fq_uint256.cairo

* use u128::MAX

* fix tests

---------

Co-authored-by: Pedro Fontana <pedro.fontana@lamdaclass.com>
  • Loading branch information
pefontana and Pedro Fontana authored Apr 21, 2023
1 parent e116860 commit 9ed2234
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 0 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,25 @@

#### Upcoming Changes

* Add missing hint on vrf.json lib [#1000](https://github.com/lambdaclass/cairo-rs/pull/1000):

`BuiltinHintProcessor` now supports the following hint:

```python
def pack_512(u, num_bits_shift: int) -> int:
limbs = (u.d0, u.d1, u.d2, u.d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

x = pack_512(ids.x, num_bits_shift = 128)
p = ids.p.low + (ids.p.high << 128)
x_inverse_mod_p = pow(x,-1, p)

x_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128)

ids.x_inverse_mod_p.low = x_inverse_mod_p_split[0]
ids.x_inverse_mod_p.high = x_inverse_mod_p_split[1]
```

* BREAKING CHANGE: Fix `CairoRunner::get_memory_holes` [#1027](https://github.com/lambdaclass/cairo-rs/pull/1027):

* Skip builtin segements when counting memory holes
Expand Down
43 changes: 43 additions & 0 deletions cairo_programs/inv_mod_p_uint512.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
%builtins range_check

from starkware.cairo.common.uint256 import Uint256

const P_low = 201385395114098847380338600778089168199;
const P_high = 64323764613183177041862057485226039389;

struct Uint512 {
d0: felt,
d1: felt,
d2: felt,
d3: felt,
}

func inv_mod_p_uint512{range_check_ptr}(x: Uint512) -> Uint256 {
alloc_locals;
local x_inverse_mod_p: Uint256;
local p: Uint256 = Uint256(P_low, P_high);
// To whitelist
%{
def pack_512(u, num_bits_shift: int) -> int:
limbs = (u.d0, u.d1, u.d2, u.d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
x = pack_512(ids.x, num_bits_shift = 128)
p = ids.p.low + (ids.p.high << 128)
x_inverse_mod_p = pow(x,-1, p)
x_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128)
ids.x_inverse_mod_p.low = x_inverse_mod_p_split[0]
ids.x_inverse_mod_p.high = x_inverse_mod_p_split[1]
%}

return x_inverse_mod_p;
}

func main{range_check_ptr: felt}() {
let x = Uint512(101, 2, 15, 61);
let y = inv_mod_p_uint512(x);
assert y = Uint256(80275402838848031859800366538378848249, 5810892639608724280512701676461676039);
return ();
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ use felt::Felt252;
#[cfg(feature = "skip_next_instruction_hint")]
use crate::hint_processor::builtin_hint_processor::skip_next_instruction::skip_next_instruction;

use super::vrf::inv_mod_p_uint512::inv_mod_p_uint512;

pub struct HintProcessorData {
pub code: String,
pub ap_tracking: ApTracking,
Expand Down Expand Up @@ -567,6 +569,9 @@ impl HintProcessor for BuiltinHintProcessor {
hi_max_bitlen(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::QUAD_BIT => quad_bit(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::INV_MOD_P_UINT512 => {
inv_mod_p_uint512(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::DI_BIT => di_bit(vm, &hint_data.ids_data, &hint_data.ap_tracking),
#[cfg(feature = "skip_next_instruction_hint")]
hint_code::SKIP_NEXT_INSTRUCTION => skip_next_instruction(vm),
Expand Down
13 changes: 13 additions & 0 deletions src/hint_processor/builtin_hint_processor/hint_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,19 @@ pub const QUAD_BIT: &str = r#"ids.quad_bit = (
+ ((ids.scalar_u >> (ids.m - 1)) & 1)
)"#;

pub const INV_MOD_P_UINT512: &str = "def pack_512(u, num_bits_shift: int) -> int:
limbs = (u.d0, u.d1, u.d2, u.d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
x = pack_512(ids.x, num_bits_shift = 128)
p = ids.p.low + (ids.p.high << 128)
x_inverse_mod_p = pow(x,-1, p)
x_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128)
ids.x_inverse_mod_p.low = x_inverse_mod_p_split[0]
ids.x_inverse_mod_p.high = x_inverse_mod_p_split[1]";

pub const DI_BIT: &str =
r#"ids.dibit = ((ids.scalar_u >> ids.m) & 1) + 2 * ((ids.scalar_v >> ids.m) & 1)"#;

Expand Down
1 change: 1 addition & 0 deletions src/hint_processor/builtin_hint_processor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ pub mod uint256_utils;
pub mod uint384;
pub mod uint384_extension;
pub mod usort;
pub mod vrf;
212 changes: 212 additions & 0 deletions src/hint_processor/builtin_hint_processor/vrf/inv_mod_p_uint512.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
use core::ops::Shl;

use crate::stdlib::prelude::String;
use crate::stdlib::vec::Vec;
use crate::{
hint_processor::{
builtin_hint_processor::hint_utils::get_relocatable_from_var_name,
hint_processor_definition::HintReference,
},
math_utils::div_mod,
serde::deserialize_program::ApTracking,
stdlib::collections::HashMap,
vm::errors::hint_errors::HintError,
};
use felt::Felt252;
use num_bigint::{BigInt, BigUint};
use num_traits::One;

use crate::vm::vm_core::VirtualMachine;

/*
def pack_512(d0, d1,d2,d3, num_bits_shift: int) -> int:
limbs = (d0, d1, d2, d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
*/
fn pack_512(limbs: &[Felt252; 4], num_bits_shift: usize) -> BigUint {
limbs
.iter()
.enumerate()
.map(|(idx, value)| value.to_biguint().shl(idx * num_bits_shift))
.sum()
}

/*
Implements hint:
%{
def pack_512(u, num_bits_shift: int) -> int:
limbs = (u.d0, u.d1, u.d2, u.d3)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
x = pack_512(ids.x, num_bits_shift = 128)
p = ids.p.low + (ids.p.high << 128)
x_inverse_mod_p = pow(x,-1, p)
x_inverse_mod_p_split = (x_inverse_mod_p & ((1 << 128) - 1), x_inverse_mod_p >> 128)
ids.x_inverse_mod_p.low = x_inverse_mod_p_split[0]
ids.x_inverse_mod_p.high = x_inverse_mod_p_split[1]
%}
*/
pub fn inv_mod_p_uint512(
vm: &mut VirtualMachine,
ids_data: &HashMap<String, HintReference>,
ap_tracking: &ApTracking,
) -> Result<(), HintError> {
let limbs_ptr = get_relocatable_from_var_name("x", vm, ids_data, ap_tracking)?;
let limbs: Vec<Felt252> = vm
.get_integer_range(limbs_ptr, 4)?
.iter()
.map(|f| f.clone().into_owned())
.collect();

let x = pack_512(
&limbs
.try_into()
.map_err(|_| HintError::FixedSizeArrayFail(4))?,
128,
);

let p_ptr = get_relocatable_from_var_name("p", vm, ids_data, ap_tracking)?;
let p_low = vm.get_integer(p_ptr)?;
let p_high = vm.get_integer((p_ptr + 1_i32)?)?;

let p = p_low.into_owned().to_biguint() + (p_high.into_owned().to_biguint() << 128_usize);
let x_inverse_mod_p =
Felt252::from(div_mod(&BigInt::one(), &BigInt::from(x), &BigInt::from(p)));

let x_inverse_mod_p_ptr =
get_relocatable_from_var_name("x_inverse_mod_p", vm, ids_data, ap_tracking)?;

vm.insert_value(
x_inverse_mod_p_ptr,
&x_inverse_mod_p & &Felt252::from(u128::MAX),
)?;

vm.insert_value((x_inverse_mod_p_ptr + 1_i32)?, x_inverse_mod_p >> 128)?;

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::any_box;
use crate::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::BuiltinHintProcessor;
use crate::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::HintProcessorData;
use crate::hint_processor::hint_processor_definition::HintProcessor;
use crate::types::relocatable::{MaybeRelocatable, Relocatable};
use crate::utils::test_utils::mayberelocatable;
use crate::utils::test_utils::memory;
use crate::utils::test_utils::memory_from_memory;
use crate::utils::test_utils::memory_inner;
use crate::vm::errors::memory_errors::MemoryError;
use crate::vm::runners::builtin_runner::RangeCheckBuiltinRunner;
use crate::vm::vm_memory::memory::Memory;
use crate::vm::vm_memory::memory_segments::MemorySegmentManager;
use crate::{
hint_processor::builtin_hint_processor::hint_code::INV_MOD_P_UINT512,
types::exec_scope::ExecutionScopes,
utils::test_utils::{
add_segments, non_continuous_ids_data, run_hint, segments, vm_with_range_check,
},
};
use num_traits::{FromPrimitive, Num};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*;

#[test]
fn test_pack_512() {
assert_eq!(
pack_512(
&[
Felt252::new(13123),
Felt252::new(534354),
Felt252::new(9901823),
Felt252::new(7812371)
],
2
),
BigUint::from(660571451_u128)
);
assert_eq!(
pack_512(
&[
Felt252::new(13123),
Felt252::new(534354),
Felt252::new(9901823),
Felt252::new(7812371)
],
76
),
BigUint::from_str_radix(
"3369937688063908975412897222574435556910082026593269572342866796946053411651",
10
)
.unwrap()
);

assert_eq!(
pack_512(
&[
Felt252::new(90812398),
Felt252::new(55),
Felt252::new(83127),
Felt252::from_i128(45312309123).unwrap()
],
761
),
BigUint::from_str_radix("80853029148137605102740201774483901385926652025450340798711030404174727480763870493377667725625759764292622444803788021444434452626041518098606806141685367065099387655302625873713592439838446220691925786159227082298892378981461987274693629088875674987359669209043388107114325450518636532594445145924759095125734364345163525655691027843325303271775064263282011908012871334532482494107608759994020937000541268185418760956243245766874157401648637158526410360988956699864519559367805347900540475245570833510432301935056255005826223734865268553682118180231081037207280009003811438596531432027766301678781550463988061852846171462460595592799020846810683500364584025173048032553173114469560143047387885550", 10).unwrap()
);
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_inv_mod_p_uint512_ok() {
let mut vm = vm_with_range_check!();
add_segments!(vm, 3);

//Initialize fp
vm.run_context.fp = 25;

//Create hint data
let ids_data = non_continuous_ids_data![("x", -5), ("p", -10), ("x_inverse_mod_p", -20)];
vm.segments = segments![
((1, 20), 101), //ids.x.d0
((1, 21), 2), // ids.x.d1
((1, 22), 15), // ids.x.d2
((1, 23), 61) // ids.x.d3
// ((1, 15), 201385395114098847380338600778089168199), // ids.p.low
// ((1, 16), 64323764613183177041862057485226039389) // ids.p.high
];
vm.insert_value(
Relocatable::from((1, 15)),
Felt252::from_str_radix("201385395114098847380338600778089168199", 10).unwrap(),
)
.expect("error setting ids.p");
vm.insert_value(
Relocatable::from((1, 16)),
Felt252::from_str_radix("64323764613183177041862057485226039389", 10).unwrap(),
)
.expect("error setting ids.p");

let mut exec_scopes = ExecutionScopes::new();
//Execute the hint
assert!(run_hint!(vm, ids_data, INV_MOD_P_UINT512, &mut exec_scopes).is_ok());

// Check VM inserts
assert_eq!(
vm.get_integer(Relocatable::from((1, 5)))
.unwrap()
.into_owned(),
Felt252::from_str_radix("80275402838848031859800366538378848249", 10).unwrap()
);
assert_eq!(
vm.get_integer(Relocatable::from((1, 6)))
.unwrap()
.into_owned(),
Felt252::from_str_radix("5810892639608724280512701676461676039", 10).unwrap()
);
}
}
1 change: 1 addition & 0 deletions src/hint_processor/builtin_hint_processor/vrf/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod inv_mod_p_uint512;
7 changes: 7 additions & 0 deletions src/tests/cairo_run_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,10 @@ fn memory_holes() {
let program_data = include_bytes!("../../cairo_programs/memory_holes.json");
run_program_simple_with_memory_holes(program_data.as_slice(), 5)
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn cairo_run_inv_mod_p_uint512() {
let program_data = include_bytes!("../../cairo_programs/inv_mod_p_uint512.json");
run_program_simple(program_data.as_slice());
}

1 comment on commit 9ed2234

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.30.

Benchmark suite Current: 9ed2234 Previous: e116860 Ratio
parse program 26465647 ns/iter (± 886853) 18257884 ns/iter (± 281612) 1.45
build runner 4124234 ns/iter (± 113021) 2555412 ns/iter (± 17253) 1.61

This comment was automatically generated by workflow using github-action-benchmark.

CC: @unbalancedparentheses

Please sign in to comment.