Skip to content

Commit

Permalink
Merge pull request ethereum#1689 from hwwhww/get_merkle_root
Browse files Browse the repository at this point in the history
Add `get_merkle_root`
  • Loading branch information
hwwhww committed Jan 8, 2019
2 parents 000333a + 88445f1 commit ee8f72d
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 59 deletions.
4 changes: 2 additions & 2 deletions eth/_utils/blobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
zpad_right,
)
from eth._utils.merkle import (
calc_merkle_root,
get_merkle_root_from_items,
)

from eth.constants import (
Expand All @@ -44,7 +44,7 @@ def iterate_chunks(collation_body: bytes) -> Iterator[Hash32]:
def calc_chunk_root(collation_body: bytes) -> Hash32:
check_body_size(collation_body)
chunks = list(iterate_chunks(collation_body))
return calc_merkle_root(chunks)
return get_merkle_root_from_items(chunks)


def check_body_size(body: bytes) -> bytes:
Expand Down
89 changes: 61 additions & 28 deletions eth/_utils/merkle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import math
from typing import (
cast,
Hashable,
Iterable,
NewType,
Sequence,
Union,
)

from cytoolz import (
Expand All @@ -20,8 +20,8 @@
reduce,
take,
)
from eth_hash.auto import (
keccak,
from eth.beacon._utils.hash import (
hash_eth2,
)
from eth_typing import (
Hash32,
Expand All @@ -36,17 +36,23 @@


def get_root(tree: MerkleTree) -> Hash32:
"""Get the root hash of a Merkle tree."""
"""
Get the root hash of a Merkle tree.
"""
return tree[0][0]


def get_branch_indices(node_index: int, depth: int) -> Sequence[int]:
"""Get the indices of all ancestors up until the root for a node with a given depth."""
def get_branch_indices(node_index: int, depth: int) -> Iterable[int]:
"""
Get the indices of all ancestors up until the root for a node with a given depth.
"""
return tuple(take(depth, iterate(lambda index: index // 2, node_index)))


def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]:
"""Read off the Merkle proof for an item from a Merkle tree."""
def get_merkle_proof(tree: MerkleTree, item_index: int) -> Iterable[Hash32]:
"""
Read off the Merkle proof for an item from a Merkle tree.
"""
if item_index < 0 or item_index >= len(tree[-1]):
raise ValidationError("Item index out of range")

Expand All @@ -64,16 +70,20 @@ def get_merkle_proof(tree: MerkleTree, item_index: int) -> Sequence[Hash32]:


def _calc_parent_hash(left_node: Hash32, right_node: Hash32) -> Hash32:
"""Calculate the parent hash of a node and its sibling."""
return keccak(left_node + right_node)
"""
Calculate the parent hash of a node and its sibling.
"""
return hash_eth2(left_node + right_node)


def verify_merkle_proof(root: Hash32,
item: Hashable,
item: Union[bytes, bytearray],
item_index: int,
proof: MerkleProof) -> bool:
"""Verify a Merkle proof against a root hash."""
leaf = keccak(item)
"""
Verify a Merkle proof against a root hash.
"""
leaf = hash_eth2(item)
branch_indices = get_branch_indices(item_index, len(proof))
node_orderers = [
identity if branch_index % 2 == 0 else reversed
Expand All @@ -87,28 +97,51 @@ def verify_merkle_proof(root: Hash32,
return proof_root == root


def _hash_layer(layer: Sequence[Hash32]) -> Sequence[Hash32]:
"""Calculate the layer on top of another one."""
return tuple(_calc_parent_hash(left, right) for left, right in partition(2, layer))
def _hash_layer(layer: Sequence[Hash32]) -> Iterable[Hash32]:
"""
Calculate the layer on top of another one.
"""
return tuple(
_calc_parent_hash(left, right)
for left, right in partition(2, layer)
)


def calc_merkle_tree(items: Sequence[Union[bytes, bytearray]]) -> MerkleTree:
"""
Calculate the Merkle tree corresponding to a list of items.
"""
leaves = tuple(hash_eth2(item) for item in items)
return calc_merkle_tree_from_leaves(leaves)

def calc_merkle_tree(items: Sequence[Hashable]) -> MerkleTree:
"""Calculate the Merkle tree corresponding to a list of items."""
if len(items) == 0:
raise ValidationError("No items given")
n_layers = math.log2(len(items)) + 1

def get_merkle_root_from_items(items: Sequence[Union[bytes, bytearray]]) -> Hash32:
"""
Calculate the Merkle root corresponding to a list of items.
"""
return get_root(calc_merkle_tree(items))


def calc_merkle_tree_from_leaves(leaves: Sequence[Hash32]) -> MerkleTree:
if len(leaves) == 0:
raise ValueError("No leaves given")
n_layers = math.log2(len(leaves)) + 1
if not n_layers.is_integer():
raise ValidationError("Item number is not a power of two")
raise ValueError("Number of leaves is not a power of two")
n_layers = int(n_layers)

leaves = tuple(keccak(item) for item in items)
tree = cast(MerkleTree, tuple(take(n_layers, iterate(_hash_layer, leaves)))[::-1])
reversed_tree = tuple(take(n_layers, iterate(_hash_layer, leaves)))
tree = MerkleTree(tuple(reversed(reversed_tree)))

if len(tree[0]) != 1:
raise Exception("Invariant: There must only be one root")

return tree


def calc_merkle_root(items: Sequence[Hashable]) -> Hash32:
"""Calculate the Merkle root corresponding to a list of items."""
return get_root(calc_merkle_tree(items))
def get_merkle_root(leaves: Sequence[Hash32]) -> Hash32:
"""
Return the Merkle root of the given 32-byte hashes.
Note: it has to be a full tree, i.e., `len(values)` is an exact power of 2.
"""
return get_root(calc_merkle_tree_from_leaves(leaves))
6 changes: 5 additions & 1 deletion eth/beacon/_utils/hash.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import (
Union,
)

from eth_typing import Hash32
from eth_hash.auto import keccak


def hash_eth2(data: bytes) -> Hash32:
def hash_eth2(data: Union[bytes, bytearray]) -> Hash32:
"""
Return Keccak-256 hashed result.
Note: it's a placeholder and we aim to migrate to a S[T/N]ARK-friendly hash function in
Expand Down
69 changes: 41 additions & 28 deletions tests/core/merkle-utils/test_merkle_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@
ValidationError,
)

from eth_hash.auto import (
keccak,
from eth.beacon._utils.hash import (
hash_eth2,
)

from eth._utils.merkle import (
calc_merkle_root,
get_merkle_root_from_items,
calc_merkle_tree,
get_root,
get_merkle_proof,
get_merkle_root,
verify_merkle_proof,
)

Expand All @@ -21,41 +22,41 @@
(
(b"single leaf",),
(
(keccak(b"single leaf"),),
(hash_eth2(b"single leaf"),),
),
),
(
(b"left", b"right"),
(
(keccak(keccak(b"left") + keccak(b"right")),),
(keccak(b"left"), keccak(b"right")),
(hash_eth2(hash_eth2(b"left") + hash_eth2(b"right")),),
(hash_eth2(b"left"), hash_eth2(b"right")),
),
),
(
(b"1", b"2", b"3", b"4"),
(
(
keccak(
keccak(
keccak(b"1") + keccak(b"2")
) + keccak(
keccak(b"3") + keccak(b"4")
hash_eth2(
hash_eth2(
hash_eth2(b"1") + hash_eth2(b"2")
) + hash_eth2(
hash_eth2(b"3") + hash_eth2(b"4")
)
),
),
(
keccak(
keccak(b"1") + keccak(b"2")
hash_eth2(
hash_eth2(b"1") + hash_eth2(b"2")
),
keccak(
keccak(b"3") + keccak(b"4")
hash_eth2(
hash_eth2(b"3") + hash_eth2(b"4")
),
),
(
keccak(b"1"),
keccak(b"2"),
keccak(b"3"),
keccak(b"4"),
hash_eth2(b"1"),
hash_eth2(b"2"),
hash_eth2(b"3"),
hash_eth2(b"4"),
),
),
),
Expand All @@ -64,45 +65,45 @@ def test_merkle_tree_calculation(leaves, tree):
calculated_tree = calc_merkle_tree(leaves)
assert calculated_tree == tree
assert get_root(tree) == tree[0][0]
assert calc_merkle_root(leaves) == get_root(tree)
assert get_merkle_root_from_items(leaves) == get_root(tree)


@pytest.mark.parametrize("leave_number", [0, 3, 5, 6, 7, 9])
def test_invalid_merkle_root_calculation(leave_number):
with pytest.raises(ValidationError):
calc_merkle_root((b"",) * leave_number)
with pytest.raises(ValueError):
get_merkle_root_from_items((b"",) * leave_number)


@pytest.mark.parametrize("leaves,index,proof", [
(
(b"1", b"2"),
0,
(keccak(b"2"),),
(hash_eth2(b"2"),),
),
(
(b"1", b"2"),
1,
(keccak(b"1"),),
(hash_eth2(b"1"),),
),
(
(b"1", b"2", b"3", b"4"),
0,
(keccak(b"2"), keccak(keccak(b"3") + keccak(b"4"))),
(hash_eth2(b"2"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))),
),
(
(b"1", b"2", b"3", b"4"),
1,
(keccak(b"1"), keccak(keccak(b"3") + keccak(b"4"))),
(hash_eth2(b"1"), hash_eth2(hash_eth2(b"3") + hash_eth2(b"4"))),
),
(
(b"1", b"2", b"3", b"4"),
2,
(keccak(b"4"), keccak(keccak(b"1") + keccak(b"2"))),
(hash_eth2(b"4"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))),
),
(
(b"1", b"2", b"3", b"4"),
3,
(keccak(b"3"), keccak(keccak(b"1") + keccak(b"2"))),
(hash_eth2(b"3"), hash_eth2(hash_eth2(b"1") + hash_eth2(b"2"))),
),
])
def test_merkle_proofs(leaves, index, proof):
Expand Down Expand Up @@ -142,3 +143,15 @@ def test_proof_generation_index_validation(leaves):
for invalid_index in [-1, len(leaves)]:
with pytest.raises(ValidationError):
get_merkle_proof(tree, invalid_index)


def test_get_merkle_root():
hash_0 = b"0" * 32
leaves = (hash_0,)
root = get_merkle_root(leaves)
assert root == hash_0

hash_1 = b"1" * 32
leaves = (hash_0, hash_1)
root = get_merkle_root(leaves)
assert root == hash_eth2(hash_0 + hash_1)

0 comments on commit ee8f72d

Please sign in to comment.