Skip to content

Commit

Permalink
moved to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
fametrano committed May 23, 2020
1 parent 6ac5c42 commit 2d1527d
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 89 deletions.
92 changes: 57 additions & 35 deletions btclib/curve.py
Expand Up @@ -15,7 +15,9 @@

from .alias import INF, INFJ, Integer, JacPoint, Point
from .numbertheory import legendre_symbol, mod_inv, mod_sqrt
from .utils import int_from_integer
from .utils import hex_string, int_from_integer

_HEXTHRESHOLD = 0xFFFFFFFF


def _jac_from_aff(Q: Point) -> JacPoint:
Expand Down Expand Up @@ -49,7 +51,9 @@ def __init__(self, p: Integer, a: Integer, b: Integer) -> None:
# 1) check that p is a prime
# Fermat test will do as _probabilistic_ primality test...
if p < 2 or p % 2 == 0 or pow(2, p - 1, p) != 1:
raise ValueError(f"p ({hex(p)}) is not prime")
err_msg = "p is not prime: "
err_msg += f"{hex_string(p)}" if p > _HEXTHRESHOLD else f"{p}"
raise ValueError(err_msg)

plen = p.bit_length()
self.psize = (plen + 7) // 8
Expand All @@ -58,10 +62,20 @@ def __init__(self, p: Integer, a: Integer, b: Integer) -> None:
self.p = p

# 2. check that a and b are integers in the interval [0, p−1]
if not 0 <= a < p:
raise ValueError(f"invalid a ({hex(a)}) for given p ({hex(p)})")
if not 0 <= b < p:
raise ValueError(f"invalid b ({hex(b)}) for given p ({hex(p)})")
if p <= a:
err_msg = "p <= a: " + (
f"{hex_string(p)} <= {hex_string(a)}"
if p > _HEXTHRESHOLD
else f"{p} <= {a}"
)
raise ValueError(err_msg)
if p <= b:
err_msg = "p <= b: " + (
f"{hex_string(p)} <= {hex_string(b)}"
if p > _HEXTHRESHOLD
else f"{p} <= {b}"
)
raise ValueError(err_msg)

# 3. Check that 4*a^3 + 27*b^2 ≠ 0 (mod p)
d = 4 * a * a * a + 27 * b * b
Expand All @@ -72,14 +86,14 @@ def __init__(self, p: Integer, a: Integer, b: Integer) -> None:

def __str__(self) -> str:
result = "Curve"
if self.p > 0xFFFFFFFF:
result += f"\n p = {hex(self.p).upper()}"
if self.p > _HEXTHRESHOLD:
result += f"\n p = {hex_string(self.p)}"
else:
result += f"\n p = {self.p}"

if self._a > 0xFFFFFFFF or self._b > 0xFFFFFFFF:
result += f"\n a = {hex(self._a).upper()}"
result += f"\n b = {hex(self._b).upper()}"
if self._a > _HEXTHRESHOLD or self._b > _HEXTHRESHOLD:
result += f"\n a = {hex_string(self._a)}"
result += f"\n b = {hex_string(self._b)}"
else:
result += f"\n a = {self._a}"
result += f"\n b = {self._b}"
Expand All @@ -88,13 +102,13 @@ def __str__(self) -> str:

def __repr__(self) -> str:
result = "Curve("
if self.p > 0xFFFFFFFF:
result += f"{hex(self.p).upper()}"
if self.p > _HEXTHRESHOLD:
result += f"'{hex_string(self.p)}'"
else:
result += f"{self.p}"

if self._a > 0xFFFFFFFF or self._b > 0xFFFFFFFF:
result += f", {hex(self._a).upper()}, {hex(self._b).upper()}"
if self._a > _HEXTHRESHOLD or self._b > _HEXTHRESHOLD:
result += f", '{hex_string(self._a)}', '{hex_string(self._b)}'"
else:
result += f", {self._a}, {self._b}"

Expand Down Expand Up @@ -216,7 +230,9 @@ def _y2(self, x: int) -> int:
def y(self, x: int) -> int:
"""Return the y coordinate from x, as in (x, y)."""
if not 0 <= x < self.p:
raise ValueError(f"x-coordinate not in 0..p-1: {hex(x)}")
err_msg = "x-coordinate not in 0..p-1: "
err_msg += f"{hex_string(x)}" if x > _HEXTHRESHOLD else f"{x}"
raise ValueError(err_msg)
y2 = self._y2(x)
# mod_sqrt will raise a ValueError if root does not exist
return mod_sqrt(y2, self.p)
Expand All @@ -236,7 +252,7 @@ def is_on_curve(self, Q: Point) -> bool:
if Q[1] == 0: # Infinity point in affine coordinates
return True
if not 0 < Q[1] < self.p: # y cannot be zero
raise ValueError(f"y-coordinate {hex(Q[1])} not in (0, p)")
raise ValueError(f"y-coordinate {hex_string(Q[1])} not in (0, p)")
return self._y2(Q[0]) == (Q[1] * Q[1] % self.p)

def has_square_y(self, Q: Union[Point, JacPoint]) -> bool:
Expand All @@ -256,7 +272,7 @@ def require_p_ThreeModFour(self) -> None:
An Error is raised if not.
"""
if not self.pIsThreeModFour:
m = f"field prime is not equal to 3 mod 4: {hex(self.p)}"
m = f"field prime is not equal to 3 mod 4: {hex_string(self.p)}"
raise ValueError(m)

# break the y simmetry: even/odd, low/high, or quadratic residue criteria
Expand Down Expand Up @@ -351,23 +367,23 @@ def __init__(self, p: Integer, a: Integer, b: Integer, G: Point) -> None:
raise ValueError("Generator must a be a sequence[int, int]")
self.G = (int_from_integer(G[0]), int_from_integer(G[1]))
if not self.is_on_curve(self.G):
raise ValueError("Generator is not on the 'x^3 + a*x + b' curve")
raise ValueError("Generator is not on the curve")
self.GJ = self.G[0], self.G[1], 1 # Jacobian coordinates

def __str__(self) -> str:
result = super().__str__()
if self.p > 0xFFFFFFFF:
result += f"\n x_G = {hex(self.G[0]).upper()}"
result += f"\n y_G = {hex(self.G[1]).upper()}"
if self.p > _HEXTHRESHOLD:
result += f"\n x_G = {hex_string(self.G[0])}"
result += f"\n y_G = {hex_string(self.G[1])}"
else:
result += f"\n x_G = {self.G[0]}"
result += f"\n y_G = {self.G[1]}"
return result

def __repr__(self) -> str:
result = super().__repr__()[:-1]
if self.p > 0xFFFFFFFF:
result += f", ({hex(self.G[0]).upper()}, {hex(self.G[1]).upper()})"
if self.p > _HEXTHRESHOLD:
result += f", ('{hex_string(self.G[0])}', '{hex_string(self.G[1])}')"
else:
result += f", ({self.G[0]}, {self.G[1]})"
result += ")"
Expand Down Expand Up @@ -404,31 +420,37 @@ def __init__(

# 5. Check that n is prime.
if n < 2 or n % 2 == 0 or pow(2, n - 1, n) != 1:
raise ValueError(f"n ({hex(n)}) is not prime")
err_msg = "n is not prime: "
err_msg += f"{hex_string(n)}" if n > _HEXTHRESHOLD else f"{n}"
raise ValueError(err_msg)
delta = int(2 * sqrt(self.p))
# also check n with Hasse Theorem
if h < 2:
if not (self.p + 1 - delta <= n <= self.p + 1 + delta):
m = f"n ({hex(n)}) not in [p + 1 - delta, p + 1 + delta]"
raise ValueError(m)
err_msg = "n not in p+1-delta..p+1+delta: "
err_msg += f"{hex_string(n)}" if n > _HEXTHRESHOLD else f"{n}"
raise ValueError(err_msg)

# 7. Check that G ≠ INF, nG = INF
if self.G[1] == 0:
m = "INF point does not generate a prime order subgroup"
m = "INF point cannot be a generator"
raise ValueError(m)
Inf = _mult_aff(n, self.G, self)
if Inf[1] != 0:
raise ValueError(f"n ({hex(n)}) is not the group order")
err_msg = "n is not the group order: "
err_msg += f"{hex_string(n)}" if n > _HEXTHRESHOLD else f"{n}"
raise ValueError(err_msg)

# 6. Check cofactor
exp_h = int(1 / n + delta / n + self.p / n)
if h != exp_h:
raise ValueError(f"h ({h}) not as expected ({exp_h})")
raise ValueError(f"invalid h: {h}, expected {exp_h}")
self.h = h

# 8. Check that n ≠ p
assert n != p, f"n=p ({hex(n)}) -> weak curve"
assert n != p, f"n=p weak curve: {hex_string(n)}"
# raise UserWarning("n=p -> weak curve")

if weakness_check:
# 8. Check that p^i % n ≠ 1 for all 1≤i<100
for i in range(1, 100):
Expand All @@ -437,17 +459,17 @@ def __init__(

def __str__(self) -> str:
result = super().__str__()
if self.p > 0xFFFFFFFF:
result += f"\n n = {hex(self.n).upper()}"
if self.n > _HEXTHRESHOLD:
result += f"\n n = {hex_string(self.n)}"
else:
result += f"\n n = {self.n}"
result += f"\n h = {self.h}"
return result

def __repr__(self) -> str:
result = super().__repr__()[:-1]
if self.p > 0xFFFFFFFF:
result += f", {hex(self.n).upper()}"
if self.n > _HEXTHRESHOLD:
result += f", '{hex_string(self.n)}'"
else:
result += f", {self.n}"
result += f", {self.h}"
Expand Down
70 changes: 70 additions & 0 deletions btclib/tests/test_curve.py
@@ -0,0 +1,70 @@
#!/usr/bin/env python3

# Copyright (C) 2017-2020 The btclib developers
#
# This file is part of btclib. It is subject to the license terms in the
# LICENSE file found in the top-level directory of this distribution.
#
# No part of btclib including this file, may be copied, modified, propagated,
# or distributed except according to the terms contained in the LICENSE file.

"Tests for `btclib.curve` module."

import pytest

from btclib.alias import INF
from btclib.curve import Curve


def test_exceptions():

with pytest.raises(ValueError, match="p is not prime: "):
Curve(15, 0, 2, (1, 9), 19, 1, False)

with pytest.raises(ValueError, match="negative integer: "):
Curve(13, -1, 2, (1, 9), 19, 1, False)

with pytest.raises(ValueError, match="p <= a: "):
Curve(13, 13, 2, (1, 9), 19, 1, False)

with pytest.raises(ValueError, match="negative integer: "):
Curve(13, 0, -2, (1, 9), 19, 1, False)

with pytest.raises(ValueError, match="p <= b: "):
Curve(13, 0, 13, (1, 9), 19, 1, False)

with pytest.raises(ValueError, match="zero discriminant"):
Curve(11, 7, 7, (1, 9), 19, 1, False)

err_msg = "Generator must a be a sequence\\[int, int\\]"
with pytest.raises(ValueError, match=err_msg):
Curve(13, 0, 2, (1, 9, 1), 19, 1, False)

with pytest.raises(ValueError, match="Generator is not on the curve"):
Curve(13, 0, 2, (2, 9), 19, 1, False)

with pytest.raises(ValueError, match="n is not prime: "):
Curve(13, 0, 2, (1, 9), 20, 1, False)

with pytest.raises(ValueError, match="n not in "):
Curve(13, 0, 2, (1, 9), 71, 1, False)

with pytest.raises(ValueError, match="INF point cannot be a generator"):
Curve(13, 0, 2, INF, 19, 1, False)

with pytest.raises(ValueError, match="n is not the group order: "):
Curve(13, 0, 2, (1, 9), 17, 1, False)

with pytest.raises(ValueError, match="invalid h: "):
Curve(13, 0, 2, (1, 9), 19, 2, False)

# n=p -> weak curve
# missing

with pytest.raises(UserWarning, match="weak curve"):
Curve(11, 2, 7, (6, 9), 7, 2, True)

# good curve
ec = Curve(13, 0, 2, (1, 9), 19, 1, False)
with pytest.raises(ValueError, match="x-coordinate not in 0..p-1: "):
ec.y(ec.p)
55 changes: 1 addition & 54 deletions btclib/tests/test_curves.py
Expand Up @@ -42,60 +42,7 @@
all_curves.update(CURVES)


class TestEllipticCurve(unittest.TestCase):
def test_exceptions(self):
# good
Curve(11, 2, 7, (6, 9), 7, 2, False)

# p not odd
self.assertRaises(ValueError, Curve, 10, 2, 7, (6, 9), 7, 1, False)

# p not prime
self.assertRaises(ValueError, Curve, 15, 2, 7, (6, 9), 7, 1, False)

# a > p
self.assertRaises(ValueError, Curve, 11, 12, 7, (6, 9), 13, 1, False)

# b > p
self.assertRaises(ValueError, Curve, 11, 2, 12, (6, 9), 13, 1, False)

# zero discriminant
self.assertRaises(ValueError, Curve, 11, 7, 7, (6, 9), 7, 1, False)

# G not Tuple (int, int)
self.assertRaises(ValueError, Curve, 11, 2, 7, (6, 9, 1), 7, 1, False)

# G not on curve
self.assertRaises(ValueError, Curve, 11, 2, 7, (7, 9), 7, 1, False)

# n not prime
self.assertRaises(ValueError, Curve, 11, 2, 7, (6, 9), 8, 1, False)

# n not Hesse
self.assertRaises(ValueError, Curve, 11, 2, 7, (6, 9), 71, 1, True)

# h not as expected
self.assertRaises(ValueError, Curve, 11, 2, 7, (6, 9), 7, 1, True)
# Curve(11, 2, 7, (6, 9), 7, 1, 0, True)

# n not group order
self.assertRaises(ValueError, Curve, 11, 2, 7, (6, 9), 13, 1, False)

# n=p -> weak curve
# missing

# weak curve
self.assertRaises(UserWarning, Curve, 11, 2, 7, (6, 9), 7, 2, True)

# x-coordinate not in 0..p-1:
ec = CURVES["secp256k1"]
self.assertRaises(ValueError, ec.y, ec.p)
# secp256k1.y(secp256k1.p)

# INF point does not generate a prime order subgroup
self.assertRaises(ValueError, Curve, 11, 2, 7, INF, 7, 2, False)
# Curve(11, 2, 7, INF, 7, 2, 0, False)

class TestEllipticCurves(unittest.TestCase):
def test_all_curves(self):
for ec in all_curves.values():
assert mult(0, ec.G, ec) == INF
Expand Down

0 comments on commit 2d1527d

Please sign in to comment.