From 5d8d4048095e576f92bd32b3e28f163618e5c80b Mon Sep 17 00:00:00 2001 From: Chih Cheng Liang Date: Thu, 28 Feb 2019 01:25:53 +0800 Subject: [PATCH] Add BLS APIs --- .gitignore | 3 + py_ecc/bls_api/__init__.py | 0 py_ecc/bls_api/api.py | 114 ++++++++++++++++++++++++ py_ecc/bls_api/hash.py | 12 +++ py_ecc/bls_api/typing.py | 10 +++ py_ecc/bls_api/utils.py | 174 +++++++++++++++++++++++++++++++++++++ setup.py | 3 + tests/test_bls_api.py | 163 ++++++++++++++++++++++++++++++++++ 8 files changed, 479 insertions(+) create mode 100644 py_ecc/bls_api/__init__.py create mode 100644 py_ecc/bls_api/api.py create mode 100644 py_ecc/bls_api/hash.py create mode 100644 py_ecc/bls_api/typing.py create mode 100644 py_ecc/bls_api/utils.py create mode 100644 tests/test_bls_api.py diff --git a/.gitignore b/.gitignore index 7aedf48..447a4a1 100644 --- a/.gitignore +++ b/.gitignore @@ -86,6 +86,9 @@ logs # Mongo Explorer plugin: .idea/mongoSettings.xml +# Mypy cache +.mypy_cache + # VIM temp files *.swp diff --git a/py_ecc/bls_api/__init__.py b/py_ecc/bls_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/py_ecc/bls_api/api.py b/py_ecc/bls_api/api.py new file mode 100644 index 0000000..8fa8c43 --- /dev/null +++ b/py_ecc/bls_api/api.py @@ -0,0 +1,114 @@ +from typing import ( + Sequence, +) +from eth_utils import ( + ValidationError, +) +from py_ecc.optimized_bls12_381 import ( + FQ12, + G1, + Z1, + Z2, + add, + final_exponentiate, + multiply, + neg, + pairing, +) +from .typing import ( + BLSPubkey, + BLSSignature, + Hash32, +) +from .utils import ( + FQP_point_to_FQ2_point, + G1_to_pubkey, + G2_to_signature, + compress_G1, + compress_G2, + decompress_G1, + decompress_G2, + hash_to_G2, + pubkey_to_G1, + signature_to_G2, +) + + +def sign(message_hash: Hash32, + privkey: int, + domain: int) -> BLSSignature: + return G2_to_signature( + compress_G2( + multiply( + hash_to_G2(message_hash, domain), + privkey + ) + )) + + +def privtopub(k: int) -> BLSPubkey: + return G1_to_pubkey(compress_G1(multiply(G1, k))) + + +def verify(message_hash: Hash32, pubkey: BLSPubkey, signature: BLSSignature, domain: int) -> bool: + try: + final_exponentiation = final_exponentiate( + pairing( + FQP_point_to_FQ2_point(decompress_G2(signature_to_G2(signature))), + G1, + final_exponentiate=False, + ) * + pairing( + FQP_point_to_FQ2_point(hash_to_G2(message_hash, domain)), + neg(decompress_G1(pubkey_to_G1(pubkey))), + final_exponentiate=False, + ) + ) + return final_exponentiation == FQ12.one() + except (ValidationError, ValueError, AssertionError): + return False + + +def aggregate_signatures(signatures: Sequence[BLSSignature]) -> BLSSignature: + o = Z2 + for s in signatures: + o = FQP_point_to_FQ2_point(add(o, decompress_G2(signature_to_G2(s)))) + return G2_to_signature(compress_G2(o)) + + +def aggregate_pubkeys(pubkeys: Sequence[BLSPubkey]) -> BLSPubkey: + o = Z1 + for p in pubkeys: + o = add(o, decompress_G1(pubkey_to_G1(p))) + return G1_to_pubkey(compress_G1(o)) + + +def verify_multiple(pubkeys: Sequence[BLSPubkey], + message_hashes: Sequence[Hash32], + signature: BLSSignature, + domain: int) -> bool: + len_msgs = len(message_hashes) + + if len(pubkeys) != len_msgs: + raise ValidationError( + "len(pubkeys) (%s) should be equal to len(message_hashes) (%s)" % ( + len(pubkeys), len_msgs + ) + ) + + try: + o = FQ12([1] + [0] * 11) + for m_pubs in set(message_hashes): + # aggregate the pubs + group_pub = Z1 + for i in range(len_msgs): + if message_hashes[i] == m_pubs: + group_pub = add(group_pub, decompress_G1(pubkey_to_G1(pubkeys[i]))) + + o *= pairing(hash_to_G2(m_pubs, domain), group_pub, final_exponentiate=False) + o *= pairing(decompress_G2(signature_to_G2(signature)), neg(G1), final_exponentiate=False) + + final_exponentiation = final_exponentiate(o) + return final_exponentiation == FQ12.one() + except (ValidationError, ValueError, AssertionError): + return False diff --git a/py_ecc/bls_api/hash.py b/py_ecc/bls_api/hash.py new file mode 100644 index 0000000..c8ea4ac --- /dev/null +++ b/py_ecc/bls_api/hash.py @@ -0,0 +1,12 @@ +from .typing import Hash32 +from eth_hash.auto import keccak +from typing import Union + + +def hash_eth2(data: Union[bytes, bytearray]) -> Hash32: + """ + Return Keccak-256 hashed result. + Note: it's a placeholder and we aim to migrate to a S[T/N]ARK-friendly hash function in + a future Ethereum 2.0 deployment phase. + """ + return Hash32(keccak(data)) diff --git a/py_ecc/bls_api/typing.py b/py_ecc/bls_api/typing.py new file mode 100644 index 0000000..c9dda6d --- /dev/null +++ b/py_ecc/bls_api/typing.py @@ -0,0 +1,10 @@ +# This module will not be included in the PR. +# These types should be replaced with those in eth-typing + +from typing import ( + NewType, +) + +Hash32 = NewType("Hash32", bytes) +BLSPubkey = NewType('BLSPubkey', bytes) # bytes48 +BLSSignature = NewType('BLSSignature', bytes) # bytes96 diff --git a/py_ecc/bls_api/utils.py b/py_ecc/bls_api/utils.py new file mode 100644 index 0000000..fa1aaa5 --- /dev/null +++ b/py_ecc/bls_api/utils.py @@ -0,0 +1,174 @@ +from typing import ( # noqa: F401 + Dict, + Sequence, + Tuple, + Union, +) + +from eth_utils import ( + big_endian_to_int, +) +from py_ecc.optimized_bls12_381 import ( + FQ, + FQ2, + FQP, + b, + b2, + field_modulus as q, + is_on_curve, + multiply, + normalize, +) + +from .hash import ( + hash_eth2, +) +from .typing import ( + BLSPubkey, + BLSSignature, + Hash32, +) + +G2_cofactor = 305502333931268344200999753193121504214466019254188142667664032982267604182971884026507427359259977847832272839041616661285803823378372096355777062779109 # noqa: E501 +FQ2_order = q ** 2 - 1 +eighth_roots_of_unity = [ + FQ2([1, 1]) ** ((FQ2_order * k) // 8) + for k in range(8) +] + + +# +# Helpers +# +def FQP_point_to_FQ2_point(pt: Tuple[FQP, FQP, FQP]) -> Tuple[FQ2, FQ2, FQ2]: + """ + Transform FQP to FQ2 for type hinting. + """ + return ( + FQ2(pt[0].coeffs), + FQ2(pt[1].coeffs), + FQ2(pt[2].coeffs), + ) + + +def modular_squareroot(value: FQ2) -> FQP: + """ + ``modular_squareroot(x)`` returns the value ``y`` such that ``y**2 % q == x``, + and None if this is not possible. In cases where there are two solutions, + the value with higher imaginary component is favored; + if both solutions have equal imaginary component the value with higher real + component is favored. + """ + candidate_squareroot = value ** ((FQ2_order + 8) // 16) + check = candidate_squareroot ** 2 / value + if check in eighth_roots_of_unity[::2]: + x1 = candidate_squareroot / eighth_roots_of_unity[eighth_roots_of_unity.index(check) // 2] + x2 = FQ2([-x1.coeffs[0], -x1.coeffs[1]]) # x2 = -x1 + return x1 if (x1.coeffs[1], x1.coeffs[0]) > (x2.coeffs[1], x2.coeffs[0]) else x2 + return None + + +def _get_x_coordinate(message_hash: Hash32, domain: int) -> FQ2: + domain_in_bytes = domain.to_bytes(8, 'big') + + # Initial candidate x coordinate + x_re = big_endian_to_int(hash_eth2(message_hash + domain_in_bytes + b'\x01')) + x_im = big_endian_to_int(hash_eth2(message_hash + domain_in_bytes + b'\x02')) + x_coordinate = FQ2([x_re, x_im]) # x_re + x_im * i + + return x_coordinate + + +def hash_to_G2(message_hash: Hash32, domain: int) -> Tuple[FQ2, FQ2, FQ2]: + x_coordinate = _get_x_coordinate(message_hash, domain) + + # Test candidate y coordinates until a one is found + while 1: + y_coordinate_squared = x_coordinate ** 3 + FQ2([4, 4]) # The curve is y^2 = x^3 + 4(i + 1) + y_coordinate = modular_squareroot(y_coordinate_squared) + if y_coordinate is not None: # Check if quadratic residue found + break + x_coordinate += FQ2([1, 0]) # Add 1 and try again + + return multiply( + (x_coordinate, y_coordinate, FQ2([1, 0])), + G2_cofactor + ) + + +# +# G1 +# +def compress_G1(pt: Tuple[FQ, FQ, FQ]) -> int: + x, y = normalize(pt) + return x.n + 2**383 * (y.n % 2) + + +def decompress_G1(pt: int) -> Tuple[FQ, FQ, FQ]: + if pt == 0: + return (FQ(1), FQ(1), FQ(0)) + x = pt % 2**383 + y_mod_2 = pt // 2**383 + y = pow((x**3 + b.n) % q, (q + 1) // 4, q) + + if pow(y, 2, q) != (x**3 + b.n) % q: + raise ValueError( + "he given point is not on G1: y**2 = x**3 + b" + ) + if y % 2 != y_mod_2: + y = q - y + return (FQ(x), FQ(y), FQ(1)) + + +def G1_to_pubkey(pt: int) -> BLSPubkey: + return BLSPubkey(pt.to_bytes(48, "big")) + + +def pubkey_to_G1(pubkey: BLSPubkey) -> int: + return big_endian_to_int(pubkey) + +# +# G2 +# + + +def compress_G2(pt: Tuple[FQP, FQP, FQP]) -> Tuple[int, int]: + if not is_on_curve(pt, b2): + raise ValueError( + "The given point is not on the twisted curve over FQ**2" + ) + x, y = normalize(pt) + return ( + int(x.coeffs[0] + 2**383 * (y.coeffs[0] % 2)), + int(x.coeffs[1]) + ) + + +def decompress_G2(p: Tuple[int, int]) -> Tuple[FQP, FQP, FQP]: + x1 = p[0] % 2**383 + y1_mod_2 = p[0] // 2**383 + x2 = p[1] + x = FQ2([x1, x2]) + if x == FQ2([0, 0]): + return FQ2([1, 0]), FQ2([1, 0]), FQ2([0, 0]) + y = modular_squareroot(x**3 + b2) + if y is None: + raise ValueError("Failed to find a modular squareroot") + if y.coeffs[0] % 2 != y1_mod_2: + y = FQ2((y * -1).coeffs) + if not is_on_curve((x, y, FQ2([1, 0])), b2): + raise ValueError( + "The given point is not on the twisted curve over FQ**2" + ) + return x, y, FQ2([1, 0]) + + +def G2_to_signature(pt: Tuple[int, int]) -> BLSSignature: + return BLSSignature( + pt[0].to_bytes(48, "big") + + pt[1].to_bytes(48, "big") + ) + + +def signature_to_G2(signature: BLSSignature) -> Tuple[int, int]: + return (big_endian_to_int(signature[:48]), big_endian_to_int(signature[48:])) diff --git a/setup.py b/setup.py index b587996..fce50e5 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,9 @@ packages=find_packages(exclude=('tests', 'docs')), package_data={'py_ecc': ['py.typed']}, install_requires=[ + "eth-hash>=0.1.4,<1", + "eth-utils>=1.3.0,<2", + "eth-typing>=2.0.0,<3.0.0", ], python_requires='>=3.5, <4', extras_require=extras_require, diff --git a/tests/test_bls_api.py b/tests/test_bls_api.py new file mode 100644 index 0000000..b83c6ff --- /dev/null +++ b/tests/test_bls_api.py @@ -0,0 +1,163 @@ +from eth_utils import ( + big_endian_to_int, +) +from py_ecc.bls_api.api import ( + aggregate_pubkeys, + aggregate_signatures, + privtopub, + sign, + verify, + verify_multiple, +) +from py_ecc.bls_api.hash import ( + hash_eth2, +) +from py_ecc.bls_api.utils import ( + FQP_point_to_FQ2_point, + G1_to_pubkey, + G2_to_signature, + _get_x_coordinate, + compress_G1, + compress_G2, + decompress_G1, + decompress_G2, + hash_to_G2, + pubkey_to_G1, + signature_to_G2, +) +from py_ecc.optimized_bls12_381 import ( + FQ, + FQ2, + FQ12, + FQP, + G1, + G2, + Z1, + Z2, + add, + b, + b2, + curve_order, + field_modulus as q, + final_exponentiate, + is_on_curve, + multiply, + neg, + normalize, + pairing, +) +import pytest + + +@pytest.mark.parametrize( + 'message_hash,domain', + [ + (b'\x12' * 32, 0), + (b'\x12' * 32, 1), + (b'\x34' * 32, 0), + ] +) +def test_get_x_coordinate(message_hash, domain): + x_coordinate = _get_x_coordinate(message_hash, domain) + domain_in_bytes = domain.to_bytes(8, 'big') + assert x_coordinate == FQ2( + [ + big_endian_to_int(hash_eth2(message_hash + domain_in_bytes + b'\x01')), + big_endian_to_int(hash_eth2(message_hash + domain_in_bytes + b'\x02')), + ] + ) + + +def test_hash_to_G2(): + message_hash = b'\x12' * 32 + + domain_1 = 1 + result_1 = hash_to_G2(message_hash, domain_1) + assert is_on_curve(result_1, b2) + + +def test_decompress_G2_with_no_modular_square_root_found(): + with pytest.raises(ValueError, match="Failed to find a modular squareroot"): + decompress_G2(signature_to_G2(b'\x11' * 96)) + + +@pytest.mark.parametrize( + 'privkey', + [ + (1), + (5), + (124), + (735), + (127409812145), + (90768492698215092512159), + (0), + ] +) +def test_bls_core(privkey): + domain = 0 + p1 = multiply(G1, privkey) + p2 = multiply(G2, privkey) + msg = str(privkey).encode('utf-8') + msghash = hash_to_G2(msg, domain=domain) + + assert normalize(decompress_G1(compress_G1(p1))) == normalize(p1) + assert normalize(decompress_G2(compress_G2(p2))) == normalize(p2) + assert normalize(decompress_G2(compress_G2(msghash))) == normalize(msghash) + sig = sign(msg, privkey, domain=domain) + pub = privtopub(privkey) + assert verify(msg, pub, sig, domain=domain) + + +@pytest.mark.parametrize( + 'msg, privkeys', + [ + (b'\x12' * 32, [1, 5, 124, 735, 127409812145, 90768492698215092512159, 0]), + (b'\x34' * 32, [42, 666, 1274099945, 4389392949595]), + ] +) +def test_signature_aggregation(msg, privkeys): + domain = 0 + sigs = [sign(msg, k, domain=domain) for k in privkeys] + pubs = [privtopub(k) for k in privkeys] + aggsig = aggregate_signatures(sigs) + aggpub = aggregate_pubkeys(pubs) + assert verify(msg, aggpub, aggsig, domain=domain) + + +@pytest.mark.parametrize( + 'msg_1, msg_2', + [ + (b'\x12' * 32, b'\x34' * 32) + ] +) +@pytest.mark.parametrize( + 'privkeys_1, privkeys_2', + [ + (tuple(range(10)), tuple(range(10))), + ((0, 1, 2, 3), (4, 5, 6, 7)), + ((0, 1, 2, 3), (2, 3, 4, 5)), + ] +) +def test_multi_aggregation(msg_1, msg_2, privkeys_1, privkeys_2): + domain = 0 + + sigs_1 = [sign(msg_1, k, domain=domain) for k in privkeys_1] + pubs_1 = [privtopub(k) for k in privkeys_1] + aggsig_1 = aggregate_signatures(sigs_1) + aggpub_1 = aggregate_pubkeys(pubs_1) + + sigs_2 = [sign(msg_2, k, domain=domain) for k in privkeys_2] + pubs_2 = [privtopub(k) for k in privkeys_2] + aggsig_2 = aggregate_signatures(sigs_2) + aggpub_2 = aggregate_pubkeys(pubs_2) + + message_hashes = [msg_1, msg_2] + pubs = [aggpub_1, aggpub_2] + aggsig = aggregate_signatures([aggsig_1, aggsig_2]) + + assert verify_multiple( + pubkeys=pubs, + message_hashes=message_hashes, + signature=aggsig, + domain=domain, + )