diff --git a/btclib/curve.py b/btclib/curve.py index be5b38886..f03a75bd7 100644 --- a/btclib/curve.py +++ b/btclib/curve.py @@ -340,8 +340,8 @@ def _mult_aff(m: int, Q: Point, ec: CurveGroup) -> Point: It is not constant-time. The input point is assumed to be on curve, - m is assumed to have been reduced mod n if appropriate - (e.g. cyclic groups of order n). + the m coefficient is assumed to have been reduced mod n + if appropriate (e.g. cyclic groups of order n). """ if m < 0: @@ -364,12 +364,12 @@ def _mult_jac(m: int, Q: JacPoint, ec: CurveGroup) -> JacPoint: This implementation uses 'double & add' algorithm, binary decomposition of m, - affine coordinates. + Jacobian coordinates. It is not constant-time. The input point is assumed to be on curve, - m is assumed to have been reduced mod n if appropriate - (e.g. cyclic groups of order n). + the m coefficient is assumed to have been reduced mod n + if appropriate (e.g. cyclic groups of order n). """ if m < 0: diff --git a/btclib/curvemult.py b/btclib/curvemult.py index afc2414ce..e1328774f 100644 --- a/btclib/curvemult.py +++ b/btclib/curvemult.py @@ -20,10 +20,7 @@ def mult(m: Integer, Q: Point = None, ec: Curve = secp256k1) -> Point: - """Point multiplication, implemented using 'double and add'. - - Computations use Jacobian coordinates and binary decomposition of m. - """ + "Elliptic curve scalar multiplication." if Q is None: QJ = ec.GJ else: @@ -38,6 +35,12 @@ def mult(m: Integer, Q: Point = None, ec: Curve = secp256k1) -> Point: def _double_mult( u: int, HJ: JacPoint, v: int, QJ: JacPoint, ec: CurveGroup ) -> JacPoint: + """Shamir trick for efficient computation of u*H + v*Q. + + The input points are assumed to be on curve, + the u and v coefficients are assumed to have been reduced mod n + if appropriate (e.g. cyclic groups of order n). + """ if u < 0: raise ValueError(f"negative first coefficient: {hex(u)}") @@ -63,7 +66,7 @@ def _double_mult( def double_mult( u: Integer, H: Point, v: Integer, Q: Point, ec: Curve = secp256k1 ) -> Point: - """Shamir trick for efficient computation of u*H + v*Q""" + "Shamir trick for efficient computation of u*H + v*Q." ec.require_on_curve(H) HJ = _jac_from_aff(H) @@ -80,6 +83,14 @@ def double_mult( def _multi_mult( scalars: Sequence[int], JPoints: Sequence[JacPoint], ec: CurveGroup ) -> JacPoint: + """Return the multi scalar multiplication u1*Q1 + ... + un*Qn. + + Use Bos-Coster's algorithm for efficient computation. + + The input points are assumed to be on curve, + the scalar coefficients are assumed to have been reduced mod n + if appropriate (e.g. cyclic groups of order n). + """ # source: https://cr.yp.to/badbatch/boscoster2.py if len(scalars) != len(JPoints): @@ -87,13 +98,13 @@ def _multi_mult( errMsg += f"{len(scalars)} vs {len(JPoints)}" raise ValueError(errMsg) - # FIXME - # check for negative scalars # x = list(zip([-n for n in scalars], JPoints)) x: List[Tuple[int, JacPoint]] = [] for n, PJ in zip(scalars, JPoints): - if n == 0: + if n == 0: # mandatory check to avoid infinite loop continue + if n < 0: + raise ValueError(f"negative coefficient: {hex(n)}") x.append((-n, PJ)) if not x: @@ -134,7 +145,7 @@ def multi_mult( ints: List[int] = list() for P, i in zip(Points, scalars): i = int_from_integer(i) % ec.n - if i == 0: + if i == 0: # early optimization, even if not strictly necessary continue ints.append(i) ec.require_on_curve(P) diff --git a/btclib/curvemult2.py b/btclib/curvemult2.py new file mode 100644 index 000000000..4d61fda6b --- /dev/null +++ b/btclib/curvemult2.py @@ -0,0 +1,403 @@ +#!/usr/bin/env python3 + +# Copyright (C) 2017-2020 The btclib developers +# +# This file is part of btclib. It is subject to the license terms in the +# LICENSE file found in the top-level directory of this distribution. +# +# No part of btclib including this file, may be copied, modified, propagated, +# or distributed except according to the terms contained in the LICENSE file. + +"""Elliptic curve point multiplication functions. + +The implemented algorithms are: + - Montgomery Ladder + - Scalar multiplication on basis 3 + - Fixed window + - Sliding window + - w-ary non-adjacent form (wNAF) + +References: + - https://en.wikipedia.org/wiki/Elliptic_curve_point_multiplication + - https://cryptojedi.org/peter/data/eccss-20130911b.pdf + - https://ecc2017.cs.ru.nl/slides/ecc2017school-castryck.pdf + +TODO: + - Computational cost of the different multiplications + - Add double function to make it clear the difference between sum and dubling for computational cost + - New alghoritms at the state-of-art: + -https://hal.archives-ouvertes.fr/hal-00932199/document + -https://iacr.org/workshops/ches/ches2006/presentations/Douglas%20Stebila.pdf + -1-s2.0-S1071579704000395-main + - Elegance in the code + - Solve the small problem with wNAF and w=1 + - Multi_mult algorithm: why does it work? +""" + + +from typing import List + +from .alias import INFJ, Integer, JacPoint, Point +from .curve import Curve, _jac_from_aff +from .curves import secp256k1 +from .utils import int_from_integer + + +def _double_jac(Q: JacPoint, ec: Curve = secp256k1) -> JacPoint: + + if Q[2] == 0: + return INFJ + + QZ2 = Q[2] * Q[2] + QY2 = Q[1] * Q[1] + W = (3 * Q[0] * Q[0] + ec._a * QZ2 * QZ2) % ec.p + V = (4 * Q[0] * QY2) % ec.p + X = (W * W - 2 * V) % ec.p + Y = (W * (V - X) - 8 * QY2 * QY2) % ec.p + Z = (2 * Q[1] * Q[2]) % ec.p + return X, Y, Z + + +def _mult_jac_mont_ladder(m: int, Q: JacPoint, ec: Curve) -> JacPoint: + """Scalar multiplication of a curve point in Jacobian coordinates. + + This implementation uses "montgomery ladder" algorithm, + jacobian coordinates. + It is constant-time if the binary size of Q remains the same. + The input point is assumed to be on curve, + m is assumed to have been reduced mod n if appropriate + (e.g. cyclic groups of order n). + """ + + if m < 0: + raise ValueError(f"negative m: {hex(m)}") + + if Q == INFJ: + return Q + + R = INFJ # initialize as infinity point + for m in [int(i) for i in bin(m)[2:]]: # goes through binary digits + if m == 0: + Q = ec._add_jac(R, Q) + R = _double_jac(R, ec) + else: + R = ec._add_jac(R, Q) + Q = _double_jac(Q, ec) + return R + + +def mult_mont_ladder(m: Integer, Q: Point = None, ec: Curve = secp256k1) -> Point: + """Point multiplication, implemented using "montgomery ladder" algorithm to run in constant time. + + This can be beneficial when timing measurements are exposed to an attacker performing a side-channel attack. + This algorithm has the same speed as the double-and-add approach except that it computes the same number + of point additions and doubles regardless of the value of the multiplicand m. + + Computations use Jacobian coordinates and binary decomposition of m. + """ + if Q is None: + QJ = ec.GJ + else: + ec.require_on_curve(Q) + QJ = _jac_from_aff(Q) + + m = int_from_integer(m) % ec.n + R = _mult_jac_mont_ladder(m, QJ, ec) + return ec._aff_from_jac(R) + + +def numberToBase(n, b): + # Returns the list of the digits of n written in basis b + + if n == 0: + return [0] + digits = [] + while n: + digits.append(int(n % b)) + n //= b + return digits[::-1] + + +def _mult_jac_base_3(m: int, Q: JacPoint, ec: Curve) -> JacPoint: + """Scalar multiplication of a curve point in Jacobian coordinates. + This implementation uses the same idea of "double and add" algorithm, but with scalar radix 3. + It is not constant time. + The input point is assumed to be on curve, + m is assumed to have been reduced mod n if appropriate + (e.g. cyclic groups of order n). + """ + + if m < 0: + raise ValueError(f"negative m: {hex(m)}") + + if Q == INFJ: + return Q + + T: List[JacPoint] = [] + T.append(INFJ) + for i in range(1, 3): + T.append(ec._add_jac(T[i - 1], Q)) + + M = numberToBase(m, 3) + + R = T[M[0]] + + for i in range(1, len(M)): + R2 = _double_jac(R, ec) + R = ec._add_jac(R2, R) + R = ec._add_jac(R, T[M[i]]) + + return R + + +def mult_base_3(m: Integer, Q: Point = None, ec: Curve = secp256k1) -> Point: + """Point multiplication, implemented using "double and add" but changing the scalar radix to 3. + + Computations use Jacobian coordinates and decomposition of m basis 3. + """ + if Q is None: + QJ = ec.GJ + else: + ec.require_on_curve(Q) + QJ = _jac_from_aff(Q) + + m = int_from_integer(m) % ec.n + R = _mult_jac_base_3(m, QJ, ec) + return ec._aff_from_jac(R) + + +def _mult_jac_fixed_window(m: int, w: int, Q: JacPoint, ec: Curve) -> JacPoint: + """Scalar multiplication of a curve point in Jacobian coordinates. + This implementation uses the method called "fixed window" + It is not constant time. + For 256-bit scalars choose w=4 or w=5 + The input point is assumed to be on curve, + m is assumed to have been reduced mod n if appropriate + (e.g. cyclic groups of order n). + """ + if m < 0: + raise ValueError(f"negative m: {hex(m)}") + + if Q == INFJ: + return Q + + # a number cannot be written in basis 1 (ie w=0) + if w <= 0: + raise ValueError(f"non positive w: {w}") + + b = pow(2, w) + + T: List[JacPoint] = [] + T.append(INFJ) + for i in range(1, b): + T.append(ec._add_jac(T[i - 1], Q)) + + M = numberToBase(m, b) + + R = T[M[0]] + + for i in range(1, len(M)): + for _ in range(w): + R = _double_jac(R, ec) + R = ec._add_jac(R, T[M[i]]) + + return R + + +def mult_fixed_window( + m: Integer, w: Integer, Q: Point = None, ec: Curve = secp256k1 +) -> Point: + """Point multiplication, implemented using "fixed window" method. + + Computations use Jacobian coordinates and decomposition of m on basis 2^w. + """ + + if Q is None: + QJ = ec.GJ + else: + ec.require_on_curve(Q) + QJ = _jac_from_aff(Q) + + m = int_from_integer(m) % ec.n + w = int_from_integer(w) + R = _mult_jac_fixed_window(m, w, QJ, ec) + return ec._aff_from_jac(R) + + +# Need some modifies to make it more elegant +def _mult_jac_sliding_window(m: int, w: int, Q: JacPoint, ec: Curve) -> JacPoint: + """Scalar multiplication of a curve point in Jacobian coordinates. + This implementation uses the method called "sliding window". + It has the benefit that the pre-computation stage is roughly half as complex as the normal windowed method . + It is not constant time. + For 256-bit scalars choose w=4 or w=5 + The input point is assumed to be on curve, + m is assumed to have been reduced mod n if appropriate + (e.g. cyclic groups of order n). + """ + + if m < 0: + raise ValueError(f"negative m: {hex(m)}") + + if Q == INFJ: + return Q + + if w <= 0: + raise ValueError(f"non positive w: {w}") + + k = w - 1 + p = pow(2, k) + + P = Q + for _ in range(k): + P = _double_jac(P, ec) + + T: List[JacPoint] = [] + T.append(P) + for i in range(1, p): + T.append(ec._add_jac(T[i - 1], Q)) + + M = numberToBase(m, 2) + + R = INFJ + + i = 0 + while i < len(M): + if M[i] == 0: + R = _double_jac(R, ec) + i += 1 + else: + j = min(len(M) - i, w) + t = M[i] + for a in range(1, j): + t = 2 * t + M[i + a] + + if j < w: + for b in range(i, (i + j)): + R = _double_jac(R, ec) + if M[b] == 1: + R = ec._add_jac(R, Q) + return R + + else: + for _ in range(w): + R = _double_jac(R, ec) + R = ec._add_jac(R, T[t - p]) + i += j + return R + + +def mult_sliding_window( + m: Integer, w: Integer, Q: Point = None, ec: Curve = secp256k1 +) -> Point: + """Point multiplication, implemented using "sliding window" method. + + Computations use Jacobian coordinates and decomposition of m on basis 2. + """ + + if Q is None: + QJ = ec.GJ + else: + ec.require_on_curve(Q) + QJ = _jac_from_aff(Q) + + m = int_from_integer(m) % ec.n + w = int_from_integer(w) + R = _mult_jac_sliding_window(m, w, QJ, ec) + return ec._aff_from_jac(R) + + +def mods(m: int, w: int) -> int: + # Signed modulo function + """ + Need minor changes: + mods does NOT work for w=1. However the function in NOT really meant to be used for w=1 + For w=1 it always gives back -1 and enters an infinte loop + """ + + w2 = pow(2, w) + M = m % w2 + if M >= (w2 / 2): + return M - w2 + else: + return M + + +def _mult_jac_w_NAF(m: int, w: int, Q: JacPoint, ec: Curve) -> JacPoint: + """Scalar multiplication of a curve point in Jacobian coordinates. + This implementation uses the same method called "w-ary non-adjacent form" (wNAF) + we make use of the fact that point subtraction is as easy as point addition to perform fewer operations compared to sliding-window + In fact, on Weierstrass curves, known P, -P can be computed on the fly. + + The input point is assumed to be on curve, + m is assumed to have been reduced mod n if appropriate + (e.g. cyclic groups of order n). + """ + if m < 0: + raise ValueError(f"negative m: {hex(m)}") + + if m == 0: + return INFJ + + if Q == INFJ: + return Q + + if w <= 0: + raise ValueError(f"non positive w: {w}") + + i = 0 + + M: List[int] = [] + while m > 0: + if (m % 2) == 1: + M.append(mods(m, w)) + m -= M[i] + else: + M.append(0) + m //= 2 + i += 1 + + p = i + + b = pow(2, w) + + Q2 = _double_jac(Q, ec) + + T: List[JacPoint] = [] + T.append(Q) + for i in range(1, (b // 2)): + T.append(ec._add_jac(T[i - 1], Q2)) + for i in range((b // 2), b): + T.append(ec.negate_jac(T[i - (b // 2)])) + + R = INFJ + + for j in range(p - 1, -1, -1): + R = _double_jac(R, ec) + if M[j] != 0: + if M[j] > 0: + # It adds the element jQ + R = ec._add_jac(R, T[(M[j] - 1) // 2]) + else: + # In this case it adds the opposite, ie -jQ + R = ec._add_jac(R, T[(b // 2) - ((M[j] + 1) // 2)]) + + return R + + +def mult_w_NAF(m: Integer, w: Integer, Q: Point = None, ec: Curve = secp256k1) -> Point: + """Point multiplication, implemented using "w-NAF" method. + + Computations use Jacobian coordinates and decomposition of m on basis 2^w. + """ + + if Q is None: + QJ = ec.GJ + else: + ec.require_on_curve(Q) + QJ = _jac_from_aff(Q) + + m = int_from_integer(m) % ec.n + w = int_from_integer(w) + R = _mult_jac_w_NAF(m, w, QJ, ec) + return ec._aff_from_jac(R) diff --git a/btclib/tests/test_curvemult.py b/btclib/tests/test_curvemult.py index d058e0e1d..77a628407 100644 --- a/btclib/tests/test_curvemult.py +++ b/btclib/tests/test_curvemult.py @@ -72,6 +72,10 @@ def test_assorted_mult() -> None: with pytest.raises(ValueError, match=err_msg): _multi_mult([k1, k2, k3, k4], [ec.GJ, HJ, ec.GJ], ec) + err_msg = "negative coefficient: " + with pytest.raises(ValueError, match=err_msg): + _multi_mult([k1, k2, -k3], [ec.GJ, HJ, ec.GJ], ec) + with pytest.raises(ValueError, match="negative first coefficient: "): _double_mult(-5, HJ, 1, ec.GJ, ec) with pytest.raises(ValueError, match="negative second coefficient: "): diff --git a/btclib/tests/test_curvemult2.py b/btclib/tests/test_curvemult2.py new file mode 100644 index 000000000..340c6bc4b --- /dev/null +++ b/btclib/tests/test_curvemult2.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 + +# Copyright (C) 2017-2020 The btclib developers +# +# This file is part of btclib. It is subject to the license terms in the +# LICENSE file found in the top-level directory of this distribution. +# No part of btclib including this file, may be copied, modified, propagated, +# or distributed except according to the terms contained in the LICENSE file. + +"Tests for `btclib.curvemult2` module." + +import pytest + +from btclib.alias import INFJ +from btclib.curvemult2 import ( + _mult_jac_base_3, + _mult_jac_fixed_window, + _mult_jac_mont_ladder, + _mult_jac_sliding_window, + _mult_jac_w_NAF, +) +from btclib.tests.test_curves import low_card_curves + +ec23_31 = low_card_curves["ec23_31"] + + +def test_mont_ladder(): + for ec in low_card_curves.values(): + assert _mult_jac_mont_ladder(0, ec.GJ, ec) == INFJ + assert _mult_jac_mont_ladder(0, INFJ, ec) == INFJ + + assert _mult_jac_mont_ladder(1, INFJ, ec) == INFJ + assert _mult_jac_mont_ladder(1, ec.GJ, ec) == ec.GJ + + PJ = ec._add_jac(ec.GJ, ec.GJ) + assert PJ == _mult_jac_mont_ladder(2, ec.GJ, ec) + + PJ = _mult_jac_mont_ladder(ec.n - 1, ec.GJ, ec) + assert ec._jac_equality(ec.negate_jac(ec.GJ), PJ) + + assert _mult_jac_mont_ladder(ec.n - 1, INFJ, ec) == INFJ + assert ec._add_jac(PJ, ec.GJ) == INFJ + assert _mult_jac_mont_ladder(ec.n, ec.GJ, ec) == INFJ + + with pytest.raises(ValueError, match="negative m: "): + _mult_jac_mont_ladder(-1, ec.GJ, ec) + + +def test_mult_jac_base_3(): + for ec in low_card_curves.values(): + assert _mult_jac_base_3(0, ec.GJ, ec) == INFJ + assert _mult_jac_base_3(0, INFJ, ec) == INFJ + + assert _mult_jac_base_3(1, INFJ, ec) == INFJ + assert _mult_jac_base_3(1, ec.GJ, ec) == ec.GJ + + PJ = ec._add_jac(ec.GJ, ec.GJ) + assert PJ == _mult_jac_base_3(2, ec.GJ, ec) + + PJ = _mult_jac_base_3(ec.n - 1, ec.GJ, ec) + assert ec._jac_equality(ec.negate_jac(ec.GJ), PJ) + + assert _mult_jac_base_3(ec.n - 1, INFJ, ec) == INFJ + assert ec._add_jac(PJ, ec.GJ) == INFJ + assert _mult_jac_base_3(ec.n, ec.GJ, ec) == INFJ + + with pytest.raises(ValueError, match="negative m: "): + _mult_jac_base_3(-1, ec.GJ, ec) + + +def test_mult_jac_fixed_window(): + for k in range(1, 10): # Actually it makes use of w=4 or w=5, only to check + for ec in low_card_curves.values(): + assert _mult_jac_fixed_window(0, k, ec.GJ, ec) == INFJ + assert _mult_jac_fixed_window(0, k, INFJ, ec) == INFJ + + assert _mult_jac_fixed_window(1, k, INFJ, ec) == INFJ + assert _mult_jac_fixed_window(1, k, ec.GJ, ec) == ec.GJ + + PJ = ec._add_jac(ec.GJ, ec.GJ) + assert PJ == _mult_jac_fixed_window(2, k, ec.GJ, ec) + + PJ = _mult_jac_fixed_window(ec.n - 1, k, ec.GJ, ec) + assert ec._jac_equality(ec.negate_jac(ec.GJ), PJ) + + assert _mult_jac_fixed_window(ec.n - 1, k, INFJ, ec) == INFJ + assert ec._add_jac(PJ, ec.GJ) == INFJ + assert _mult_jac_fixed_window(ec.n, k, ec.GJ, ec) == INFJ + + with pytest.raises(ValueError, match="negative m: "): + _mult_jac_fixed_window(-1, k, ec.GJ, ec) + + +def test_mult_jac_sliding_window(): + for k in range(1, 10): # Actually it makes use of w=4 or w=5, only to check + for ec in low_card_curves.values(): + assert _mult_jac_sliding_window(0, k, ec.GJ, ec) == INFJ + assert _mult_jac_sliding_window(0, k, INFJ, ec) == INFJ + + assert _mult_jac_sliding_window(1, k, INFJ, ec) == INFJ + assert _mult_jac_sliding_window(1, k, ec.GJ, ec) == ec.GJ + + PJ = ec._add_jac(ec.GJ, ec.GJ) + assert PJ == _mult_jac_sliding_window(2, k, ec.GJ, ec) + + PJ = _mult_jac_sliding_window(ec.n - 1, k, ec.GJ, ec) + assert ec._jac_equality(ec.negate_jac(ec.GJ), PJ) + + assert _mult_jac_sliding_window(ec.n - 1, k, INFJ, ec) == INFJ + assert ec._add_jac(PJ, ec.GJ) == INFJ + assert _mult_jac_sliding_window(ec.n, k, ec.GJ, ec) == INFJ + + with pytest.raises(ValueError, match="negative m: "): + _mult_jac_sliding_window(-1, k, ec.GJ, ec) + + +# it does NOT work for k=1 +def test_mult_jac_w_NAF(): + for k in range(2, 10): + for ec in low_card_curves.values(): + assert _mult_jac_w_NAF(0, k, ec.GJ, ec) == INFJ + assert _mult_jac_w_NAF(0, k, INFJ, ec) == INFJ + + assert _mult_jac_w_NAF(1, k, INFJ, ec) == INFJ + assert _mult_jac_w_NAF(1, k, ec.GJ, ec) == ec.GJ + + PJ = ec._add_jac(ec.GJ, ec.GJ) + assert PJ == _mult_jac_w_NAF(2, k, ec.GJ, ec) + + PJ = _mult_jac_w_NAF(ec.n - 1, k, ec.GJ, ec) + assert ec._jac_equality(ec.negate_jac(ec.GJ), PJ) + + assert _mult_jac_w_NAF(ec.n - 1, k, INFJ, ec) == INFJ + assert ec._add_jac(PJ, ec.GJ) == INFJ + assert _mult_jac_w_NAF(ec.n, k, ec.GJ, ec) == INFJ + + with pytest.raises(ValueError, match="negative m: "): + _mult_jac_w_NAF(-1, k, ec.GJ, ec)