Skip to content

Commit

Permalink
perf(crypto): CRP-2258 Add support for multiplication by generator
Browse files Browse the repository at this point in the history
  • Loading branch information
randombit committed Jan 29, 2024
1 parent fcaf9fd commit 7405dfa
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 37 deletions.
188 changes: 151 additions & 37 deletions rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/src/group.rs
Expand Up @@ -949,7 +949,12 @@ impl EccPoint {
pub fn mul_by_g(scalar: &EccScalar) -> Self {
match scalar {
EccScalar::K256(s) => secp256k1::Point::mul_by_g(s).into(),
EccScalar::P256(s) => secp256r1::Point::mul_by_g(s).into(),
EccScalar::P256(_) => {
// This unwrap is safe because mul can only fail if
// the point and scalar are on different curves, but we
// only invoke the function if the scalar is P256.
PRECOMPUTED_GX_P256.mul(scalar).unwrap()
}
}
}

Expand Down Expand Up @@ -1158,33 +1163,59 @@ impl<const WINDOW_SIZE: usize> WindowInfo<WINDOW_SIZE> {

const MASK: u8 = 0xFFu8 >> (8 - WINDOW_SIZE);
const MAX: usize = 1 << WINDOW_SIZE;
const WINDOWS_IN_BYTE: usize = 8 / WINDOW_SIZE;

/// Returns the bit offset for window `w`.
/// * `bit_len` denotes the total bit size
/// * `inverted_w` denotes the window index counting from the least significant part of the scalar
#[inline(always)]
fn window_bit_offset(w: usize) -> usize {
8 - Self::SIZE - Self::SIZE * (w % Self::WINDOWS_IN_BYTE)
fn window_bit_offset(inverted_w: usize) -> usize {
(inverted_w * WINDOW_SIZE) % 8
}

/// Returns the number of windows in `curve_type`.
#[inline(always)]
/// Extracts a window from a serialized scalar value.
fn number_of_windows(curve_type: EccCurveType) -> usize {
(curve_type.scalar_bits() + WINDOW_SIZE - 1) / WINDOW_SIZE
}

/// Extract a window from a serialized scalar value
///
/// Treats the scalar as if it was a sequence of windows, each of `WINDOW_SIZE` bits,
/// Treat the scalar as if it was a sequence of windows, each of WINDOW_SIZE bits,
/// and return the `w`th one of them. For 8 bit windows, this is simply the byte
/// value. For smaller windows this is some subset of a single byte.
/// Note that `w=0` is the window corresponding to the largest value, i.e., if
/// out scalar spans one byte and is equal to 10101111_2=207_10, then it first, say
/// 4-bit, window will be 1010_2=10_10.
///
/// Only window sizes which are a power of 2 are supported which simplifies the
/// implementation to not require creating windows that cross byte boundaries.
/// Only window sizes in 1..=8 are supported.
#[inline(always)]
fn extract(scalar: &[u8], w: usize) -> u8 {
assert!(WINDOW_SIZE == 1 || WINDOW_SIZE == 2 || WINDOW_SIZE == 4 || WINDOW_SIZE == 8);
let window_byte = scalar[w / Self::WINDOWS_IN_BYTE];
(window_byte >> Self::window_bit_offset(w)) & Self::MASK
}
assert!((1..=8).contains(&WINDOW_SIZE));
const BITS_IN_BYTE: usize = 8;

/// Returns the number of windows in `curve_type`.
#[inline(always)]
fn number_of_windows(curve_type: EccCurveType) -> usize {
curve_type.scalar_bits() / WINDOW_SIZE
let scalar_bytes = scalar.len();
let windows = (scalar_bytes * 8 + WINDOW_SIZE - 1) / WINDOW_SIZE;

// to compute the correct bit offset for bit lengths that are not a power of 2,
// we need to start from the inverted value or otherwise we will have multiple options
// for the offset.
let inverted_w = windows - w - 1;
let bit_offset = Self::window_bit_offset(inverted_w);
let byte_offset = scalar_bytes - 1 - (inverted_w * WINDOW_SIZE) / 8;
let target_byte = scalar[byte_offset];

let no_overflow = bit_offset + WINDOW_SIZE <= BITS_IN_BYTE;

let non_overflow_bits = target_byte >> bit_offset;

if no_overflow || byte_offset == 0 {
// If we can get the window out of single byte, do so
non_overflow_bits & Self::MASK
} else {
// Otherwise we must join two bytes and extract the result
let prev_byte = scalar[byte_offset - 1];
let overflow_bits = prev_byte << (BITS_IN_BYTE - bit_offset);
(non_overflow_bits | overflow_bits) & Self::MASK
}
}
}

Expand Down Expand Up @@ -1530,11 +1561,113 @@ impl NafLut {

lazy_static::lazy_static! {

static ref PRECOMPUTED_GX_P256: EccPointMulTable =
EccPointMulTable::new(&EccPoint::generator_g(EccCurveType::P256)).unwrap();

static ref PRECOMPUTED_PEDERSEN_GX_HY_P256: EccPointMul2Table =
EccPointMul2Table::for_generators_of_curve(EccCurveType::P256).unwrap();

}

// Constant time table lookup
//
// This version has some special logic to simplify the multiplication
// algorithms that use it. If index is zero, then it returns the
// identity element. Otherwise it returns from[index-1].
#[inline(always)]
fn ct_select(from: &[EccPoint], index: usize) -> ThresholdEcdsaResult<EccPoint> {
use subtle::ConstantTimeEq;
let mut result = EccPoint::identity(from[0].curve_type());

let index = index.wrapping_sub(1);
for (i, val) in from.iter().enumerate() {
let choice = usize::ct_eq(&i, &index);
result = EccPoint::conditional_select(&result, val, choice)?;
}

Ok(result)
}

/// Structure for precomputed multiplication p*x
///
/// It works by precomputing a table containing the powers of p
/// which allows the online phase of the scalar multiplication
/// to be effected using only additions.
///
/// As the precomputation phase is expensive this is only worth using
/// for points which are multiplied many times (typically the standard
/// group generators).
pub struct EccPointMulTable {
table: Vec<EccPoint>,
}

impl EccPointMulTable {
/// The number of bits we examine in each scalar per iteration
const WINDOW_BITS: usize = 3;
// 2^w elements minus one (since one table element is always the identity)
const TABLE_ELEM_PER_WINDOW: usize = (1 << Self::WINDOW_BITS) - 1;

pub fn new(p: &EccPoint) -> ThresholdEcdsaResult<Self> {
let curve = p.curve_type();

type Window = WindowInfo<{ EccPointMulTable::WINDOW_BITS }>;
let windows = Window::number_of_windows(curve);

let mut table = Vec::with_capacity(Self::TABLE_ELEM_PER_WINDOW * windows);

let mut accum = p.clone();

for _ in 0..windows {
let x1 = accum;
let x2 = x1.double();
let x3 = x2.add_points(&x1)?;
let x4 = x2.double();
let x5 = x4.add_points(&x1)?;
let x6 = x3.double();
let x7 = x6.add_points(&x1)?;
let x8 = x4.double();

table.push(x1);
table.push(x2);
table.push(x3);
table.push(x4);
table.push(x5);
table.push(x6);
table.push(x7);

accum = x8;
}

Ok(Self { table })
}

pub fn mul(&self, x: &EccScalar) -> ThresholdEcdsaResult<EccPoint> {
let curve = self.table[0].curve_type();

if x.curve_type() != curve {
return Err(ThresholdEcdsaError::CurveMismatch);
}

type Window = WindowInfo<{ EccPointMulTable::WINDOW_BITS }>;
let windows = Window::number_of_windows(curve);

assert_eq!(self.table.len(), windows * Self::TABLE_ELEM_PER_WINDOW);

let s = x.serialize();

let mut accum = EccPoint::identity(curve);

for i in 0..windows {
let tbl_i = &self.table
[Self::TABLE_ELEM_PER_WINDOW * i..(Self::TABLE_ELEM_PER_WINDOW * (i + 1))];
let w = Window::extract(&s, windows - 1 - i) as usize;
accum = accum.add_points(&ct_select(tbl_i, w)?)?;
}

Ok(accum)
}
}

/// Structure for precomputed multiplication g*x+h*y
///
/// It works by precomputing a series of table, each of which
Expand Down Expand Up @@ -1654,7 +1787,7 @@ impl EccPointMul2Table {

let w = w1 + (w2 << Self::WINDOW_BITS);

accum = accum.add_points(&Self::ct_select(tbl_i, w)?)?;
accum = accum.add_points(&ct_select(tbl_i, w)?)?;
}

Ok(accum)
Expand All @@ -1667,23 +1800,4 @@ impl EccPointMul2Table {
let shift = 6 - 2 * (i % 4);
((b >> shift) % 4) as usize
}

// Constant time table lookup
//
// This version is specifically adapted to this algorithm. If
// index is zero, then it returns the identity element. Otherwise
// it returns from[index-1].
#[inline(always)]
fn ct_select(from: &[EccPoint], index: usize) -> ThresholdEcdsaResult<EccPoint> {
use subtle::ConstantTimeEq;
let mut result = EccPoint::identity(from[0].curve_type());

let index = index.wrapping_sub(1);
for (i, val) in from.iter().enumerate() {
let choice = usize::ct_eq(&i, &index);
result = EccPoint::conditional_select(&result, val, choice)?;
}

Ok(result)
}
}
19 changes: 19 additions & 0 deletions rs/crypto/internal/crypto_lib/threshold_sig/tecdsa/tests/group.rs
Expand Up @@ -262,6 +262,25 @@ fn test_point_negate() -> ThresholdEcdsaResult<()> {
Ok(())
}

#[test]
fn test_mul_by_g_is_correct() -> ThresholdEcdsaResult<()> {
let rng = &mut reproducible_rng();

for curve_type in EccCurveType::all() {
let g = EccPoint::generator_g(curve_type);
for small in 0..1024 {
let s = EccScalar::from_u64(curve_type, small);
assert_eq!(g.scalar_mul(&s)?, EccPoint::mul_by_g(&s));
}

for _iteration in 0..300 {
let s = EccScalar::random(curve_type, rng);
assert_eq!(g.scalar_mul(&s)?, EccPoint::mul_by_g(&s));
}
}
Ok(())
}

#[test]
fn test_y_is_even() -> ThresholdEcdsaResult<()> {
let rng = &mut reproducible_rng();
Expand Down

0 comments on commit 7405dfa

Please sign in to comment.