diff --git a/eth/_utils/blobs.py b/eth/_utils/blobs.py index cde48da79a..7a3ee264fa 100644 --- a/eth/_utils/blobs.py +++ b/eth/_utils/blobs.py @@ -17,7 +17,7 @@ zpad_right, ) from eth._utils.merkle import ( - calc_merkle_root, + get_merkle_root_from_items, ) from eth.constants import ( @@ -44,7 +44,7 @@ def iterate_chunks(collation_body: bytes) -> Iterator[Hash32]: def calc_chunk_root(collation_body: bytes) -> Hash32: check_body_size(collation_body) chunks = list(iterate_chunks(collation_body)) - return calc_merkle_root(chunks) + return get_merkle_root_from_items(chunks) def check_body_size(body: bytes) -> bytes: diff --git a/eth/_utils/merkle.py b/eth/_utils/merkle.py index 31f1e03462..dff8d354c2 100644 --- a/eth/_utils/merkle.py +++ b/eth/_utils/merkle.py @@ -7,10 +7,10 @@ import math from typing import ( - cast, - Hashable, + Iterable, NewType, Sequence, + Union, ) from cytoolz import ( @@ -20,8 +20,8 @@ reduce, take, ) -from eth_hash.auto import ( - keccak, +from eth.beacon._utils.hash import ( + hash_eth2, ) from eth_typing import ( Hash32, @@ -36,17 +36,23 @@ def get_root(tree: MerkleTree) -> Hash32: - """Get the root hash of a Merkle tree.""" + """ + Get the root hash of a Merkle tree. + """ return tree[0][0] -def get_branch_indices(node_index: int, depth: int) -> Sequence[int]: - """Get the indices of all ancestors up until the root for a node with a given depth.""" +def get_branch_indices(node_index: int, depth: int) -> Iterable[int]: + """ + Get the indices of all ancestors up until the root for a node with a given depth. + """ return tuple(take(depth, iterate(lambda index: index // 2, node_index))) -def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]: - """Read off the Merkle proof for an item from a Merkle tree.""" +def get_merkle_proof(tree: MerkleTree, item_index: int) -> Iterable[Hash32]: + """ + Read off the Merkle proof for an item from a Merkle tree. + """ if item_index < 0 or item_index >= len(tree[-1]): raise ValidationError("Item index out of range") @@ -64,16 +70,20 @@ def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]: def _calc_parent_hash(left_node: Hash32, right_node: Hash32) -> Hash32: - """Calculate the parent hash of a node and its sibling.""" - return keccak(left_node + right_node) + """ + Calculate the parent hash of a node and its sibling. + """ + return hash_eth2(left_node + right_node) def verify_merkle_proof(root: Hash32, - item: Hashable, + item: Union[bytes, bytearray], item_index: int, proof: MerkleProof) -> bool: - """Verify a Merkle proof against a root hash.""" - leaf = keccak(item) + """ + Verify a Merkle proof against a root hash. + """ + leaf = hash_eth2(item) branch_indices = get_branch_indices(item_index, len(proof)) node_orderers = [ identity if branch_index % 2 == 0 else reversed @@ -87,28 +97,51 @@ def verify_merkle_proof(root: Hash32, return proof_root == root -def _hash_layer(layer: Sequence[Hash32]) -> Sequence[Hash32]: - """Calculate the layer on top of another one.""" - return tuple(_calc_parent_hash(left, right) for left, right in partition(2, layer)) +def _hash_layer(layer: Sequence[Hash32]) -> Iterable[Hash32]: + """ + Calculate the layer on top of another one. + """ + return tuple( + _calc_parent_hash(left, right) + for left, right in partition(2, layer) + ) + +def calc_merkle_tree(items: Sequence[Union[bytes, bytearray]]) -> MerkleTree: + """ + Calculate the Merkle tree corresponding to a list of items. + """ + leaves = tuple(hash_eth2(item) for item in items) + return calc_merkle_tree_from_leaves(leaves) -def calc_merkle_tree(items: Sequence[Hashable]) -> MerkleTree: - """Calculate the Merkle tree corresponding to a list of items.""" - if len(items) == 0: - raise ValidationError("No items given") - n_layers = math.log2(len(items)) + 1 + +def get_merkle_root_from_items(items: Sequence[Union[bytes, bytearray]]) -> Hash32: + """ + Calculate the Merkle root corresponding to a list of items. + """ + return get_root(calc_merkle_tree(items)) + + +def calc_merkle_tree_from_leaves(leaves: Sequence[Hash32]) -> MerkleTree: + if len(leaves) == 0: + raise ValueError("No leaves given") + n_layers = math.log2(len(leaves)) + 1 if not n_layers.is_integer(): - raise ValidationError("Item number is not a power of two") + raise ValueError("Number of leaves is not a power of two") n_layers = int(n_layers) - leaves = tuple(keccak(item) for item in items) - tree = cast(MerkleTree, tuple(take(n_layers, iterate(_hash_layer, leaves)))[::-1]) + reversed_tree = tuple(take(n_layers, iterate(_hash_layer, leaves))) + tree = MerkleTree(tuple(reversed(reversed_tree))) + if len(tree[0]) != 1: raise Exception("Invariant: There must only be one root") return tree -def calc_merkle_root(items: Sequence[Hashable]) -> Hash32: - """Calculate the Merkle root corresponding to a list of items.""" - return get_root(calc_merkle_tree(items)) +def get_merkle_root(leaves: Sequence[Hash32]) -> Hash32: + """ + Return the Merkle root of the given 32-byte hashes. + Note: it has to be a full tree, i.e., `len(values)` is an exact power of 2. + """ + return get_root(calc_merkle_tree_from_leaves(leaves)) diff --git a/eth/beacon/_utils/hash.py b/eth/beacon/_utils/hash.py index 8431a4a7fa..062e7787c4 100644 --- a/eth/beacon/_utils/hash.py +++ b/eth/beacon/_utils/hash.py @@ -1,8 +1,12 @@ +from typing import ( + Union, +) + from eth_typing import Hash32 from eth_hash.auto import keccak -def hash_eth2(data: bytes) -> Hash32: +def hash_eth2(data: Union[bytes, bytearray]) -> Hash32: """ Return Keccak-256 hashed result. Note: it's a placeholder and we aim to migrate to a S[T/N]ARK-friendly hash function in diff --git a/tests/core/merkle-utils/test_merkle_trees.py b/tests/core/merkle-utils/test_merkle_trees.py index 17625888e2..cbef512191 100644 --- a/tests/core/merkle-utils/test_merkle_trees.py +++ b/tests/core/merkle-utils/test_merkle_trees.py @@ -4,15 +4,16 @@ ValidationError, ) -from eth_hash.auto import ( - keccak, +from eth.beacon._utils.hash import ( + hash_eth2, ) from eth._utils.merkle import ( - calc_merkle_root, + get_merkle_root_from_items, calc_merkle_tree, get_root, get_merkle_proof, + get_merkle_root, verify_merkle_proof, ) @@ -21,41 +22,41 @@ ( (b"single leaf",), ( - (keccak(b"single leaf"),), + (hash_eth2(b"single leaf"),), ), ), ( (b"left", b"right"), ( - (keccak(keccak(b"left") + keccak(b"right")),), - (keccak(b"left"), keccak(b"right")), + (hash_eth2(hash_eth2(b"left") + hash_eth2(b"right")),), + (hash_eth2(b"left"), hash_eth2(b"right")), ), ), ( (b"1", b"2", b"3", b"4"), ( ( - keccak( - keccak( - keccak(b"1") + keccak(b"2") - ) + keccak( - keccak(b"3") + keccak(b"4") + hash_eth2( + hash_eth2( + hash_eth2(b"1") + hash_eth2(b"2") + ) + hash_eth2( + hash_eth2(b"3") + hash_eth2(b"4") ) ), ), ( - keccak( - keccak(b"1") + keccak(b"2") + hash_eth2( + hash_eth2(b"1") + hash_eth2(b"2") ), - keccak( - keccak(b"3") + keccak(b"4") + hash_eth2( + hash_eth2(b"3") + hash_eth2(b"4") ), ), ( - keccak(b"1"), - keccak(b"2"), - keccak(b"3"), - keccak(b"4"), + hash_eth2(b"1"), + hash_eth2(b"2"), + hash_eth2(b"3"), + hash_eth2(b"4"), ), ), ), @@ -64,45 +65,45 @@ def test_merkle_tree_calculation(leaves, tree): calculated_tree = calc_merkle_tree(leaves) assert calculated_tree == tree assert get_root(tree) == tree[0][0] - assert calc_merkle_root(leaves) == get_root(tree) + assert get_merkle_root_from_items(leaves) == get_root(tree) @pytest.mark.parametrize("leave_number", [0, 3, 5, 6, 7, 9]) def test_invalid_merkle_root_calculation(leave_number): - with pytest.raises(ValidationError): - calc_merkle_root((b"",) * leave_number) + with pytest.raises(ValueError): + get_merkle_root_from_items((b"",) * leave_number) @pytest.mark.parametrize("leaves,index,proof", [ ( (b"1", b"2"), 0, - (keccak(b"2"),), + (hash_eth2(b"2"),), ), ( (b"1", b"2"), 1, - (keccak(b"1"),), + (hash_eth2(b"1"),), ), ( (b"1", b"2", b"3", b"4"), 0, - (keccak(b"2"), keccak(keccak(b"3") + keccak(b"4"))), + (hash_eth2(b"2"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))), ), ( (b"1", b"2", b"3", b"4"), 1, - (keccak(b"1"), keccak(keccak(b"3") + keccak(b"4"))), + (hash_eth2(b"1"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))), ), ( (b"1", b"2", b"3", b"4"), 2, - (keccak(b"4"), keccak(keccak(b"1") + keccak(b"2"))), + (hash_eth2(b"4"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))), ), ( (b"1", b"2", b"3", b"4"), 3, - (keccak(b"3"), keccak(keccak(b"1") + keccak(b"2"))), + (hash_eth2(b"3"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))), ), ]) def test_merkle_proofs(leaves, index, proof): @@ -142,3 +143,15 @@ def test_proof_generation_index_validation(leaves): for invalid_index in [-1, len(leaves)]: with pytest.raises(ValidationError): get_merkle_proof(tree, invalid_index) + + +def test_get_merkle_root(): + hash_0 = b"0" * 32 + leaves = (hash_0,) + root = get_merkle_root(leaves) + assert root == hash_0 + + hash_1 = b"1" * 32 + leaves = (hash_0, hash_1) + root = get_merkle_root(leaves) + assert root == hash_eth2(hash_0 + hash_1)