Skip to content

Commit

Permalink
moved to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
fametrano committed May 24, 2020
1 parent 878c2c7 commit ec9c372
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 168 deletions.
86 changes: 44 additions & 42 deletions btclib/curvemult.py
Expand Up @@ -35,23 +35,6 @@ def mult(m: Integer, Q: Point = None, ec: Curve = secp256k1) -> Point:
return ec._aff_from_jac(R)


def double_mult(
u: Integer, H: Point, v: Integer, Q: Point, ec: Curve = secp256k1
) -> Point:
"""Shamir trick for efficient computation of u*H + v*Q"""

ec.require_on_curve(H)
HJ = _jac_from_aff(H)

ec.require_on_curve(Q)
QJ = _jac_from_aff(Q)

u = int_from_integer(u) % ec.n
v = int_from_integer(v) % ec.n
R = _double_mult(u, HJ, v, QJ, ec)
return ec._aff_from_jac(R)


def _double_mult(
u: int, HJ: JacPoint, v: int, QJ: JacPoint, ec: CurveGroup
) -> JacPoint:
Expand All @@ -61,12 +44,6 @@ def _double_mult(
if v < 0:
raise ValueError(f"negative v: {hex(v)}")

if u == 0 or HJ[2] == 0:
return _mult_jac(v, QJ, ec)

if v == 0 or QJ[2] == 0:
return _mult_jac(u, HJ, ec)

R = INFJ # initialize as infinity point
msb = max(u.bit_length(), v.bit_length())
while msb > 0:
Expand All @@ -83,29 +60,20 @@ def _double_mult(
return R


def multi_mult(
scalars: Sequence[Integer], Points: Sequence[Point], ec: Curve = secp256k1
def double_mult(
u: Integer, H: Point, v: Integer, Q: Point, ec: Curve = secp256k1
) -> Point:
"""Return the multi scalar multiplication u1*Q1 + ... + un*Qn.
Use Bos-Coster's algorithm for efficient computation;
the input points must be on the curve.
"""

if len(scalars) != len(Points):
errMsg = f"mismatch between scalar length ({len(scalars)}) and "
errMsg += f"Points length ({len(Points)})"
raise ValueError(errMsg)
"""Shamir trick for efficient computation of u*H + v*Q"""

JPoints: List[JacPoint] = list()
ints: List[int] = list()
for P, i in zip(Points, scalars):
ec.require_on_curve(P)
JPoints.append(_jac_from_aff(P))
ints.append(int_from_integer(i) % ec.n)
ec.require_on_curve(H)
HJ = _jac_from_aff(H)

R = _multi_mult(ints, JPoints, ec)
ec.require_on_curve(Q)
QJ = _jac_from_aff(Q)

u = int_from_integer(u) % ec.n
v = int_from_integer(v) % ec.n
R = _double_mult(u, HJ, v, QJ, ec)
return ec._aff_from_jac(R)


Expand All @@ -114,6 +82,11 @@ def _multi_mult(
) -> JacPoint:
# source: https://cr.yp.to/badbatch/boscoster2.py

if len(scalars) != len(JPoints):
errMsg = "mismatch between number of scalars and points: "
errMsg += f"{len(scalars)} vs {len(JPoints)}"
raise ValueError(errMsg)

x = list(zip([-n for n in scalars], JPoints))
heapq.heapify(x)
while len(x) > 1:
Expand All @@ -128,5 +101,34 @@ def _multi_mult(
heapq.heappush(x, (-n2, p2))
np1 = heapq.heappop(x)
n1, p1 = -np1[0], np1[1]
assert n1 < ec.n, "better to take the mod n"
# n1 %= ec.n
return _mult_jac(n1, p1, ec)


def multi_mult(
scalars: Sequence[Integer], Points: Sequence[Point], ec: Curve = secp256k1
) -> Point:
"""Return the multi scalar multiplication u1*Q1 + ... + un*Qn.
Use Bos-Coster's algorithm for efficient computation.
"""

if len(scalars) != len(Points):
errMsg = "mismatch between number of scalars and points: "
errMsg += f"{len(scalars)} vs {len(Points)}"
raise ValueError(errMsg)

JPoints: List[JacPoint] = list()
ints: List[int] = list()
for P, i in zip(Points, scalars):
ec.require_on_curve(P)
JPoints.append(_jac_from_aff(P))
i = int_from_integer(i) % ec.n
if i == 0:
raise ValueError("zero coefficient in Bos-Coster's algorithm")
ints.append(i)

R = _multi_mult(ints, JPoints, ec)

return ec._aff_from_jac(R)
100 changes: 40 additions & 60 deletions btclib/tests/test_curvemult.py
Expand Up @@ -11,68 +11,53 @@
"Tests for `btclib.curvemult` module."

import secrets
import unittest
from typing import List

from btclib.alias import INF, INFJ
from btclib.curve import _mult_aff, _mult_jac
from btclib.curvemult import double_mult, mult, multi_mult
import pytest

from btclib.alias import INFJ
from btclib.curvemult import _double_mult, double_mult, mult, _multi_mult
from btclib.curve import _mult_jac, _jac_from_aff
from btclib.curves import secp256k1
from btclib.tests.test_curves import low_card_curves
from btclib.pedersen import second_generator

ec23_31 = low_card_curves["ec23_31"]


class TestEllipticCurve(unittest.TestCase):
def test_mult(self):
for ec in low_card_curves.values():
for q in range(ec.n):
Q = _mult_aff(q, ec.G, ec)
QJ = _mult_jac(q, ec.GJ, ec)
Q2 = ec._aff_from_jac(QJ)
self.assertEqual(Q, Q2)
# with last curve
self.assertEqual(INF, _mult_aff(3, INF, ec))
self.assertEqual(INFJ, _mult_jac(3, INFJ, ec))

def test_double_mult(self):
ec = ec23_31
for k1 in range(ec.n):
for k2 in range(ec.n):
shamir = double_mult(k1, ec.G, k2, ec.G, ec)
std = ec.add(mult(k1, ec.G, ec), mult(k2, ec.G, ec))
self.assertEqual(shamir, std)
shamir = double_mult(k1, INF, k2, ec.G, ec)
std = ec.add(mult(k1, INF, ec), mult(k2, ec.G, ec))
self.assertEqual(shamir, std)
shamir = double_mult(k1, ec.G, k2, INF, ec)
std = ec.add(mult(k1, ec.G, ec), mult(k2, INF, ec))
self.assertEqual(shamir, std)

def test_multi_mult(self):
ec = secp256k1

k: List[int] = list()
ksum = 0
for i in range(11):
k.append(secrets.randbits(ec.nlen) % ec.n)
ksum += k[i]

P = [ec.G] * len(k)
boscoster = multi_mult(k, P, ec)
self.assertEqual(boscoster, mult(ksum, ec.G, ec))

# mismatch between scalar length and Points length
P = [ec.G] * (len(k) - 1)
self.assertRaises(ValueError, multi_mult, k, P, ec)
# multi_mult(k, P, ec)


def test_double_mult():
H = (
0x50929B74C1A04954B78B4B6035E97A5E078A5A0F28EC96D547BFEE9ACE803AC0,
0x31D3C6863973926E049E637CB1B5F40A36DAC28AF1766968C30C2313F3A38904,
)
def test_assorted_mult():
ec = ec23_31
H = second_generator(ec)
HJ = _jac_from_aff(H)
for k1 in range(1, ec.n):
k2 = 1 + secrets.randbelow(ec.n - 1)
shamir = _double_mult(k1, ec.GJ, k2, ec.GJ, ec)
assert ec._jac_equality(shamir, _mult_jac(k1 + k2, ec.GJ, ec))
shamir = _double_mult(k1, INFJ, k2, ec.GJ, ec)
assert ec._jac_equality(shamir, _mult_jac(k2, ec.GJ, ec))
shamir = _double_mult(k1, ec.GJ, k2, INFJ, ec)
assert ec._jac_equality(shamir, _mult_jac(k1, ec.GJ, ec))

shamir = _double_mult(k1, ec.GJ, k2, HJ, ec)
std = ec._add_jac(_mult_jac(k1, ec.GJ, ec), _mult_jac(k2, HJ, ec))
assert ec._jac_equality(std, shamir)

k3 = 1 + secrets.randbelow(ec.n - 1)
std = ec._add_jac(std, _mult_jac(k3, ec.GJ, ec))
boscoster = _multi_mult([k1, k2, k3], [ec.GJ, HJ, ec.GJ], ec)
assert ec._jac_equality(std, boscoster)

k4 = 1 + secrets.randbelow(ec.n - 1)
std = ec._add_jac(std, _mult_jac(k4, HJ, ec))
boscoster = _multi_mult([k1, k2, k3, k4], [ec.GJ, HJ, ec.GJ, HJ], ec)
assert ec._jac_equality(std, boscoster)

err_msg = "mismatch between number of scalars and points: "
with pytest.raises(ValueError, match=err_msg):
_multi_mult([k1, k2, k3, k4], [ec.GJ, HJ, ec.GJ], ec)


def test_mult_double_mult():
H = second_generator(secp256k1)
G = secp256k1.G

# 0*G + 1*H
Expand Down Expand Up @@ -110,8 +95,3 @@ def test_double_mult():
# 1*G - 5*H
U = double_mult(-5, H, 1, G)
assert U == secp256k1.add(G, T)


if __name__ == "__main__":
# execute only if run as a script
unittest.main() # pragma: no cover

0 comments on commit ec9c372

Please sign in to comment.