Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse merkle tree part1 #58

Closed
wants to merge 11 commits into from
77 changes: 77 additions & 0 deletions tests/test_sparse_merkle_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from hypothesis import (
given,
strategies as st,
settings,
)

from eth_hash.auto import (
keccak,
)

from trie.sparse_merkle_tree import (
SparseMerkleTree,
)
from trie.constants import (
EMPTY_NODE_HASHES,
)


@given(k=st.lists(st.binary(min_size=20, max_size=20), min_size=100, max_size=100, unique=True),
v=st.lists(st.binary(min_size=1), min_size=100, max_size=100),
chosen_numbers=st.lists(
st.integers(min_value=1, max_value=99),
min_size=50,
max_size=100,
unique=True),
random=st.randoms())
@settings(max_examples=10)
def test_sparse_merkle_tree(k, v, chosen_numbers, random):
kv_pairs = list(zip(k, v))

# Test basic get/set
trie = SparseMerkleTree(db={})
for k, v in kv_pairs:
assert not trie.exists(k)
trie.set(k, v)
prev_root = trie.root_hash
for k, v in kv_pairs:
assert trie.get(k) == v
trie.delete(k)
for k, _ in kv_pairs:
assert not trie.exists(k)
assert trie.root_hash == keccak(EMPTY_NODE_HASHES[0] + EMPTY_NODE_HASHES[0])

# Test single update
random.shuffle(kv_pairs)
for k, v in kv_pairs:
trie.set(k, v)
# Check trie root remains the same even in different insert order
assert trie.root_hash == prev_root
prior_to_update_root = trie.root_hash
for i in chosen_numbers:
# If new value is the same as current value, skip the update
if i.to_bytes(i, byteorder='big') == trie.get(kv_pairs[i][0]):
continue
# Update
trie.set(kv_pairs[i][0], i.to_bytes(i, byteorder='big'))
assert trie.get(kv_pairs[i][0]) == i.to_bytes(i, byteorder='big')
assert trie.root_hash != prior_to_update_root
# Un-update
trie.set(kv_pairs[i][0], kv_pairs[i][1])
assert trie.root_hash == prior_to_update_root

# Test batch update with different update order
# First batch update
for i in chosen_numbers:
trie.set(kv_pairs[i][0], i.to_bytes(i, byteorder='big'))
batch_updated_root = trie.root_hash
# Un-update
random.shuffle(chosen_numbers)
for i in chosen_numbers:
trie.set(kv_pairs[i][0], kv_pairs[i][1])
assert trie.root_hash == prior_to_update_root
# Second batch update
random.shuffle(chosen_numbers)
for i in chosen_numbers:
trie.set(kv_pairs[i][0], i.to_bytes(i, byteorder='big'))
assert trie.root_hash == batch_updated_root
9 changes: 9 additions & 0 deletions trie/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from eth_hash.auto import keccak

BLANK_NODE = b''
# keccak(b'')
BLANK_HASH = b"\xc5\xd2F\x01\x86\xf7#<\x92~}\xb2\xdc\xc7\x03\xc0\xe5\x00\xb6S\xca\x82';{\xfa\xd8\x04]\x85\xa4p" # noqa: E501
Expand Down Expand Up @@ -34,3 +36,10 @@

BYTE_1 = bytes([1])
BYTE_0 = bytes([0])

# Constants for Sparse Merkle Tree
TREE_HEIGHT = 160
EMPTY_LEAF_NODE_HASH = BLANK_HASH
EMPTY_NODE_HASHES = [EMPTY_LEAF_NODE_HASH]
for _ in range(TREE_HEIGHT - 1):
EMPTY_NODE_HASHES.insert(0, keccak(EMPTY_NODE_HASHES[0] + EMPTY_NODE_HASHES[0]))
112 changes: 112 additions & 0 deletions trie/sparse_merkle_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from eth_hash.auto import (
keccak,
)

from trie.constants import (
TREE_HEIGHT,
EMPTY_LEAF_NODE_HASH,
EMPTY_NODE_HASHES,
)
from trie.validation import (
validate_is_bytes,
validate_length,
)


# sanity check
assert EMPTY_LEAF_NODE_HASH == keccak(b'')


class SparseMerkleTree:
def __init__(self, db):
self.db = db
# Initialize an empty tree with one branch
self.root_hash = keccak(EMPTY_NODE_HASHES[0] + EMPTY_NODE_HASHES[0])
self.db[self.root_hash] = EMPTY_NODE_HASHES[0] + EMPTY_NODE_HASHES[0]
for i in range(TREE_HEIGHT - 1):
self.db[EMPTY_NODE_HASHES[i]] = EMPTY_NODE_HASHES[i+1] + EMPTY_NODE_HASHES[i+1]
self.db[EMPTY_LEAF_NODE_HASH] = b''

def get(self, key):
validate_is_bytes(key)
validate_length(key, 20)

target_bit = 1 << TREE_HEIGHT - 1
path = int.from_bytes(key, byteorder='big')
node_hash = self.root_hash
for i in range(TREE_HEIGHT):
if path & target_bit:
node_hash = self.db[node_hash][32:]
else:
node_hash = self.db[node_hash][:32]
target_bit >>= 1

if node_hash == EMPTY_LEAF_NODE_HASH:
raise KeyError("Key does not exist")
return self.db[node_hash]

def set(self, key, value):
validate_is_bytes(key)
validate_length(key, 20)
validate_is_bytes(value)

path = int.from_bytes(key, byteorder='big')
self.root_hash = self._set(value, path, 0, self.root_hash)
return

def _set(self, value, path, depth, node_hash):
if depth == TREE_HEIGHT:
return self._hash_and_save(value)
else:
node = self.db[node_hash]
target_bit = 1 << (TREE_HEIGHT - depth - 1)
if (path & target_bit):
return self._hash_and_save(node[:32] + self._set(value, path, depth+1, node[32:]))
else:
return self._hash_and_save(self._set(value, path, depth+1, node[:32]) + node[32:])

def exists(self, key):
validate_is_bytes(key)
validate_length(key, 20)

try:
self.get(key)
return True
except KeyError:
return False

def delete(self, key):
"""
Equals to setting the value to None
"""
validate_is_bytes(key)
validate_length(key, 20)

self.set(key, b'')

#
# Utils
#
def _hash_and_save(self, node):
"""
Saves a node into the database and returns its hash
"""

node_hash = keccak(node)
self.db[node_hash] = node
return node_hash

#
# Dictionary API
#
def __getitem__(self, key):
return self.get(key)

def __setitem__(self, key, value):
return self.set(key, value)

def __delitem__(self, key):
return self.delete(key)

def __contains__(self, key):
return self.exists(key)