From a17086359e3636fadaf47d53c1dfdf0789fd5152 Mon Sep 17 00:00:00 2001 From: Ferdinando Ametrano Date: Sun, 17 May 2020 16:15:40 +0200 Subject: [PATCH] refactored mod_sqrt for improved clarity --- btclib/numbertheory.py | 52 ++++++++++++++++++------------- btclib/tests/test_numbertheory.py | 38 ++++++++++++++++------ 2 files changed, 59 insertions(+), 31 deletions(-) diff --git a/btclib/numbertheory.py b/btclib/numbertheory.py index 851edef1c..11206049e 100644 --- a/btclib/numbertheory.py +++ b/btclib/numbertheory.py @@ -83,45 +83,55 @@ def mod_sqrt(a: int, p: int) -> int: a %= p - # Simple cases if p % 4 == 3: # secp256k1 case # inverse candidate is pow(a, (p + 1) // 4, p) - x = pow(a, (p >> 2) + 1, p) - if x * x % p == a: - return x - raise ValueError(f"No root for {hex(a)} (mod {hex(p)})") + r = pow(a, (p >> 2) + 1, p) elif p % 8 == 5: # inverse candidate is pow(a, (p + 3) // 8, p) - x = pow(a, (p >> 3) + 1, p) - if x * x % p == a: - return x + r = pow(a, (p >> 3) + 1, p) + if r * r % p == a: + return r else: - # inverse candidate - x = x * pow(2, p >> 2, p) % p - if x * x % p == a: - return x + # another inverse candidate + r = r * pow(2, p >> 2, p) % p + else: + return tonelli(a, p) + + if r * r % p != a: raise ValueError(f"No root for {hex(a)} (mod {hex(p)})") - elif a == 0 or p == 2: + return r + + +def tonelli(a: int, p: int) -> int: + """Return a quadratic residue (mod p) of a; p must be a prime. + + The Tonelli-Shanks algorithm is used. + + https://codereview.stackexchange.com/questions/43210/tonelli-shanks-algorithm-implementation-of-prime-modular-square-root/43267 + """ + + a %= p + if a == 0 or p == 2: return a - # Check solution existence for odd primes + # Check solution existence for an odd prime p if legendre_symbol(a, p) != 1: raise ValueError(f"No root for {hex(a)} (mod {hex(p)})") - # Factor p-1 on the form q * 2^s (with Q odd) + # Factor p-1 on the form q * 2^s (with q odd) q, s = p - 1, 0 while q & 1 == 0: s += 1 q >>= 1 + if s == 1: + return pow(a, (p + 1) // 4, p) - # Select a z which is a quadratic non resudue modulo p + # Select a z which is a quadratic non residue modulo p z = 1 while legendre_symbol(z, p) != -1: z += 1 c = pow(z, q, p) - - # Search for a solution - x = pow(a, (q + 1) // 2, p) + r = pow(a, (q + 1) // 2, p) t = pow(a, q, p) m = s while t != 1: @@ -134,9 +144,9 @@ def mod_sqrt(a: int, p: int) -> int: # Update next value to iterate b = pow(c, 1 << (m - i - 1), p) - x = (x * b) % p + r = (r * b) % p c = (b * b) % p t = (t * c) % p m = i - return x + return r diff --git a/btclib/tests/test_numbertheory.py b/btclib/tests/test_numbertheory.py index 92bf4db07..22dfaf345 100644 --- a/btclib/tests/test_numbertheory.py +++ b/btclib/tests/test_numbertheory.py @@ -12,7 +12,7 @@ import pytest -from btclib.numbertheory import mod_inv, mod_sqrt +from btclib.numbertheory import mod_inv, mod_sqrt, tonelli primes = [ 2, @@ -88,25 +88,43 @@ def test_mod_inv(): def test_mod_sqrt(): for p in primes[:30]: # exhaustable only for small p - hasRoot = set() - hasRoot.add(0) - hasRoot.add(1) + has_root = set() + has_root.add(0) + has_root.add(1) for i in range(2, p): - hasRoot.add(i * i % p) + has_root.add(i * i % p) for i in range(p): - if i in hasRoot: - root = mod_sqrt(i, p) - assert i == (root * root) % p - root = p - root - assert i == (root * root) % p + if i in has_root: + root1 = mod_sqrt(i, p) + assert i == (root1 * root1) % p + root2 = p - root1 + assert i == (root2 * root2) % p root = mod_sqrt(i + p, p) assert i == (root * root) % p + if p % 4 == 3 or p % 8 == 5: + assert tonelli(i, p) in (root1, root2) else: err_msg = "No root for " with pytest.raises(ValueError, match=err_msg): mod_sqrt(i, p) +def test_mod_sqrt2(): + # https://rosettacode.org/wiki/Tonelli-Shanks_algorithm#Python + ttest = [ + (10, 13), + (56, 101), + (1030, 10009), + (44402, 100049), + (665820697, 1000000009), + (881398088036, 1000000000039), + (41660815127637347468140745042827704103445750172002, 10 ** 50 + 577), + ] + for i, p in ttest: + root = tonelli(i, p) + assert i == (root * root) % p + + def test_minus_one_quadr_res(): "Ensure that if p = 3 (mod 4) then p - 1 is not a quadratic residue" for p in primes: