Skip to content

Commit

Permalink
added check for negative scalar in _multi_mult; improved documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
fametrano committed Aug 28, 2020
1 parent 73da2b7 commit 181ce0a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
10 changes: 5 additions & 5 deletions btclib/curve.py
Expand Up @@ -340,8 +340,8 @@ def _mult_aff(m: int, Q: Point, ec: CurveGroup) -> Point:
It is not constant-time.
The input point is assumed to be on curve,
m is assumed to have been reduced mod n if appropriate
(e.g. cyclic groups of order n).
the m coefficient is assumed to have been reduced mod n
if appropriate (e.g. cyclic groups of order n).
"""

if m < 0:
Expand All @@ -364,12 +364,12 @@ def _mult_jac(m: int, Q: JacPoint, ec: CurveGroup) -> JacPoint:
This implementation uses 'double & add' algorithm,
binary decomposition of m,
affine coordinates.
Jacobian coordinates.
It is not constant-time.
The input point is assumed to be on curve,
m is assumed to have been reduced mod n if appropriate
(e.g. cyclic groups of order n).
the m coefficient is assumed to have been reduced mod n
if appropriate (e.g. cyclic groups of order n).
"""

if m < 0:
Expand Down
29 changes: 20 additions & 9 deletions btclib/curvemult.py
Expand Up @@ -20,10 +20,7 @@


def mult(m: Integer, Q: Point = None, ec: Curve = secp256k1) -> Point:
"""Point multiplication, implemented using 'double and add'.
Computations use Jacobian coordinates and binary decomposition of m.
"""
"Elliptic curve scalar multiplication."
if Q is None:
QJ = ec.GJ
else:
Expand All @@ -38,6 +35,12 @@ def mult(m: Integer, Q: Point = None, ec: Curve = secp256k1) -> Point:
def _double_mult(
u: int, HJ: JacPoint, v: int, QJ: JacPoint, ec: CurveGroup
) -> JacPoint:
"""Shamir trick for efficient computation of u*H + v*Q.
The input points are assumed to be on curve,
the u and v coefficients are assumed to have been reduced mod n
if appropriate (e.g. cyclic groups of order n).
"""

if u < 0:
raise ValueError(f"negative first coefficient: {hex(u)}")
Expand All @@ -63,7 +66,7 @@ def _double_mult(
def double_mult(
u: Integer, H: Point, v: Integer, Q: Point, ec: Curve = secp256k1
) -> Point:
"""Shamir trick for efficient computation of u*H + v*Q"""
"Shamir trick for efficient computation of u*H + v*Q."

ec.require_on_curve(H)
HJ = _jac_from_aff(H)
Expand All @@ -80,20 +83,28 @@ def double_mult(
def _multi_mult(
scalars: Sequence[int], JPoints: Sequence[JacPoint], ec: CurveGroup
) -> JacPoint:
"""Return the multi scalar multiplication u1*Q1 + ... + un*Qn.
Use Bos-Coster's algorithm for efficient computation.
The input points are assumed to be on curve,
the scalar coefficients are assumed to have been reduced mod n
if appropriate (e.g. cyclic groups of order n).
"""
# source: https://cr.yp.to/badbatch/boscoster2.py

if len(scalars) != len(JPoints):
errMsg = "mismatch between number of scalars and points: "
errMsg += f"{len(scalars)} vs {len(JPoints)}"
raise ValueError(errMsg)

# FIXME
# check for negative scalars
# x = list(zip([-n for n in scalars], JPoints))
x: List[Tuple[int, JacPoint]] = []
for n, PJ in zip(scalars, JPoints):
if n == 0:
if n == 0: # mandatory check to avoid infinite loop
continue
if n < 0:
raise ValueError(f"negative coefficient: {hex(n)}")
x.append((-n, PJ))

if not x:
Expand Down Expand Up @@ -134,7 +145,7 @@ def multi_mult(
ints: List[int] = list()
for P, i in zip(Points, scalars):
i = int_from_integer(i) % ec.n
if i == 0:
if i == 0: # early optimization, even if not strictly necessary
continue
ints.append(i)
ec.require_on_curve(P)
Expand Down
4 changes: 4 additions & 0 deletions btclib/tests/test_curvemult.py
Expand Up @@ -72,6 +72,10 @@ def test_assorted_mult() -> None:
with pytest.raises(ValueError, match=err_msg):
_multi_mult([k1, k2, k3, k4], [ec.GJ, HJ, ec.GJ], ec)

err_msg = "negative coefficient: "
with pytest.raises(ValueError, match=err_msg):
_multi_mult([k1, k2, -k3], [ec.GJ, HJ, ec.GJ], ec)

with pytest.raises(ValueError, match="negative first coefficient: "):
_double_mult(-5, HJ, 1, ec.GJ, ec)
with pytest.raises(ValueError, match="negative second coefficient: "):
Expand Down

0 comments on commit 181ce0a

Please sign in to comment.