Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hints): add NewHint#45 #1024

Merged
merged 3 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

#### Upcoming Changes

* Add missing hint on uint256_improvements lib [#1024](https://github.com/lambdaclass/cairo-rs/pull/1024):

`BuiltinHintProcessor` now supports the following hint:

```python
res = ids.a + ids.b
ids.carry = 1 if res >= ids.SHIFT else 0
```

* BREAKING CHANGE: move `Program::identifiers` to `SharedProgramData::identifiers` [#1023](https://github.com/lambdaclass/cairo-rs/pull/1023)
* Optimizes `CairoRunner::new`, needed for sequencers and other workflows reusing the same `Program` instance across `CairoRunner`s
* Breaking change: make all fields in `Program` and `SharedProgramData` `pub(crate)`, since we break by moving the field let's make it the last break for this struct
Expand Down
15 changes: 15 additions & 0 deletions cairo_programs/uint256_improvements.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,24 @@ func test_uint256_sub{range_check_ptr}() {
return ();
}

func test_uint128_add{range_check_ptr}() {
let (res) = uint128_add(5, 66);

assert res = Uint256(71, 0);

let (res) = uint128_add(
340282366920938463463374607431768211455, 340282366920938463463374607431768211455
);

assert res = Uint256(340282366920938463463374607431768211454, 1);

return ();
}

func main{range_check_ptr}() {
test_udiv_expanded();
test_uint256_sub();
test_uint128_add();

return ();
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ use crate::{
squash_dict_inner_used_accesses_assert,
},
uint256_utils::{
split_64, uint256_add, uint256_expanded_unsigned_div_rem, uint256_mul_div_mod,
uint256_signed_nn, uint256_sqrt, uint256_sub, uint256_unsigned_div_rem,
split_64, uint128_add, uint256_add, uint256_expanded_unsigned_div_rem,
uint256_mul_div_mod, uint256_signed_nn, uint256_sqrt, uint256_sub,
uint256_unsigned_div_rem,
},
uint384::{
add_no_uint384_check, uint384_signed_nn, uint384_split_128, uint384_sqrt,
Expand Down Expand Up @@ -336,6 +337,7 @@ impl HintProcessor for BuiltinHintProcessor {
dict_squash_update_ptr(vm, exec_scopes, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::UINT256_ADD => uint256_add(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::UINT128_ADD => uint128_add(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::UINT256_SUB => uint256_sub(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::SPLIT_64 => split_64(vm, &hint_data.ids_data, &hint_data.ap_tracking),
hint_code::UINT256_SQRT => {
Expand Down
3 changes: 3 additions & 0 deletions src/hint_processor/builtin_hint_processor/hint_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
sum_high = ids.a.high + ids.b.high + ids.carry_low
ids.carry_high = 1 if sum_high >= ids.SHIFT else 0"#;

pub const UINT128_ADD: &str = r#"res = ids.a + ids.b
ids.carry = 1 if res >= ids.SHIFT else 0"#;

pub const UINT256_SUB: &str = r#"def split(num: int, num_bits_shift: int = 128, length: int = 2):
a = []
for _ in range(length):
Expand Down
92 changes: 63 additions & 29 deletions src/hint_processor/builtin_hint_processor/uint256_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,38 +108,53 @@ pub fn uint256_add(
ap_tracking: &ApTracking,
) -> Result<(), HintError> {
let shift = Felt252::new(1_u32) << 128_u32;
let a_relocatable = get_relocatable_from_var_name("a", vm, ids_data, ap_tracking)?;
let b_relocatable = get_relocatable_from_var_name("b", vm, ids_data, ap_tracking)?;
let a_low = vm.get_integer(a_relocatable)?;
let a_high = vm.get_integer((a_relocatable + 1_usize)?)?;
let b_low = vm.get_integer(b_relocatable)?;
let b_high = vm.get_integer((b_relocatable + 1_usize)?)?;
let a_low = a_low.as_ref();
let a_high = a_high.as_ref();
let b_low = b_low.as_ref();
let b_high = b_high.as_ref();

//Main logic
//sum_low = ids.a.low + ids.b.low
//ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
//sum_high = ids.a.high + ids.b.high + ids.carry_low
//ids.carry_high = 1 if sum_high >= ids.SHIFT else 0
let a = Uint256::from_var_name("a", vm, ids_data, ap_tracking)?;
let b = Uint256::from_var_name("b", vm, ids_data, ap_tracking)?;
let a_low = a.low.as_ref();
let a_high = a.high.as_ref();
let b_low = b.low.as_ref();
let b_high = b.high.as_ref();

let carry_low = if a_low + b_low >= shift {
Felt252::one()
} else {
Felt252::zero()
};
// Main logic
// sum_low = ids.a.low + ids.b.low
// ids.carry_low = 1 if sum_low >= ids.SHIFT else 0
// sum_high = ids.a.high + ids.b.high + ids.carry_low
// ids.carry_high = 1 if sum_high >= ids.SHIFT else 0

let carry_low = Felt252::from((a_low + b_low >= shift) as u8);
let carry_high = Felt252::from((a_high + b_high + &carry_low >= shift) as u8);

let carry_high = if a_high + b_high + &carry_low >= shift {
Felt252::one()
} else {
Felt252::zero()
};
insert_value_from_var_name("carry_high", carry_high, vm, ids_data, ap_tracking)?;
insert_value_from_var_name("carry_low", carry_low, vm, ids_data, ap_tracking)
}

/*
Implements hint:
%{
res = ids.a + ids.b
ids.carry = 1 if res >= ids.SHIFT else 0
%}
*/
pub fn uint128_add(
vm: &mut VirtualMachine,
ids_data: &HashMap<String, HintReference>,
ap_tracking: &ApTracking,
) -> Result<(), HintError> {
let shift = Felt252::new(1_u32) << 128_u32;
let a = get_integer_from_var_name("a", vm, ids_data, ap_tracking)?;
let b = get_integer_from_var_name("b", vm, ids_data, ap_tracking)?;
let a = a.as_ref();
MegaRedHand marked this conversation as resolved.
Show resolved Hide resolved
let b = b.as_ref();

// Main logic
// res = ids.a + ids.b
// ids.carry = 1 if res >= ids.SHIFT else 0
let carry = Felt252::from((a + b >= shift) as u8);

insert_value_from_var_name("carry", carry, vm, ids_data, ap_tracking)
}

/*
Implements hint:
%{
Expand Down Expand Up @@ -477,29 +492,48 @@ mod tests {
#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint256_add_ok() {
let hint_code = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0";
let hint_code = hint_code::UINT256_ADD;
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 10;
//Create hint_data
let ids_data =
non_continuous_ids_data![("a", -6), ("b", -4), ("carry_high", 3), ("carry_low", 2)];
non_continuous_ids_data![("a", -6), ("b", -4), ("carry_low", 2), ("carry_high", 3)];
vm.segments = segments![
((1, 4), 2),
((1, 5), 3),
((1, 6), 4),
((1, 7), ("340282366920938463463374607431768211456", 10))
((1, 7), ("340282366920938463463374607431768211455", 10))
];
//Execute the hint
assert_matches!(run_hint!(vm, ids_data, hint_code), Ok(()));
//Check hint memory inserts
check_memory![vm.segments.memory, ((1, 12), 0), ((1, 13), 1)];
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint128_add_ok() {
let hint_code = hint_code::UINT128_ADD;
let mut vm = vm_with_range_check!();
// Initialize fp
vm.run_context.fp = 0;
// Create hint_data
let ids_data = non_continuous_ids_data![("a", 0), ("b", 1), ("carry", 2)];
vm.segments = segments![
((1, 0), 180141183460469231731687303715884105727_u128),
((1, 1), 180141183460469231731687303715884105727_u128),
];
// Execute the hint
assert_matches!(run_hint!(vm, ids_data, hint_code), Ok(()));
// Check hint memory inserts
check_memory![vm.segments.memory, ((1, 2), 1)];
}

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn run_uint256_add_fail_inserts() {
let hint_code = "sum_low = ids.a.low + ids.b.low\nids.carry_low = 1 if sum_low >= ids.SHIFT else 0\nsum_high = ids.a.high + ids.b.high + ids.carry_low\nids.carry_high = 1 if sum_high >= ids.SHIFT else 0";
let hint_code = hint_code::UINT256_ADD;
let mut vm = vm_with_range_check!();
//Initialize fp
vm.run_context.fp = 10;
Expand Down