diff --git a/Cargo.lock b/Cargo.lock index 6d8a6bb64b..592c60d863 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -300,6 +300,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "aurora-engine-modexp" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfacad86e9e138fca0670949eb8ed4ffdf73a55bded8887efe0863cd1a3a6f70" +dependencies = [ + "hex", + "num", +] + [[package]] name = "auto_impl" version = "1.1.0" @@ -2381,9 +2391,9 @@ dependencies = [ name = "revm-precompile" version = "2.2.0" dependencies = [ + "aurora-engine-modexp", "c-kzg", "k256", - "num", "once_cell", "revm-primitives", "ripemd", diff --git a/crates/precompile/Cargo.toml b/crates/precompile/Cargo.toml index 46e9aea3ea..6a54124529 100644 --- a/crates/precompile/Cargo.toml +++ b/crates/precompile/Cargo.toml @@ -11,10 +11,11 @@ version = "2.2.0" [dependencies] revm-primitives = { path = "../primitives", version = "1.3.0", default-features = false } bn = { package = "substrate-bn", version = "0.6", default-features = false } -num = { version = "0.4.0", default-features = false, features = ["alloc"] } once_cell = { version = "1.17", default-features = false, features = ["alloc"] } ripemd = { version = "0.1", default-features = false } sha2 = { version = "0.10", default-features = false } +# modexp precompile +aurora-engine-modexp = { version = "1.0", default-features = false } # Optional KZG point evaluation precompile c-kzg = { version = "0.1.1", default-features = false, optional = true } @@ -26,12 +27,12 @@ secp256k1 = { version = "0.27.0", default-features = false, features = [ "recovery", ], optional = true } + [features] default = ["std", "c-kzg", "secp256k1"] std = [ "revm-primitives/std", "k256/std", - "num/std", "once_cell/std", "ripemd/std", "sha2/std", @@ -41,7 +42,7 @@ std = [ optimism = ["revm-primitives/optimism"] -# This library may not work on all no_std platforms as they depend on C libraries. +# These libraries may not work on all no_std platforms as they depend on C. # Enables the KZG point evaluation precompile. c-kzg = ["dep:c-kzg", "revm-primitives/c-kzg"] diff --git a/crates/precompile/src/lib.rs b/crates/precompile/src/lib.rs index 4fdef0595b..03911b183a 100644 --- a/crates/precompile/src/lib.rs +++ b/crates/precompile/src/lib.rs @@ -16,6 +16,7 @@ mod identity; pub mod kzg_point_evaluation; mod modexp; mod secp256k1; +pub mod utilities; use alloc::{boxed::Box, vec::Vec}; use core::fmt; diff --git a/crates/precompile/src/modexp.rs b/crates/precompile/src/modexp.rs index 072c2c6fe9..2f90f0cb6b 100644 --- a/crates/precompile/src/modexp.rs +++ b/crates/precompile/src/modexp.rs @@ -1,12 +1,11 @@ use crate::{ - primitives::U256, Error, Precompile, PrecompileAddress, PrecompileResult, StandardPrecompileFn, + primitives::U256, + utilities::{get_right_padded, get_right_padded_vec, left_padding, left_padding_vec}, + Error, Precompile, PrecompileAddress, PrecompileResult, StandardPrecompileFn, }; use alloc::vec::Vec; -use core::{ - cmp::{max, min, Ordering}, - mem::size_of, -}; -use num::{BigUint, One, Zero}; +use aurora_engine_modexp::modexp; +use core::cmp::{max, min}; pub const BYZANTIUM: PrecompileAddress = PrecompileAddress( crate::u64_to_address(5), @@ -32,121 +31,96 @@ pub fn berlin_run(input: &[u8], gas_limit: u64) -> PrecompileResult { }) } -fn calculate_iteration_count(exp_length: u64, exp_highp: &BigUint) -> u64 { +fn calculate_iteration_count(exp_length: u64, exp_highp: &U256) -> u64 { let mut iteration_count: u64 = 0; - if exp_length <= 32 && exp_highp.is_zero() { + if exp_length <= 32 && *exp_highp == U256::ZERO { iteration_count = 0; } else if exp_length <= 32 { - iteration_count = exp_highp.bits() - 1; + iteration_count = exp_highp.bit_len() as u64 - 1; } else if exp_length > 32 { - iteration_count = (8 * (exp_length - 32)) + max(1, exp_highp.bits()) - 1; + iteration_count = (8 * (exp_length - 32)) + max(1, exp_highp.bit_len() as u64) - 1; } max(iteration_count, 1) } -macro_rules! read_u64_with_overflow { - ($input:expr, $from:expr, $to:expr, $overflow_limit:expr) => {{ - const SPLIT: usize = 32 - size_of::(); - let len = $input.len(); - let from_zero = min($from, len); - let from = min(from_zero + SPLIT, len); - let to = min($to, len); - let overflow_bytes = &$input[from_zero..from]; - - let mut len_bytes = [0u8; size_of::()]; - len_bytes[..to - from].copy_from_slice(&$input[from..to]); - let out = u64::from_be_bytes(len_bytes) as usize; - let overflow = !(out < $overflow_limit && overflow_bytes.iter().all(|&x| x == 0)); - (out, overflow) - }}; -} - fn run_inner(input: &[u8], gas_limit: u64, min_gas: u64, calc_gas: F) -> PrecompileResult where - F: FnOnce(u64, u64, u64, &BigUint) -> u64, + F: FnOnce(u64, u64, u64, &U256) -> u64, { - let len = input.len(); - let (base_len, base_overflow) = read_u64_with_overflow!(input, 0, 32, u32::MAX as usize); - let (exp_len, exp_overflow) = read_u64_with_overflow!(input, 32, 64, u32::MAX as usize); - let (mod_len, mod_overflow) = read_u64_with_overflow!(input, 64, 96, u32::MAX as usize); - - if base_overflow || mod_overflow { - return Err(Error::ModexpBaseOverflow); + // If there is no minimum gas, return error. + if min_gas > gas_limit { + return Err(Error::OutOfGas); } - - if mod_overflow { + // The format of input is: + // + // Where every length is a 32-byte left-padded integer representing the number of bytes + // to be taken up by the next value + const HEADER_LENGTH: usize = 96; + + // Extract the header. + let base_len = U256::from_be_bytes(get_right_padded::<32>(input, 0)); + let exp_len = U256::from_be_bytes(get_right_padded::<32>(input, 32)); + let mod_len = U256::from_be_bytes(get_right_padded::<32>(input, 64)); + + // cast base and modulus to usize, it does not make sense to handle larger values + let Ok(base_len) = usize::try_from(base_len) else { + return Err(Error::ModexpBaseOverflow); + }; + let Ok(mod_len) = usize::try_from(mod_len) else { return Err(Error::ModexpModOverflow); + }; + + // Handle a special case when both the base and mod length is zero + if base_len == 0 && mod_len == 0 { + return Ok((min_gas, Vec::new())); } - let (r, gas_cost) = if base_len == 0 && mod_len == 0 { - if min_gas > gas_limit { - return Err(Error::OutOfGas); - } - (BigUint::zero(), min_gas) - } else { - // set limit for exp overflow - if exp_overflow { - return Err(Error::ModexpExpOverflow); - } - let base_start = 96; - let base_end = base_start + base_len; - let exp_end = base_end + exp_len; - let exp_highp_end = base_end + min(32, exp_len); - let mod_end = exp_end + mod_len; - - let exp_highp = { - let mut out = [0; 32]; - let from = min(base_end, len); - let to = min(exp_highp_end, len); - let target_from = 32 - (exp_highp_end - base_end); // 32 - exp length - let target_to = target_from + (to - from); // beginning + size to copy - out[target_from..target_to].copy_from_slice(&input[from..to]); - BigUint::from_bytes_be(&out) - }; - - let gas_cost = calc_gas(base_len as u64, exp_len as u64, mod_len as u64, &exp_highp); - if gas_cost > gas_limit { - return Err(Error::OutOfGas); - } + // cast exponent length to usize, it does not make sense to handle larger values. + let Ok(exp_len) = usize::try_from(exp_len) else { + return Err(Error::ModexpModOverflow); + }; - let read_big = |from: usize, to: usize| { - let mut out = vec![0; to - from]; - let from = min(from, len); - let to = min(to, len); - out[..to - from].copy_from_slice(&input[from..to]); - BigUint::from_bytes_be(&out) - }; + // Used to extract ADJUSTED_EXPONENT_LENGTH. + let exp_highp_len = min(exp_len, 32); - let base = read_big(base_start, base_end); - let exponent = read_big(base_end, exp_end); - let modulus = read_big(exp_end, mod_end); + // throw away the header data as we already extracted lengths. + let input = if input.len() >= 96 { + &input[HEADER_LENGTH..] + } else { + // or set input to zero if there is no more data + &[] + }; - if modulus.is_zero() || modulus.is_one() { - (BigUint::zero(), gas_cost) - } else { - (base.modpow(&exponent, &modulus), gas_cost) - } + let exp_highp = { + // get right padded bytes so if data.len is less then exp_len we will get right padded zeroes. + let right_padded_highp = get_right_padded::<32>(input, base_len); + // If exp_len is less then 32 bytes get only exp_len bytes and do left padding. + let out = left_padding::<32>(&right_padded_highp[..exp_highp_len]); + U256::from_be_bytes(out) }; - // write output to given memory, left padded and same length as the modulus. - let bytes = r.to_bytes_be(); - // always true except in the case of zero-length modulus, which leads to - // output of length and value 1. - match bytes.len().cmp(&mod_len) { - Ordering::Equal => Ok((gas_cost, bytes)), - Ordering::Less => { - let mut ret = Vec::with_capacity(mod_len); - ret.extend(core::iter::repeat(0).take(mod_len - bytes.len())); - ret.extend_from_slice(&bytes[..]); - Ok((gas_cost, ret)) - } - Ordering::Greater => Ok((gas_cost, Vec::new())), + // calculate gas spent. + let gas_cost = calc_gas(base_len as u64, exp_len as u64, mod_len as u64, &exp_highp); + // check if we have enough gas. + if gas_cost > gas_limit { + return Err(Error::OutOfGas); } + + // Padding is needed if the input does not contain all 3 values. + let base = get_right_padded_vec(input, 0, base_len); + let exponent = get_right_padded_vec(input, base_len, exp_len); + let modulus = get_right_padded_vec(input, base_len.saturating_add(exp_len), mod_len); + + // Call the modexp. + let output = modexp(&base, &exponent, &modulus); + + // left pad the result to modulus length. bytes will always by less or equal to modulus length. + Ok((gas_cost, left_padding_vec(&output, mod_len))) } -fn byzantium_gas_calc(base_len: u64, exp_len: u64, mod_len: u64, exp_highp: &BigUint) -> u64 { +fn byzantium_gas_calc(base_len: u64, exp_len: u64, mod_len: u64, exp_highp: &U256) -> u64 { // ouput of this function is bounded by 2^128 fn mul_complexity(x: u64) -> U256 { if x <= 64 { @@ -175,7 +149,7 @@ fn byzantium_gas_calc(base_len: u64, exp_len: u64, mod_len: u64, exp_highp: &Big // Calculate gas cost according to EIP 2565: // https://eips.ethereum.org/EIPS/eip-2565 -fn berlin_gas_calc(base_length: u64, exp_length: u64, mod_length: u64, exp_highp: &BigUint) -> u64 { +fn berlin_gas_calc(base_length: u64, exp_length: u64, mod_length: u64, exp_highp: &U256) -> u64 { fn calculate_multiplication_complexity(base_length: u64, mod_length: u64) -> U256 { let max_length = max(base_length, mod_length); let mut words = max_length / 8; diff --git a/crates/precompile/src/utilities.rs b/crates/precompile/src/utilities.rs new file mode 100644 index 0000000000..3b0f19143c --- /dev/null +++ b/crates/precompile/src/utilities.rs @@ -0,0 +1,43 @@ +use core::cmp::min; + +use alloc::vec::Vec; + +/// Get an array from the data, if data does not contain `start` to `len` bytes, add right padding with +/// zeroes +#[inline(always)] +pub fn get_right_padded(data: &[u8], offset: usize) -> [u8; S] { + let mut padded = [0; S]; + let start = min(offset, data.len()); + let end = min(start.saturating_add(S), data.len()); + padded[..end - start].copy_from_slice(&data[start..end]); + padded +} + +/// Get a vector of the data, if data does not contain the slice of `start` to `len`, right pad missing +/// part with zeroes +#[inline(always)] +pub fn get_right_padded_vec(data: &[u8], offset: usize, len: usize) -> Vec { + let mut padded = vec![0; len]; + let start = min(offset, data.len()); + let end = min(start.saturating_add(len), data.len()); + padded[..end - start].copy_from_slice(&data[start..end]); + padded +} + +/// Left padding until `len`. If data is more then len, truncate the right most bytes. +#[inline(always)] +pub fn left_padding(data: &[u8]) -> [u8; S] { + let mut padded = [0; S]; + let end = min(S, data.len()); + padded[S - end..].copy_from_slice(&data[..end]); + padded +} + +/// Left padding until `len`. If data is more then len, truncate the right most bytes. +#[inline(always)] +pub fn left_padding_vec(data: &[u8], len: usize) -> Vec { + let mut padded = vec![0; len]; + let end = min(len, data.len()); + padded[len - end..].copy_from_slice(&data[..end]); + padded +}