diff --git a/rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/src/group.rs b/rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/src/group.rs index 7cde7918e50..e5d063b53f0 100644 --- a/rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/src/group.rs +++ b/rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/src/group.rs @@ -949,7 +949,12 @@ impl EccPoint { pub fn mul_by_g(scalar: &EccScalar) -> Self { match scalar { EccScalar::K256(s) => secp256k1::Point::mul_by_g(s).into(), - EccScalar::P256(s) => secp256r1::Point::mul_by_g(s).into(), + EccScalar::P256(_) => { + // This unwrap is safe because mul can only fail if + // the point and scalar are on different curves, but we + // only invoke the function if the scalar is P256. + PRECOMPUTED_GX_P256.mul(scalar).unwrap() + } } } @@ -1158,33 +1163,59 @@ impl WindowInfo { const MASK: u8 = 0xFFu8 >> (8 - WINDOW_SIZE); const MAX: usize = 1 << WINDOW_SIZE; - const WINDOWS_IN_BYTE: usize = 8 / WINDOW_SIZE; - /// Returns the bit offset for window `w`. + /// * `bit_len` denotes the total bit size + /// * `inverted_w` denotes the window index counting from the least significant part of the scalar #[inline(always)] - fn window_bit_offset(w: usize) -> usize { - 8 - Self::SIZE - Self::SIZE * (w % Self::WINDOWS_IN_BYTE) + fn window_bit_offset(inverted_w: usize) -> usize { + (inverted_w * WINDOW_SIZE) % 8 } + /// Returns the number of windows in `curve_type`. #[inline(always)] - /// Extracts a window from a serialized scalar value. + fn number_of_windows(curve_type: EccCurveType) -> usize { + (curve_type.scalar_bits() + WINDOW_SIZE - 1) / WINDOW_SIZE + } + + /// Extract a window from a serialized scalar value /// - /// Treats the scalar as if it was a sequence of windows, each of `WINDOW_SIZE` bits, + /// Treat the scalar as if it was a sequence of windows, each of WINDOW_SIZE bits, /// and return the `w`th one of them. For 8 bit windows, this is simply the byte /// value. For smaller windows this is some subset of a single byte. + /// Note that `w=0` is the window corresponding to the largest value, i.e., if + /// out scalar spans one byte and is equal to 10101111_2=207_10, then it first, say + /// 4-bit, window will be 1010_2=10_10. /// - /// Only window sizes which are a power of 2 are supported which simplifies the - /// implementation to not require creating windows that cross byte boundaries. + /// Only window sizes in 1..=8 are supported. + #[inline(always)] fn extract(scalar: &[u8], w: usize) -> u8 { - assert!(WINDOW_SIZE == 1 || WINDOW_SIZE == 2 || WINDOW_SIZE == 4 || WINDOW_SIZE == 8); - let window_byte = scalar[w / Self::WINDOWS_IN_BYTE]; - (window_byte >> Self::window_bit_offset(w)) & Self::MASK - } + assert!((1..=8).contains(&WINDOW_SIZE)); + const BITS_IN_BYTE: usize = 8; - /// Returns the number of windows in `curve_type`. - #[inline(always)] - fn number_of_windows(curve_type: EccCurveType) -> usize { - curve_type.scalar_bits() / WINDOW_SIZE + let scalar_bytes = scalar.len(); + let windows = (scalar_bytes * 8 + WINDOW_SIZE - 1) / WINDOW_SIZE; + + // to compute the correct bit offset for bit lengths that are not a power of 2, + // we need to start from the inverted value or otherwise we will have multiple options + // for the offset. + let inverted_w = windows - w - 1; + let bit_offset = Self::window_bit_offset(inverted_w); + let byte_offset = scalar_bytes - 1 - (inverted_w * WINDOW_SIZE) / 8; + let target_byte = scalar[byte_offset]; + + let no_overflow = bit_offset + WINDOW_SIZE <= BITS_IN_BYTE; + + let non_overflow_bits = target_byte >> bit_offset; + + if no_overflow || byte_offset == 0 { + // If we can get the window out of single byte, do so + non_overflow_bits & Self::MASK + } else { + // Otherwise we must join two bytes and extract the result + let prev_byte = scalar[byte_offset - 1]; + let overflow_bits = prev_byte << (BITS_IN_BYTE - bit_offset); + (non_overflow_bits | overflow_bits) & Self::MASK + } } } @@ -1530,11 +1561,113 @@ impl NafLut { lazy_static::lazy_static! { + static ref PRECOMPUTED_GX_P256: EccPointMulTable = + EccPointMulTable::new(&EccPoint::generator_g(EccCurveType::P256)).unwrap(); + static ref PRECOMPUTED_PEDERSEN_GX_HY_P256: EccPointMul2Table = EccPointMul2Table::for_generators_of_curve(EccCurveType::P256).unwrap(); } +// Constant time table lookup +// +// This version has some special logic to simplify the multiplication +// algorithms that use it. If index is zero, then it returns the +// identity element. Otherwise it returns from[index-1]. +#[inline(always)] +fn ct_select(from: &[EccPoint], index: usize) -> ThresholdEcdsaResult { + use subtle::ConstantTimeEq; + let mut result = EccPoint::identity(from[0].curve_type()); + + let index = index.wrapping_sub(1); + for (i, val) in from.iter().enumerate() { + let choice = usize::ct_eq(&i, &index); + result = EccPoint::conditional_select(&result, val, choice)?; + } + + Ok(result) +} + +/// Structure for precomputed multiplication p*x +/// +/// It works by precomputing a table containing the powers of p +/// which allows the online phase of the scalar multiplication +/// to be effected using only additions. +/// +/// As the precomputation phase is expensive this is only worth using +/// for points which are multiplied many times (typically the standard +/// group generators). +pub struct EccPointMulTable { + table: Vec, +} + +impl EccPointMulTable { + /// The number of bits we examine in each scalar per iteration + const WINDOW_BITS: usize = 3; + // 2^w elements minus one (since one table element is always the identity) + const TABLE_ELEM_PER_WINDOW: usize = (1 << Self::WINDOW_BITS) - 1; + + pub fn new(p: &EccPoint) -> ThresholdEcdsaResult { + let curve = p.curve_type(); + + type Window = WindowInfo<{ EccPointMulTable::WINDOW_BITS }>; + let windows = Window::number_of_windows(curve); + + let mut table = Vec::with_capacity(Self::TABLE_ELEM_PER_WINDOW * windows); + + let mut accum = p.clone(); + + for _ in 0..windows { + let x1 = accum; + let x2 = x1.double(); + let x3 = x2.add_points(&x1)?; + let x4 = x2.double(); + let x5 = x4.add_points(&x1)?; + let x6 = x3.double(); + let x7 = x6.add_points(&x1)?; + let x8 = x4.double(); + + table.push(x1); + table.push(x2); + table.push(x3); + table.push(x4); + table.push(x5); + table.push(x6); + table.push(x7); + + accum = x8; + } + + Ok(Self { table }) + } + + pub fn mul(&self, x: &EccScalar) -> ThresholdEcdsaResult { + let curve = self.table[0].curve_type(); + + if x.curve_type() != curve { + return Err(ThresholdEcdsaError::CurveMismatch); + } + + type Window = WindowInfo<{ EccPointMulTable::WINDOW_BITS }>; + let windows = Window::number_of_windows(curve); + + assert_eq!(self.table.len(), windows * Self::TABLE_ELEM_PER_WINDOW); + + let s = x.serialize(); + + let mut accum = EccPoint::identity(curve); + + for i in 0..windows { + let tbl_i = &self.table + [Self::TABLE_ELEM_PER_WINDOW * i..(Self::TABLE_ELEM_PER_WINDOW * (i + 1))]; + let w = Window::extract(&s, windows - 1 - i) as usize; + accum = accum.add_points(&ct_select(tbl_i, w)?)?; + } + + Ok(accum) + } +} + /// Structure for precomputed multiplication g*x+h*y /// /// It works by precomputing a series of table, each of which @@ -1654,7 +1787,7 @@ impl EccPointMul2Table { let w = w1 + (w2 << Self::WINDOW_BITS); - accum = accum.add_points(&Self::ct_select(tbl_i, w)?)?; + accum = accum.add_points(&ct_select(tbl_i, w)?)?; } Ok(accum) @@ -1667,23 +1800,4 @@ impl EccPointMul2Table { let shift = 6 - 2 * (i % 4); ((b >> shift) % 4) as usize } - - // Constant time table lookup - // - // This version is specifically adapted to this algorithm. If - // index is zero, then it returns the identity element. Otherwise - // it returns from[index-1]. - #[inline(always)] - fn ct_select(from: &[EccPoint], index: usize) -> ThresholdEcdsaResult { - use subtle::ConstantTimeEq; - let mut result = EccPoint::identity(from[0].curve_type()); - - let index = index.wrapping_sub(1); - for (i, val) in from.iter().enumerate() { - let choice = usize::ct_eq(&i, &index); - result = EccPoint::conditional_select(&result, val, choice)?; - } - - Ok(result) - } } diff --git a/rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/tests/group.rs b/rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/tests/group.rs index 5dba10b2fe8..530ca06d183 100644 --- a/rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/tests/group.rs +++ b/rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/tests/group.rs @@ -262,6 +262,25 @@ fn test_point_negate() -> ThresholdEcdsaResult<()> { Ok(()) } +#[test] +fn test_mul_by_g_is_correct() -> ThresholdEcdsaResult<()> { + let rng = &mut reproducible_rng(); + + for curve_type in EccCurveType::all() { + let g = EccPoint::generator_g(curve_type); + for small in 0..1024 { + let s = EccScalar::from_u64(curve_type, small); + assert_eq!(g.scalar_mul(&s)?, EccPoint::mul_by_g(&s)); + } + + for _iteration in 0..300 { + let s = EccScalar::random(curve_type, rng); + assert_eq!(g.scalar_mul(&s)?, EccPoint::mul_by_g(&s)); + } + } + Ok(()) +} + #[test] fn test_y_is_even() -> ThresholdEcdsaResult<()> { let rng = &mut reproducible_rng();