From 4cfc5e3c5a3ad4f7ffd213de8fe925da30716a9f Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Thu, 30 Mar 2017 13:27:15 -0600 Subject: [PATCH] Initial Commit --- .gitignore | 51 ++ .gitmodules | 3 + .travis.yml | 22 + Makefile | 54 ++ README.md | 1 + pytest.ini | 3 + reference.py | 1046 +++++++++++++++++++++++++++++++++++ requirements-dev.txt | 3 + setup.py | 43 ++ tests/test_nibbles_utils.py | 15 + tests/test_nodes_utils.py | 42 ++ tests/test_trie.py | 106 ++++ tox.ini | 24 + trie/__init__.py | 5 + trie/constants.py | 16 + trie/db/__init__.py | 0 trie/db/base.py | 35 ++ trie/db/memory.py | 22 + trie/exceptions.py | 10 + trie/trie.py | 377 +++++++++++++ trie/utils/__init__.py | 0 trie/utils/nibbles.py | 121 ++++ trie/utils/nodes.py | 89 +++ trie/utils/sha3.py | 10 + trie/validation.py | 49 ++ 25 files changed, 2147 insertions(+) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 .travis.yml create mode 100644 Makefile create mode 100644 README.md create mode 100644 pytest.ini create mode 100644 reference.py create mode 100644 requirements-dev.txt create mode 100644 setup.py create mode 100644 tests/test_nibbles_utils.py create mode 100644 tests/test_nodes_utils.py create mode 100644 tests/test_trie.py create mode 100644 tox.ini create mode 100644 trie/__init__.py create mode 100644 trie/constants.py create mode 100644 trie/db/__init__.py create mode 100644 trie/db/base.py create mode 100644 trie/db/memory.py create mode 100644 trie/exceptions.py create mode 100644 trie/trie.py create mode 100644 trie/utils/__init__.py create mode 100644 trie/utils/nibbles.py create mode 100644 trie/utils/nodes.py create mode 100644 trie/utils/sha3.py create mode 100644 trie/validation.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..7451909c --- /dev/null +++ b/.gitignore @@ -0,0 +1,51 @@ +*.py[cod] + +# C extensions +*.so + +# Packages +*.egg +*.egg-info +dist +build +eggs +parts +var +sdist +develop-eggs +.installed.cfg +lib +lib64 + +# Installer logs +pip-log.txt + +# Unit test / coverage reports +.coverage +.tox +nosetests.xml + +# Translations +*.mo + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# Complexity +output/*.html +output/*/index.html + +# Sphinx +docs/_build +docs/modules.rst + +# pytest +.cache/ + +# fixtures +fixtures/** + +# profiling +prof/** diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..e1ed7112 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "fixtures"] + path = fixtures + url = git@github.com:ethereum/tests.git diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 00000000..d5c9cbd9 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,22 @@ +sudo: false +language: python +python: + - "3.5" +dist: trusty +env: + matrix: + - TOX_ENV=py27 + - TOX_ENV=py3 + - TOX_ENV=py35-stdlib + - TOX_ENV=flake8 +cache: + pip: true +install: + - "travis_retry pip install pip setuptools --upgrade" + - "travis_retry pip install tox" +before_script: + - pip freeze +script: + - tox -e $TOX_ENV --recreate +after_script: + - cat .tox/$TOX_ENV/log/*.log diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..a887b0a3 --- /dev/null +++ b/Makefile @@ -0,0 +1,54 @@ +.PHONY: clean-pyc clean-build docs + +help: + @echo "clean-build - remove build artifacts" + @echo "clean-pyc - remove Python file artifacts" + @echo "lint - check style with flake8" + @echo "test - run tests quickly with the default Python" + @echo "testall - run tests on every Python version with tox" + @echo "coverage - check code coverage quickly with the default Python" + @echo "docs - generate Sphinx HTML documentation, including API docs" + @echo "release - package and upload a release" + @echo "sdist - package" + +clean: clean-build clean-pyc + +clean-build: + rm -fr build/ + rm -fr dist/ + rm -fr *.egg-info + +clean-pyc: + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + +lint: + flake8 trie + +test: + py.test --tb native tests + +test-all: + tox + +coverage: + coverage run --source trie + coverage report -m + coverage html + open htmlcov/index.html + +docs: + rm -f docs/trie.rst + rm -f docs/modules.rst + sphinx-apidoc -o docs/ -d 2 trie/ + $(MAKE) -C docs clean + $(MAKE) -C docs html + open docs/_build/html/index.html + +release: clean + python setup.py sdist bdist bdist_wheel upload + +sdist: clean + python setup.py sdist bdist bdist_wheel + ls -l dist diff --git a/README.md b/README.md new file mode 100644 index 00000000..d268336a --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# Python Implementation of the Ethereum Trie structure diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..1c782ddc --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +addopts= -v --showlocals --durations 10 +python_paths= . diff --git a/reference.py b/reference.py new file mode 100644 index 00000000..cc64eea5 --- /dev/null +++ b/reference.py @@ -0,0 +1,1046 @@ +#!/usr/bin/env python +import os +import rlp +from ethereum import utils +from ethereum.utils import to_string +from ethereum.abi import is_string +import copy +from rlp.utils import decode_hex, ascii_chr, str_to_bytes +from ethereum.utils import encode_hex +from ethereum.fast_rlp import encode_optimized +rlp_encode = encode_optimized + +bin_to_nibbles_cache = {} + +hti = {} +for i, c in enumerate(b'0123456789abcdef'): + hti[c] = i + + +def bin_to_nibbles(s): + """convert string s to nibbles (half-bytes) + + >>> bin_to_nibbles("") + [] + >>> bin_to_nibbles("h") + [6, 8] + >>> bin_to_nibbles("he") + [6, 8, 6, 5] + >>> bin_to_nibbles("hello") + [6, 8, 6, 5, 6, 12, 6, 12, 6, 15] + """ + return [hti[c] for c in encode_hex(s)] + + +def nibbles_to_bin(nibbles): + if any(x > 15 or x < 0 for x in nibbles): + raise Exception("nibbles can only be [0,..15]") + + if len(nibbles) % 2: + raise Exception("nibbles must be of even numbers") + + res = b'' + for i in range(0, len(nibbles), 2): + res += ascii_chr(16 * nibbles[i] + nibbles[i + 1]) + return res + + +NIBBLE_TERMINATOR = 16 +RECORDING = 1 +NONE = 0 +VERIFYING = -1 + +proving = False + + +class ProofConstructor(): + + def __init__(self): + self.mode = [] + self.nodes = [] + self.exempt = [] + + def push(self, mode, nodes=[]): + global proving + proving = True + self.mode.append(mode) + self.exempt.append(set()) + if mode == VERIFYING: + self.nodes.append(set([rlp_encode(x) for x in nodes])) + else: + self.nodes.append(set()) + + def pop(self): + global proving + self.mode.pop() + self.nodes.pop() + self.exempt.pop() + if not self.mode: + proving = False + + def get_nodelist(self): + return list(map(rlp.decode, list(self.nodes[-1]))) + + def get_nodes(self): + return self.nodes[-1] + + def add_node(self, node): + node = rlp_encode(node) + if node not in self.exempt[-1]: + self.nodes[-1].add(node) + + def add_exempt(self, node): + self.exempt[-1].add(rlp_encode(node)) + + def get_mode(self): + return self.mode[-1] + +proof = ProofConstructor() + + +class InvalidSPVProof(Exception): + pass + + +def with_terminator(nibbles): + nibbles = nibbles[:] + if not nibbles or nibbles[-1] != NIBBLE_TERMINATOR: + nibbles.append(NIBBLE_TERMINATOR) + return nibbles + + +def without_terminator(nibbles): + nibbles = nibbles[:] + if nibbles and nibbles[-1] == NIBBLE_TERMINATOR: + del nibbles[-1] + return nibbles + + +def adapt_terminator(nibbles, has_terminator): + if has_terminator: + return with_terminator(nibbles) + else: + return without_terminator(nibbles) + + +def pack_nibbles(nibbles): + """pack nibbles to binary + + :param nibbles: a nibbles sequence. may have a terminator + """ + + if nibbles[-1:] == [NIBBLE_TERMINATOR]: + flags = 2 + nibbles = nibbles[:-1] + else: + flags = 0 + + oddlen = len(nibbles) % 2 + flags |= oddlen # set lowest bit if odd number of nibbles + if oddlen: + nibbles = [flags] + nibbles + else: + nibbles = [flags, 0] + nibbles + o = b'' + for i in range(0, len(nibbles), 2): + o += ascii_chr(16 * nibbles[i] + nibbles[i + 1]) + return o + + +def unpack_to_nibbles(bindata): + """unpack packed binary data to nibbles + + :param bindata: binary packed from nibbles + :return: nibbles sequence, may have a terminator + """ + o = bin_to_nibbles(bindata) + flags = o[0] + if flags & 2: + o.append(NIBBLE_TERMINATOR) + if flags & 1 == 1: + o = o[1:] + else: + o = o[2:] + return o + + +def starts_with(full, part): + ''' test whether the items in the part is + the leading items of the full + ''' + if len(full) < len(part): + return False + return full[:len(part)] == part + + +( + NODE_TYPE_BLANK, + NODE_TYPE_LEAF, + NODE_TYPE_EXTENSION, + NODE_TYPE_BRANCH +) = tuple(range(4)) + + +def is_key_value_type(node_type): + return node_type in [NODE_TYPE_LEAF, + NODE_TYPE_EXTENSION] + +BLANK_NODE = b'' +BLANK_ROOT = utils.sha3rlp(b'') + + +def transient_trie_exception(*args): + raise Exception("Transient trie") + + +class Trie(object): + + def __init__(self, db, root_hash=BLANK_ROOT, transient=False): + '''it also present a dictionary like interface + + :param db key value database + :root: blank or trie node in form of [key, value] or [v0,v1..v15,v] + ''' + self.db = db # Pass in a database object directly + self.transient = transient + if self.transient: + self.update = self.get = self.delete = transient_trie_exception + self.set_root_hash(root_hash) + + # def __init__(self, dbfile, root_hash=BLANK_ROOT): + # '''it also present a dictionary like interface + + # :param dbfile: key value database + # :root: blank or trie node in form of [key, value] or [v0,v1..v15,v] + # ''' + # if isinstance(dbfile, str): + # dbfile = os.path.abspath(dbfile) + # self.db = DB(dbfile) + # else: + # self.db = dbfile # Pass in a database object directly + # self.set_root_hash(root_hash) + + # For SPV proof production/verification purposes + def spv_grabbing(self, node): + global proving + if not proving: + pass + elif proof.get_mode() == RECORDING: + proof.add_node(copy.copy(node)) + # print('recording %s' % encode_hex(utils.sha3(rlp_encode(node)))) + elif proof.get_mode() == VERIFYING: + # print('verifying %s' % encode_hex(utils.sha3(rlp_encode(node)))) + if rlp_encode(node) not in proof.get_nodes(): + raise InvalidSPVProof("Proof invalid!") + + def spv_storing(self, node): + global proving + if not proving: + pass + elif proof.get_mode() == RECORDING: + proof.add_exempt(copy.copy(node)) + elif proof.get_mode() == VERIFYING: + proof.add_node(copy.copy(node)) + + @property + def root_hash(self): + '''always empty or a 32 bytes string + ''' + return self.get_root_hash() + + def get_root_hash(self): + if self.transient: + return self.transient_root_hash + if self.root_node == BLANK_NODE: + return BLANK_ROOT + assert isinstance(self.root_node, list) + val = rlp_encode(self.root_node) + key = utils.sha3(val) + self.db.put(key, val) + self.spv_grabbing(self.root_node) + return key + + @root_hash.setter + def root_hash(self, value): + self.set_root_hash(value) + + def set_root_hash(self, root_hash): + assert is_string(root_hash) + assert len(root_hash) in [0, 32] + if self.transient: + self.transient_root_hash = root_hash + return + if root_hash == BLANK_ROOT: + self.root_node = BLANK_NODE + return + self.root_node = self._decode_to_node(root_hash) + + def clear(self): + ''' clear all tree data + ''' + self._delete_child_storage(self.root_node) + self._delete_node_storage(self.root_node) + self.root_node = BLANK_NODE + + def _delete_child_storage(self, node): + node_type = self._get_node_type(node) + if node_type == NODE_TYPE_BRANCH: + for item in node[:16]: + self._delete_child_storage(self._decode_to_node(item)) + elif node_type == NODE_TYPE_EXTENSION: + self._delete_child_storage(self._decode_to_node(node[1])) + + def _encode_node(self, node): + if node == BLANK_NODE: + return BLANK_NODE + assert isinstance(node, list) + rlpnode = rlp_encode(node) + if len(rlpnode) < 32: + return node + + hashkey = utils.sha3(rlpnode) + self.db.put(hashkey, rlpnode) + self.spv_storing(node) + return hashkey + + def _decode_to_node(self, encoded): + if encoded == BLANK_NODE: + return BLANK_NODE + if isinstance(encoded, list): + return encoded + o = rlp.decode(self.db.get(encoded)) + self.spv_grabbing(o) + return o + + def _get_node_type(self, node): + ''' get node type and content + + :param node: node in form of list, or BLANK_NODE + :return: node type + ''' + if node == BLANK_NODE: + return NODE_TYPE_BLANK + + if len(node) == 2: + nibbles = unpack_to_nibbles(node[0]) + has_terminator = (nibbles and nibbles[-1] == NIBBLE_TERMINATOR) + return NODE_TYPE_LEAF if has_terminator\ + else NODE_TYPE_EXTENSION + if len(node) == 17: + return NODE_TYPE_BRANCH + + def _get(self, node, key): + """ get value inside a node + + :param node: node in form of list, or BLANK_NODE + :param key: nibble list without terminator + :return: + BLANK_NODE if does not exist, otherwise value or hash + """ + node_type = self._get_node_type(node) + + if node_type == NODE_TYPE_BLANK: + return BLANK_NODE + + if node_type == NODE_TYPE_BRANCH: + # already reach the expected node + if not key: + return node[-1] + sub_node = self._decode_to_node(node[key[0]]) + return self._get(sub_node, key[1:]) + + # key value node + curr_key = without_terminator(unpack_to_nibbles(node[0])) + if node_type == NODE_TYPE_LEAF: + return node[1] if key == curr_key else BLANK_NODE + + if node_type == NODE_TYPE_EXTENSION: + # traverse child nodes + if starts_with(key, curr_key): + sub_node = self._decode_to_node(node[1]) + return self._get(sub_node, key[len(curr_key):]) + else: + return BLANK_NODE + + def _update(self, node, key, value): + """ update item inside a node + + :param node: node in form of list, or BLANK_NODE + :param key: nibble list without terminator + .. note:: key may be [] + :param value: value string + :return: new node + + if this node is changed to a new node, it's parent will take the + responsibility to *store* the new node storage, and delete the old + node storage + """ + node_type = self._get_node_type(node) + + if node_type == NODE_TYPE_BLANK: + return [pack_nibbles(with_terminator(key)), value] + + elif node_type == NODE_TYPE_BRANCH: + if not key: + node[-1] = value + else: + new_node = self._update_and_delete_storage( + self._decode_to_node(node[key[0]]), + key[1:], value) + node[key[0]] = self._encode_node(new_node) + return node + + elif is_key_value_type(node_type): + return self._update_kv_node(node, key, value) + + def _update_and_delete_storage(self, node, key, value): + old_node = node[:] + new_node = self._update(node, key, value) + if old_node != new_node: + self._delete_node_storage(old_node) + return new_node + + def _update_kv_node(self, node, key, value): + node_type = self._get_node_type(node) + curr_key = without_terminator(unpack_to_nibbles(node[0])) + is_inner = node_type == NODE_TYPE_EXTENSION + + # find longest common prefix + prefix_length = 0 + for i in range(min(len(curr_key), len(key))): + if key[i] != curr_key[i]: + break + prefix_length = i + 1 + + remain_key = key[prefix_length:] + remain_curr_key = curr_key[prefix_length:] + + if remain_key == [] == remain_curr_key: + if not is_inner: + return [node[0], value] + new_node = self._update_and_delete_storage( + self._decode_to_node(node[1]), remain_key, value) + + elif remain_curr_key == []: + if is_inner: + new_node = self._update_and_delete_storage( + self._decode_to_node(node[1]), remain_key, value) + else: + new_node = [BLANK_NODE] * 17 + new_node[-1] = node[1] + new_node[remain_key[0]] = self._encode_node([ + pack_nibbles(with_terminator(remain_key[1:])), + value + ]) + else: + new_node = [BLANK_NODE] * 17 + if len(remain_curr_key) == 1 and is_inner: + new_node[remain_curr_key[0]] = node[1] + else: + new_node[remain_curr_key[0]] = self._encode_node([ + pack_nibbles( + adapt_terminator(remain_curr_key[1:], not is_inner) + ), + node[1] + ]) + + if remain_key == []: + new_node[-1] = value + else: + new_node[remain_key[0]] = self._encode_node([ + pack_nibbles(with_terminator(remain_key[1:])), value + ]) + + if prefix_length: + # create node for key prefix + return [pack_nibbles(curr_key[:prefix_length]), + self._encode_node(new_node)] + else: + return new_node + + def _getany(self, node, reverse=False, path=[]): + # print('getany', node, 'reverse=', reverse, path) + node_type = self._get_node_type(node) + if node_type == NODE_TYPE_BLANK: + return None + if node_type == NODE_TYPE_BRANCH: + if node[16] and not reverse: + # print('found!', [16], path) + return [16] + scan_range = list(range(16)) + if reverse: + scan_range.reverse() + for i in scan_range: + o = self._getany(self._decode_to_node(node[i]), reverse=reverse, path=path + [i]) + if o is not None: + # print('found@', [i] + o, path) + return [i] + o + if node[16] and reverse: + # print('found!', [16], path) + return [16] + return None + curr_key = without_terminator(unpack_to_nibbles(node[0])) + if node_type == NODE_TYPE_LEAF: + # print('found#', curr_key, path) + return curr_key + + if node_type == NODE_TYPE_EXTENSION: + curr_key = without_terminator(unpack_to_nibbles(node[0])) + sub_node = self._decode_to_node(node[1]) + return curr_key + self._getany(sub_node, reverse=reverse, path=path + curr_key) + + def _split(self, node, key): + node_type = self._get_node_type(node) + if node_type == NODE_TYPE_BLANK: + return BLANK_NODE, BLANK_NODE + elif not key: + return BLANK_NODE, node + elif node_type == NODE_TYPE_BRANCH: + b1 = node[:key[0]] + b1 += [''] * (17 - len(b1)) + b2 = node[key[0]+1:] + b2 = [''] * (17 - len(b2)) + b2 + b1[16], b2[16] = b2[16], b1[16] + sub = self._decode_to_node(node[key[0]]) + sub1, sub2 = self._split(sub, key[1:]) + b1[key[0]] = self._encode_node(sub1) if sub1 else '' + b2[key[0]] = self._encode_node(sub2) if sub2 else '' + return self._normalize_branch_node(b1) if len([x for x in b1 if x]) else BLANK_NODE, \ + self._normalize_branch_node(b2) if len([x for x in b2 if x]) else BLANK_NODE + + descend_key = without_terminator(unpack_to_nibbles(node[0])) + if node_type == NODE_TYPE_LEAF: + if descend_key < key: + return node, BLANK_NODE + else: + return BLANK_NODE, node + elif node_type == NODE_TYPE_EXTENSION: + sub_node = self._decode_to_node(node[1]) + sub_key = key[len(descend_key):] + if starts_with(key, descend_key): + sub1, sub2 = self._split(sub_node, sub_key) + subtype1 = self._get_node_type(sub1) + subtype2 = self._get_node_type(sub2) + if not sub1: + o1 = BLANK_NODE + elif subtype1 in (NODE_TYPE_LEAF, NODE_TYPE_EXTENSION): + new_key = key[:len(descend_key)] + unpack_to_nibbles(sub1[0]) + o1 = [pack_nibbles(new_key), sub1[1]] + else: + o1 = [pack_nibbles(key[:len(descend_key)]), self._encode_node(sub1)] + if not sub2: + o2 = BLANK_NODE + elif subtype2 in (NODE_TYPE_LEAF, NODE_TYPE_EXTENSION): + new_key = key[:len(descend_key)] + unpack_to_nibbles(sub2[0]) + o2 = [pack_nibbles(new_key), sub2[1]] + else: + o2 = [pack_nibbles(key[:len(descend_key)]), self._encode_node(sub2)] + return o1, o2 + elif descend_key < key[:len(descend_key)]: + return node, BLANK_NODE + elif descend_key > key[:len(descend_key)]: + return BLANK_NODE, node + else: + return BLANK_NODE, BLANK_NODE + + def split(self, key): + key = bin_to_nibbles(key) + r1, r2 = self._split(self.root_node, key) + t1, t2 = Trie(self.db), Trie(self.db) + t1.root_node, t2.root_node = r1, r2 + return t1, t2 + + def _merge(self, node1, node2): + assert isinstance(node1, list) or not node1 + assert isinstance(node2, list) or not node2 + node_type1 = self._get_node_type(node1) + node_type2 = self._get_node_type(node2) + if not node1: + return node2 + if not node2: + return node1 + if node_type1 != NODE_TYPE_BRANCH and node_type2 != NODE_TYPE_BRANCH: + descend_key1 = unpack_to_nibbles(node1[0]) + descend_key2 = unpack_to_nibbles(node2[0]) + # find longest common prefix + prefix_length = 0 + for i in range(min(len(descend_key1), len(descend_key2))): + if descend_key1[i] != descend_key2[i]: + break + prefix_length = i + 1 + if prefix_length: + sub1 = self._decode_to_node(node1[1]) if node_type1 == NODE_TYPE_EXTENSION else node1[1] + new_sub1 = [ + pack_nibbles(descend_key1[prefix_length:]), + sub1 + ] if descend_key1[prefix_length:] else sub1 + sub2 = self._decode_to_node(node2[1]) if node_type2 == NODE_TYPE_EXTENSION else node2[1] + new_sub2 = [ + pack_nibbles(descend_key2[prefix_length:]), + sub2 + ] if descend_key2[prefix_length:] else sub2 + return [pack_nibbles(descend_key1[:prefix_length]), + self._encode_node(self._merge(new_sub1, new_sub2))] + + nodes = [[node1], [node2]] + for (node, node_type) in zip(nodes, [node_type1, node_type2]): + if node_type != NODE_TYPE_BRANCH: + new_node = [BLANK_NODE] * 17 + curr_key = unpack_to_nibbles(node[0][0]) + new_node[curr_key[0]] = self._encode_node([ + pack_nibbles(curr_key[1:]), + node[0][1] + ]) if curr_key[0] < 16 and curr_key[1:] else node[0][1] + node[0] = new_node + node1, node2 = nodes[0][0], nodes[1][0] + assert len([i for i in range(17) if node1[i] and node2[i]]) <= 1 + new_node = [self._encode_node(self._merge(self._decode_to_node(node1[i]), self._decode_to_node(node2[i]))) if node1[i] and node2[i] else node1[i] or node2[i] for i in range(17)] + return new_node + + @classmethod + def unsafe_merge(cls, trie1, trie2): + t = Trie(trie1.db) + t.root_node = t._merge(trie1.root_node, trie2.root_node) + return t + + def _iter(self, node, key, reverse=False, path=[]): + # print('iter', node, key, 'reverse =', reverse, 'path =', path) + node_type = self._get_node_type(node) + + if node_type == NODE_TYPE_BLANK: + return None + + elif node_type == NODE_TYPE_BRANCH: + # print('b') + if len(key): + sub_node = self._decode_to_node(node[key[0]]) + o = self._iter(sub_node, key[1:], reverse, path + [key[0]]) + if o is not None: + # print('returning', [key[0]] + o, path) + return [key[0]] + o + if reverse: + scan_range = reversed(list(range(key[0] if len(key) else 0))) + else: + scan_range = list(range(key[0] + 1 if len(key) else 0, 16)) + for i in scan_range: + sub_node = self._decode_to_node(node[i]) + # print('prelim getany', path+[i]) + o = self._getany(sub_node, reverse, path + [i]) + if o is not None: + # print('returning', [i] + o, path) + return [i] + o + if reverse and key and node[16]: + # print('o') + return [16] + return None + + descend_key = without_terminator(unpack_to_nibbles(node[0])) + if node_type == NODE_TYPE_LEAF: + if reverse: + # print('L', descend_key, key, descend_key if descend_key < key else None, path) + return descend_key if descend_key < key else None + else: + # print('L', descend_key, key, descend_key if descend_key > key else None, path) + return descend_key if descend_key > key else None + + if node_type == NODE_TYPE_EXTENSION: + # traverse child nodes + sub_node = self._decode_to_node(node[1]) + sub_key = key[len(descend_key):] + # print('amhere', key, descend_key, descend_key > key[:len(descend_key)]) + if starts_with(key, descend_key): + o = self._iter(sub_node, sub_key, reverse, path + descend_key) + elif descend_key > key[:len(descend_key)] and not reverse: + # print(1) + # print('prelim getany', path+descend_key) + o = self._getany(sub_node, False, path + descend_key) + elif descend_key < key[:len(descend_key)] and reverse: + # print(2) + # print('prelim getany', path+descend_key) + o = self._getany(sub_node, True, path + descend_key) + else: + o = None + # print('returning@', descend_key + o if o else None, path) + return descend_key + o if o else None + + def next(self, key): + # print('nextting') + key = bin_to_nibbles(key) + o = self._iter(self.root_node, key) + # print('answer', o) + return nibbles_to_bin(without_terminator(o)) if o else None + + def prev(self, key): + # print('prevving') + key = bin_to_nibbles(key) + o = self._iter(self.root_node, key, reverse=True) + # print('answer', o) + return nibbles_to_bin(without_terminator(o)) if o else None + + def _delete_node_storage(self, node): + '''delete storage + :param node: node in form of list, or BLANK_NODE + ''' + if node == BLANK_NODE: + return + assert isinstance(node, list) + encoded = self._encode_node(node) + if len(encoded) < 32: + return + """ + ===== FIXME ==== + in the current trie implementation two nodes can share identical subtrees + thus we can not safely delete nodes for now + """ + # self.db.delete(encoded) # FIXME + + def _delete(self, node, key): + """ update item inside a node + + :param node: node in form of list, or BLANK_NODE + :param key: nibble list without terminator + .. note:: key may be [] + :return: new node + + if this node is changed to a new node, it's parent will take the + responsibility to *store* the new node storage, and delete the old + node storage + """ + node_type = self._get_node_type(node) + if node_type == NODE_TYPE_BLANK: + return BLANK_NODE + + if node_type == NODE_TYPE_BRANCH: + return self._delete_branch_node(node, key) + + if is_key_value_type(node_type): + return self._delete_kv_node(node, key) + + def _normalize_branch_node(self, node): + '''node should have only one item changed + ''' + not_blank_items_count = sum(1 for x in range(17) if node[x]) + assert not_blank_items_count >= 1 + + if not_blank_items_count > 1: + return node + + # now only one item is not blank + not_blank_index = [i for i, item in enumerate(node) if item][0] + + # the value item is not blank + if not_blank_index == 16: + return [pack_nibbles(with_terminator([])), node[16]] + + # normal item is not blank + sub_node = self._decode_to_node(node[not_blank_index]) + sub_node_type = self._get_node_type(sub_node) + + if is_key_value_type(sub_node_type): + # collape subnode to this node, not this node will have same + # terminator with the new sub node, and value does not change + new_key = [not_blank_index] + \ + unpack_to_nibbles(sub_node[0]) + return [pack_nibbles(new_key), sub_node[1]] + if sub_node_type == NODE_TYPE_BRANCH: + return [pack_nibbles([not_blank_index]), + self._encode_node(sub_node)] + assert False + + def _delete_and_delete_storage(self, node, key): + old_node = node[:] + new_node = self._delete(node, key) + if old_node != new_node: + self._delete_node_storage(old_node) + return new_node + + def _delete_branch_node(self, node, key): + # already reach the expected node + if not key: + node[-1] = BLANK_NODE + return self._normalize_branch_node(node) + + encoded_new_sub_node = self._encode_node( + self._delete_and_delete_storage( + self._decode_to_node(node[key[0]]), key[1:]) + ) + + if encoded_new_sub_node == node[key[0]]: + return node + + node[key[0]] = encoded_new_sub_node + if encoded_new_sub_node == BLANK_NODE: + return self._normalize_branch_node(node) + + return node + + def _delete_kv_node(self, node, key): + node_type = self._get_node_type(node) + assert is_key_value_type(node_type) + curr_key = without_terminator(unpack_to_nibbles(node[0])) + + if not starts_with(key, curr_key): + # key not found + return node + + if node_type == NODE_TYPE_LEAF: + return BLANK_NODE if key == curr_key else node + + # for inner key value type + new_sub_node = self._delete_and_delete_storage( + self._decode_to_node(node[1]), key[len(curr_key):]) + + if self._encode_node(new_sub_node) == node[1]: + return node + + # new sub node is BLANK_NODE + if new_sub_node == BLANK_NODE: + return BLANK_NODE + + assert isinstance(new_sub_node, list) + + # new sub node not blank, not value and has changed + new_sub_node_type = self._get_node_type(new_sub_node) + + if is_key_value_type(new_sub_node_type): + # collape subnode to this node, not this node will have same + # terminator with the new sub node, and value does not change + new_key = curr_key + unpack_to_nibbles(new_sub_node[0]) + return [pack_nibbles(new_key), new_sub_node[1]] + + if new_sub_node_type == NODE_TYPE_BRANCH: + return [pack_nibbles(curr_key), self._encode_node(new_sub_node)] + + # should be no more cases + assert False + + def delete(self, key): + ''' + :param key: a string with length of [0, 32] + ''' + if not is_string(key): + raise Exception("Key must be string") + + if len(key) > 32: + raise Exception("Max key length is 32") + + self.root_node = self._delete_and_delete_storage( + self.root_node, + bin_to_nibbles(to_string(key))) + self.get_root_hash() + + def _get_size(self, node): + '''Get counts of (key, value) stored in this and the descendant nodes + + :param node: node in form of list, or BLANK_NODE + ''' + if node == BLANK_NODE: + return 0 + + node_type = self._get_node_type(node) + + if is_key_value_type(node_type): + value_is_node = node_type == NODE_TYPE_EXTENSION + if value_is_node: + return self._get_size(self._decode_to_node(node[1])) + else: + return 1 + elif node_type == NODE_TYPE_BRANCH: + sizes = [self._get_size(self._decode_to_node(node[x])) + for x in range(16)] + sizes = sizes + [1 if node[-1] else 0] + return sum(sizes) + + def _iter_branch(self, node): + '''yield (key, value) stored in this and the descendant nodes + :param node: node in form of list, or BLANK_NODE + + .. note:: + Here key is in full form, rather than key of the individual node + ''' + if node == BLANK_NODE: + raise StopIteration + + node_type = self._get_node_type(node) + + if is_key_value_type(node_type): + nibbles = without_terminator(unpack_to_nibbles(node[0])) + key = b'+'.join([to_string(x) for x in nibbles]) + if node_type == NODE_TYPE_EXTENSION: + sub_tree = self._iter_branch(self._decode_to_node(node[1])) + else: + sub_tree = [(to_string(NIBBLE_TERMINATOR), node[1])] + + # prepend key of this node to the keys of children + for sub_key, sub_value in sub_tree: + full_key = (key + b'+' + sub_key).strip(b'+') + yield (full_key, sub_value) + + elif node_type == NODE_TYPE_BRANCH: + for i in range(16): + sub_tree = self._iter_branch(self._decode_to_node(node[i])) + for sub_key, sub_value in sub_tree: + full_key = (str_to_bytes(str(i)) + b'+' + sub_key).strip(b'+') + yield (full_key, sub_value) + if node[16]: + yield (to_string(NIBBLE_TERMINATOR), node[-1]) + + def iter_branch(self): + for key_str, value in self._iter_branch(self.root_node): + if key_str: + nibbles = [int(x) for x in key_str.split(b'+')] + else: + nibbles = [] + key = nibbles_to_bin(without_terminator(nibbles)) + yield key, value + + def _to_dict(self, node): + '''convert (key, value) stored in this and the descendant nodes + to dict items. + + :param node: node in form of list, or BLANK_NODE + + .. note:: + + Here key is in full form, rather than key of the individual node + ''' + if node == BLANK_NODE: + return {} + + node_type = self._get_node_type(node) + + if is_key_value_type(node_type): + nibbles = without_terminator(unpack_to_nibbles(node[0])) + key = b'+'.join([to_string(x) for x in nibbles]) + if node_type == NODE_TYPE_EXTENSION: + sub_dict = self._to_dict(self._decode_to_node(node[1])) + else: + sub_dict = {to_string(NIBBLE_TERMINATOR): node[1]} + + # prepend key of this node to the keys of children + res = {} + for sub_key, sub_value in sub_dict.items(): + full_key = (key + b'+' + sub_key).strip(b'+') + res[full_key] = sub_value + return res + + elif node_type == NODE_TYPE_BRANCH: + res = {} + for i in range(16): + sub_dict = self._to_dict(self._decode_to_node(node[i])) + + for sub_key, sub_value in sub_dict.items(): + full_key = (str_to_bytes(str(i)) + b'+' + sub_key).strip(b'+') + res[full_key] = sub_value + + if node[16]: + res[to_string(NIBBLE_TERMINATOR)] = node[-1] + return res + + def to_dict(self): + d = self._to_dict(self.root_node) + res = {} + for key_str, value in d.items(): + if key_str: + nibbles = [int(x) for x in key_str.split(b'+')] + else: + nibbles = [] + key = nibbles_to_bin(without_terminator(nibbles)) + res[key] = value + return res + + def get(self, key): + return self._get(self.root_node, bin_to_nibbles(to_string(key))) + + def __len__(self): + return self._get_size(self.root_node) + + def __getitem__(self, key): + return self.get(key) + + def __setitem__(self, key, value): + return self.update(key, value) + + def __delitem__(self, key): + return self.delete(key) + + def __iter__(self): + return iter(self.to_dict()) + + def __contains__(self, key): + return self.get(key) != BLANK_NODE + + def update(self, key, value): + ''' + :param key: a string + :value: a string + ''' + if not is_string(key): + raise Exception("Key must be string") + + # if len(key) > 32: + # raise Exception("Max key length is 32") + + if not is_string(value): + raise Exception("Value must be string") + + # if value == '': + # return self.delete(key) + self.root_node = self._update_and_delete_storage( + self.root_node, + bin_to_nibbles(to_string(key)), + to_string(value)) + self.get_root_hash() + + def root_hash_valid(self): + if self.root_hash == BLANK_ROOT: + return True + return self.root_hash in self.db + + def produce_spv_proof(self, key): + proof.push(RECORDING) + self.get(key) + o = proof.get_nodelist() + proof.pop() + return o + + +def verify_spv_proof(root, key, proof): + proof.push(VERIFYING, proof) + t = Trie(db.EphemDB()) + + for i, node in enumerate(proof): + R = rlp_encode(node) + H = utils.sha3(R) + t.db.put(H, R) + try: + t.root_hash = root + t.get(key) + proof.pop() + return True + except Exception as e: + print(e) + proof.pop() + return False + + +if __name__ == "__main__": + import sys + from . import db + + _db = db.DB(sys.argv[2]) + + def encode_node(nd): + if is_string(nd): + return encode_hex(nd) + else: + return encode_hex(rlp_encode(nd)) + + if len(sys.argv) >= 2: + if sys.argv[1] == 'insert': + t = Trie(_db, decode_hex(sys.argv[3])) + t.update(sys.argv[4], sys.argv[5]) + print(encode_node(t.root_hash)) + elif sys.argv[1] == 'get': + t = Trie(_db, decode_hex(sys.argv[3])) + print(t.get(sys.argv[4])) diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..f1150e71 --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,3 @@ +tox==2.6.0 +pytest==3.0.7 +hypothesis==3.7.0 diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..d425f36c --- /dev/null +++ b/setup.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import os + +DIR = os.path.dirname(os.path.abspath(__file__)) + + +from setuptools import setup, find_packages + + +readme = open(os.path.join(DIR, 'README.md')).read() + + +setup( + name='ethereum-trie', + version="0.1.0", + description="""Python implementation of the Ethereum Trie structure""", + long_description=readme, + author='Piper Merriam', + author_email='pipermerriam@gmail.com', + url='https://github.com/pipermerriam/trie', + include_package_data=True, + py_modules=['trie'], + install_requires=[ + "ethereum-utils>=0.2.0", + "rlp==0.4.7", + ], + license="MIT", + zip_safe=False, + keywords='ethereum blockchain evm trie merkle', + packages=find_packages(exclude=["tests", "tests.*"]), + classifiers=[ + 'Development Status :: 2 - Pre-Alpha', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: MIT License', + 'Natural Language :: English', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + ], +) diff --git a/tests/test_nibbles_utils.py b/tests/test_nibbles_utils.py new file mode 100644 index 00000000..e35937e7 --- /dev/null +++ b/tests/test_nibbles_utils.py @@ -0,0 +1,15 @@ +from hypothesis import ( + given, + strategies as st, +) +from trie.utils.nibbles import ( + nibbles_to_bytes, + bytes_to_nibbles, +) + + +@given(value=st.binary(min_size=0, max_size=1024)) +def test_round_trip_nibbling(value): + value_as_nibbles = bytes_to_nibbles(value) + result = nibbles_to_bytes(value_as_nibbles) + assert result == value diff --git a/tests/test_nodes_utils.py b/tests/test_nodes_utils.py new file mode 100644 index 00000000..9a0fe587 --- /dev/null +++ b/tests/test_nodes_utils.py @@ -0,0 +1,42 @@ +import pytest + +from trie.utils.nodes import ( + get_common_prefix_length, + consume_common_prefix, +) + + +@pytest.mark.parametrize( + 'left,right,expected', + ( + ([], [], 0), + ([], [1], 0), + ([1], [1], 1), + ([1], [1, 1], 1), + ([1, 2], [1, 1], 1), + ([1, 2, 3, 4, 5, 6], [1, 2, 3, 5, 6], 3), + ), +) +def test_get_common_prefix_length(left, right, expected): + actual_a = get_common_prefix_length(left, right) + actual_b = get_common_prefix_length(right, left) + assert actual_a == actual_b == expected + + +@pytest.mark.parametrize( + 'left,right,expected', + ( + ([], [], ([], [], [])), + ([], [1], ([], [], [1])), + ([1], [1], ([1], [], [])), + ([1], [1, 1], ([1], [], [1])), + ([1, 2], [1, 1], ([1], [2], [1])), + ([1, 2, 3, 4, 5, 6], [1, 2, 3, 5, 6], ([1, 2, 3], [4, 5, 6], [5, 6])), + ), +) +def test_consume_common_prefix(left, right, expected): + actual_a = consume_common_prefix(left, right) + actual_b = consume_common_prefix(right, left) + expected_b = (expected[0], expected[2], expected[1]) + assert actual_a == expected + assert actual_b == expected_b diff --git a/tests/test_trie.py b/tests/test_trie.py new file mode 100644 index 00000000..f36805e0 --- /dev/null +++ b/tests/test_trie.py @@ -0,0 +1,106 @@ +import pytest + +import itertools +import fnmatch +import json +import os + +from eth_utils import ( + is_0x_prefixed, + decode_hex, + force_bytes, +) + +from trie import ( + Trie, +) +from trie.db.memory import ( + MemoryDB, +) + + +def normalize_fixture(fixture): + normalized_fixture = { + 'in': tuple( + ( + decode_hex(key) if is_0x_prefixed(key) else force_bytes(key), + ( + decode_hex(value) if is_0x_prefixed(value) else force_bytes(value) + ) if value is not None else None, + ) + for key, value + in (fixture['in'].items() if isinstance(fixture['in'], dict) else fixture['in']) + ), + 'root': decode_hex(fixture['root']) + } + return normalized_fixture + + +ROOT_PROJECT_DIR = os.path.dirname(os.path.dirname(__file__)) + + +def recursive_find_files(base_dir, pattern): + for dirpath, _, filenames in os.walk(base_dir): + for filename in filenames: + if fnmatch.fnmatch(filename, pattern): + yield os.path.join(dirpath, filename) + + +BASE_FIXTURE_PATH = os.path.join(ROOT_PROJECT_DIR, 'fixtures', 'TrieTests') + + +FIXTURES_PATHS = tuple(recursive_find_files(BASE_FIXTURE_PATH, "trietest.json")) + + +RAW_FIXTURES = tuple( + ( + os.path.basename(fixture_path), + json.load(open(fixture_path)), + ) for fixture_path in FIXTURES_PATHS +) + + +FIXTURES = tuple( + ( + "{0}:{1}".format(fixture_filename, key), + normalize_fixture(fixtures[key]), + ) + for fixture_filename, fixtures in RAW_FIXTURES + for key in sorted(fixtures.keys()) +) + + +@pytest.mark.parametrize( + 'fixture_name,fixture', FIXTURES, +) +def test_trie_using_fixtures(fixture_name, fixture): + + keys_and_values = fixture['in'] + deletes = tuple(k for k, v in keys_and_values if v is None) + remaining = { + k: v + for k, v + in keys_and_values + if k not in deletes + } + + for kv_permutation in itertools.islice(itertools.permutations(keys_and_values), 100): + print("in it") + trie = Trie(db=MemoryDB()) + + for key, value in kv_permutation: + if value is None: + del trie[key] + else: + trie[key] = value + for key in deletes: + del trie[key] + + for key, expected_value in remaining.items(): + actual_value = trie[key] + assert actual_value == expected_value + + expected_root = fixture['root'] + actual_root = trie.root_hash + + assert actual_root == expected_root diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..74def155 --- /dev/null +++ b/tox.ini @@ -0,0 +1,24 @@ +[tox] +envlist= + py{27,34,35} + flake8 + +[flake8] +max-line-length= 100 +exclude= tests/* + +[testenv] +usedevelop=True +commands= + py.test {posargs:tests} +deps = + -r{toxinidir}/requirements-dev.txt +basepython = + py27: python2.7 + py34: python3.4 + py35: python3.5 + +[testenv:flake8] +basepython=python +deps=flake8 +commands=flake8 {toxinidir}/trie diff --git a/trie/__init__.py b/trie/__init__.py new file mode 100644 index 00000000..599915f2 --- /dev/null +++ b/trie/__init__.py @@ -0,0 +1,5 @@ +from __future__ import absolute_import + +from .trie import ( + Trie, +) diff --git a/trie/constants.py b/trie/constants.py new file mode 100644 index 00000000..512c020f --- /dev/null +++ b/trie/constants.py @@ -0,0 +1,16 @@ +BLANK_NODE = b'' +# sha3(rlp.encode(b'')) +BLANK_NODE_HASH = b'V\xe8\x1f\x17\x1b\xccU\xa6\xff\x83E\xe6\x92\xc0\xf8n[H\xe0\x1b\x99l\xad\xc0\x01b/\xb5\xe3c\xb4!' + + +NIBBLES_LOOKUP = {hex_char: idx for idx, hex_char in enumerate('0123456789abcdef')} +NIBBLE_TERMINATOR = 16 + +HP_FLAG_2 = 2 +HP_FLAG_0 = 0 + + +NODE_TYPE_BLANK = 0 +NODE_TYPE_LEAF = 1 +NODE_TYPE_EXTENSION = 2 +NODE_TYPE_BRANCH = 3 diff --git a/trie/db/__init__.py b/trie/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trie/db/base.py b/trie/db/base.py new file mode 100644 index 00000000..11933c34 --- /dev/null +++ b/trie/db/base.py @@ -0,0 +1,35 @@ +class BaseDB(object): + def get(self, key): + raise NotImplementedError( + "The `_get` method must be implemented by subclasses of BaseDB" + ) + + def set(self, key, value): + raise NotImplementedError( + "The `_set` method must be implemented by subclasses of BaseDB" + ) + + def exists(self, key): + raise NotImplementedError( + "The `_exists` method must be implemented by subclasses of BaseDB" + ) + + def delete(self, key): + raise NotImplementedError( + "The `_delete` method must be implemented by subclasses of BaseDB" + ) + + # + # Dictionary API + # + def __getitem__(self, key): + return self.get(key) + + def __setitem__(self, key, value): + return self.set(key, value) + + def __delitem__(self, key): + return self.delete(key) + + def __contains__(self, key): + return self.exists(key) diff --git a/trie/db/memory.py b/trie/db/memory.py new file mode 100644 index 00000000..945543ce --- /dev/null +++ b/trie/db/memory.py @@ -0,0 +1,22 @@ +from .base import ( + BaseDB, +) + + +class MemoryDB(BaseDB): + kv_store = None + + def __init__(self): + self.kv_store = {} + + def get(self, key): + return self.kv_store[key] + + def set(self, key, value): + self.kv_store[key] = value + + def exists(self, key): + return key in self.kv_store + + def delete(self, key): + del self.kv_store[key] diff --git a/trie/exceptions.py b/trie/exceptions.py new file mode 100644 index 00000000..af5652f6 --- /dev/null +++ b/trie/exceptions.py @@ -0,0 +1,10 @@ +class InvalidNibbles(Exception): + pass + + +class InvalidNode(Exception): + pass + + +class ValidationError(Exception): + pass diff --git a/trie/trie.py b/trie/trie.py new file mode 100644 index 00000000..3a1b56c8 --- /dev/null +++ b/trie/trie.py @@ -0,0 +1,377 @@ +from __future__ import absolute_import + +import itertools + +import rlp + +from trie.constants import ( + BLANK_NODE, + BLANK_NODE_HASH, + NODE_TYPE_BLANK, + NODE_TYPE_LEAF, + NODE_TYPE_EXTENSION, + NODE_TYPE_BRANCH, +) +from trie.exceptions import ( + InvalidNode, +) +from trie.validation import ( + validate_is_node, + validate_is_bytes, +) + +from trie.utils.sha3 import ( + keccak, +) +from trie.utils.nibbles import ( + bytes_to_nibbles, + decode_nibbles, + encode_nibbles, +) +from trie.utils.nodes import ( + get_node_type, + extract_key, + compute_leaf_key, + compute_extension_key, + is_extension_node, + is_leaf_node, + is_blank_node, + consume_common_prefix, + key_starts_with, +) + + +class Trie(object): + db = None + root_hash = None + + def __init__(self, db, root_hash=BLANK_NODE_HASH): + self.db = db + validate_is_bytes(root_hash) + self.root_hash = root_hash + + def get(self, key): + validate_is_bytes(key) + + trie_key = bytes_to_nibbles(key) + root_node = self._get_node(self.root_hash) + + return self._get(root_node, trie_key) + + def _get(self, node, trie_key): + node_type = get_node_type(node) + + if node_type == NODE_TYPE_BLANK: + return BLANK_NODE + elif node_type in {NODE_TYPE_LEAF, NODE_TYPE_EXTENSION}: + return self._get_kv_node(node, trie_key) + elif node_type == NODE_TYPE_BRANCH: + return self._get_branch_node(node, trie_key) + else: + raise Exception("Invariant: This shouldn't ever happen") + + def set(self, key, value): + validate_is_bytes(key) + validate_is_bytes(value) + + trie_key = bytes_to_nibbles(key) + root_node = self._get_node(self.root_hash) + + new_node = self._set(root_node, trie_key, value) + self._set_root_node(new_node) + + def _set(self, node, trie_key, value): + node_type = get_node_type(node) + + if node_type == NODE_TYPE_BLANK: + return [ + compute_leaf_key(trie_key), + value, + ] + elif node_type in {NODE_TYPE_LEAF, NODE_TYPE_EXTENSION}: + return self._set_kv_node(node, trie_key, value) + elif node_type == NODE_TYPE_BRANCH: + return self._set_branch_node(node, trie_key, value) + else: + raise Exception("Invariant: This shouldn't ever happen") + + def exists(self, key): + validate_is_bytes(key) + + return self.get(key) == BLANK_NODE + + def delete(self, key): + validate_is_bytes(key) + + trie_key = bytes_to_nibbles(key) + root_node = self._get_node(self.root_hash) + + new_node = self._delete(root_node, trie_key) + self._set_root_node(new_node) + + def _delete(self, node, trie_key): + node_type = get_node_type(node) + + if node_type == NODE_TYPE_BLANK: + return BLANK_NODE + elif node_type in {NODE_TYPE_LEAF, NODE_TYPE_EXTENSION}: + return self._delete_kv_node(node, trie_key) + elif node_type == NODE_TYPE_BRANCH: + return self._delete_branch_node(node, trie_key) + else: + raise Exception("Invariant: This shouldn't ever happen") + + # + # Convenience + # + @property + def root_node(self): + return self._get_node(self.root_hash) + + @root_node.setter + def root_node(self, value): + self._set_root_node(value) + + # + # Utils + # + def _set_root_node(self, root_node): + validate_is_node(root_node) + encoded_root_node = rlp.encode(root_node) + self.root_hash = keccak(encoded_root_node) + self.db[self.root_hash] = encoded_root_node + + def _get_node(self, node_hash): + if node_hash == BLANK_NODE: + return BLANK_NODE + elif node_hash == BLANK_NODE_HASH: + return BLANK_NODE + + if len(node_hash) < 32: + encoded_node = node_hash + else: + encoded_node = self.db.get(node_hash) + node = self._decode_node(encoded_node) + + return node + + def _persist_node(self, node): + validate_is_node(node) + if is_blank_node(node): + return BLANK_NODE + encoded_node = rlp.encode(node) + if len(encoded_node) < 32: + return node + + encoded_node_hash = keccak(encoded_node) + self.db[encoded_node_hash] = encoded_node + return encoded_node_hash + + def _decode_node(self, encoded_node_or_hash): + if encoded_node_or_hash == BLANK_NODE: + return BLANK_NODE + elif isinstance(encoded_node_or_hash, list): + return encoded_node_or_hash + else: + return rlp.decode(encoded_node_or_hash) + + # + # Node Operation Helpers + def _normalize_branch_node(self, node): + """ + A branch node which is left with only a single non-blank item should be + turned into either a leaf or extension node. + """ + iter_node = iter(node) + if any(iter_node) and any(iter_node): + return node + + if node[16]: + return [compute_leaf_key([]), node[16]] + + sub_node_idx, sub_node_hash = next( + (idx, v) + for idx, v + in enumerate(node[:16]) + if v + ) + sub_node = self._get_node(sub_node_hash) + sub_node_type = get_node_type(sub_node) + + if sub_node_type in {NODE_TYPE_LEAF, NODE_TYPE_EXTENSION}: + new_subnode_key = encode_nibbles(tuple(itertools.chain( + [sub_node_idx], + decode_nibbles(sub_node[0]), + ))) + return [new_subnode_key, sub_node[1]] + elif sub_node_type == NODE_TYPE_BRANCH: + subnode_hash = self._persist_node(sub_node) + return [encode_nibbles([sub_node_idx]), subnode_hash] + else: + raise Exception("Invariant: this code block should be unreachable") + + # + # Node Operations + # + def _delete_branch_node(self, node, trie_key): + if not trie_key: + node[-1] = BLANK_NODE + return self._normalize_branch_node(node) + + node_to_delete = self._get_node(node[trie_key[0]]) + + sub_node = self._delete(node_to_delete, trie_key[1:]) + encoded_sub_node = self._persist_node(sub_node) + + if encoded_sub_node == node[trie_key[0]]: + return node + + node[trie_key[0]] = encoded_sub_node + if encoded_sub_node == BLANK_NODE: + return self._normalize_branch_node(node) + + return node + + def _delete_kv_node(self, node, trie_key): + current_key = extract_key(node) + + if not key_starts_with(trie_key, current_key): + # key not present?.... + return node + + node_type = get_node_type(node) + + if node_type == NODE_TYPE_LEAF: + if trie_key == current_key: + return BLANK_NODE + else: + return node + + sub_node_key = trie_key[len(current_key):] + sub_node = self._get_node(node[1]) + + new_sub_node = self._delete(sub_node, sub_node_key) + encoded_new_sub_node = self._persist_node(new_sub_node) + + if encoded_new_sub_node == node[1]: + return node + + if new_sub_node == BLANK_NODE: + return BLANK_NODE + + new_sub_node_type = get_node_type(new_sub_node) + if new_sub_node_type in {NODE_TYPE_LEAF, NODE_TYPE_EXTENSION}: + new_key = current_key + decode_nibbles(new_sub_node[0]) + return [encode_nibbles(new_key), new_sub_node[1]] + + if new_sub_node_type == NODE_TYPE_BRANCH: + return [encode_nibbles(current_key), encoded_new_sub_node] + + raise Exception("Invariant, this code path should not be reachable") + + def _set_branch_node(self, node, trie_key, value): + if trie_key: + sub_node = self._get_node(node[trie_key[0]]) + + new_node = self._set(sub_node, trie_key[1:], value) + node[trie_key[0]] = self._persist_node(new_node) + else: + node[-1] = value + return node + + def _set_kv_node(self, node, trie_key, value): + current_key = extract_key(node) + common_prefix, current_key_remainder, trie_key_remainder = consume_common_prefix( + current_key, + trie_key, + ) + is_extension = is_extension_node(node) + + if not current_key_remainder and not trie_key_remainder: + if is_leaf_node(node): + return [node[0], value] + else: + sub_node = self._get_node(node[1]) + # TODO: this needs to cleanup old storage. + new_node = self._set(sub_node, trie_key_remainder, value) + elif not current_key_remainder: + if is_extension: + sub_node = self._get_node(node[1]) + # TODO: this needs to cleanup old storage. + new_node = self._set(sub_node, trie_key_remainder, value) + else: + subnode_position = trie_key_remainder[0] + subnode_key = compute_leaf_key(trie_key_remainder[1:]) + sub_node = [subnode_key, value] + + new_node = [BLANK_NODE] * 16 + [node[1]] + new_node[subnode_position] = self._persist_node(sub_node) + else: + new_node = [BLANK_NODE] * 17 + + if len(current_key_remainder) == 1 and is_extension: + new_node[current_key_remainder[0]] = node[1] + else: + if is_extension: + compute_key_fn = compute_extension_key + else: + compute_key_fn = compute_leaf_key + + new_node[current_key_remainder[0]] = self._persist_node([ + compute_key_fn(current_key_remainder[1:]), + node[1], + ]) + + if trie_key_remainder: + new_node[trie_key_remainder[0]] = self._persist_node([ + compute_leaf_key(trie_key_remainder[1:]), + value, + ]) + else: + new_node[-1] = value + + if common_prefix: + new_node_key = self._persist_node(new_node) + return [compute_extension_key(common_prefix), new_node_key] + else: + return new_node + + def _get_branch_node(self, node, trie_key): + if not trie_key: + return node[16] + else: + sub_node = self._get_node(node[trie_key[0]]) + return self._get(sub_node, trie_key[1:]) + + def _get_kv_node(self, node, trie_key): + current_key = extract_key(node) + node_type = get_node_type(node) + + if node_type == NODE_TYPE_LEAF: + if trie_key == current_key: + return node[1] + else: + return BLANK_NODE + elif node_type == NODE_TYPE_EXTENSION: + if key_starts_with(trie_key, current_key): + sub_node = self._get_node(node[1]) + return self._get(sub_node, trie_key[len(current_key):]) + else: + return BLANK_NODE + else: + raise Exception("Invariant: unreachable code path") + + # + # Dictionary API + # + def __getitem__(self, key): + return self.get(key) + + def __setitem__(self, key, value): + return self.set(key, value) + + def __delitem__(self, key): + return self.delete(key) + + def __contains__(self, key): + return self.exists(key) diff --git a/trie/utils/__init__.py b/trie/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/trie/utils/nibbles.py b/trie/utils/nibbles.py new file mode 100644 index 00000000..df23dd96 --- /dev/null +++ b/trie/utils/nibbles.py @@ -0,0 +1,121 @@ +import itertools + +from eth_utils import ( + encode_hex, + remove_0x_prefix, + to_tuple, +) + +from trie.constants import ( + NIBBLES_LOOKUP, + NIBBLE_TERMINATOR, + HP_FLAG_2, + HP_FLAG_0, +) +from trie.exceptions import ( + InvalidNibbles, +) + + +@to_tuple +def bytes_to_nibbles(value): + """ + Convert a byte string to nibbles + """ + for nibble in remove_0x_prefix(encode_hex(value)): + yield NIBBLES_LOOKUP[nibble] + + +@to_tuple +def pairwise(iterable): + if len(iterable) % 2: + raise ValueError("Odd length value. Cannot apply pairwise operation") + + for left, right in zip(*[iter(iterable)] * 2): + yield left, right + + +def nibbles_to_bytes(nibbles): + if any(nibble > 15 or nibble < 0 for nibble in nibbles): + raise InvalidNibbles( + "Nibbles contained invalid value. Must be constrained between [0, 15]" + ) + + if len(nibbles) % 2: + raise InvalidNibbles("Nibbles must be even in length") + + value = bytes(bytearray(tuple( + 16 * left + right + for left, right in pairwise(nibbles) + ))) + return value + + +def is_nibbles_terminated(nibbles): + return nibbles and nibbles[-1] == NIBBLE_TERMINATOR + + +@to_tuple +def add_nibbles_terminator(nibbles): + if is_nibbles_terminated(nibbles): + return nibbles + return itertools.chain(nibbles, (NIBBLE_TERMINATOR,)) + + +@to_tuple +def remove_nibbles_terminator(nibbles): + if is_nibbles_terminated(nibbles): + return nibbles[:-1] + return nibbles + + +def encode_nibbles(nibbles): + """ + The Hex Prefix function + """ + if is_nibbles_terminated(nibbles): + flag = HP_FLAG_2 + else: + flag = HP_FLAG_0 + + raw_nibbles = remove_nibbles_terminator(nibbles) + + is_odd = len(raw_nibbles) % 2 + + if is_odd: + flagged_nibbles = tuple(itertools.chain( + (flag + 1,), + raw_nibbles, + )) + else: + flagged_nibbles = tuple(itertools.chain( + (flag, 0), + raw_nibbles, + )) + + prefixed_value = nibbles_to_bytes(flagged_nibbles) + + return prefixed_value + + +def decode_nibbles(value): + """ + The inverse of the Hex Prefix function + """ + nibbles_with_flag = bytes_to_nibbles(value) + flag = nibbles_with_flag[0] + + needs_terminator = flag in {HP_FLAG_2, HP_FLAG_2 + 1} + is_odd_length = flag in {HP_FLAG_0 + 1, HP_FLAG_2 + 1} + + if is_odd_length: + raw_nibbles = nibbles_with_flag[1:] + else: + raw_nibbles = nibbles_with_flag[2:] + + if needs_terminator: + nibbles = add_nibbles_terminator(raw_nibbles) + else: + nibbles = raw_nibbles + + return nibbles diff --git a/trie/utils/nodes.py b/trie/utils/nodes.py new file mode 100644 index 00000000..c4eefba8 --- /dev/null +++ b/trie/utils/nodes.py @@ -0,0 +1,89 @@ +import rlp + +from trie.constants import ( + NODE_TYPE_BLANK, + NODE_TYPE_LEAF, + NODE_TYPE_EXTENSION, + NODE_TYPE_BRANCH, + BLANK_NODE, +) +from .nibbles import ( + decode_nibbles, + encode_nibbles, + is_nibbles_terminated, + add_nibbles_terminator, + remove_nibbles_terminator, +) + + +def get_node_type(node): + if node == BLANK_NODE: + return NODE_TYPE_BLANK + elif len(node) == 2: + key, _ = node + nibbles = decode_nibbles(key) + if is_nibbles_terminated(nibbles): + return NODE_TYPE_LEAF + else: + return NODE_TYPE_EXTENSION + elif len(node) == 17: + return NODE_TYPE_BRANCH + else: + raise InvalidNode("Unable to determine node type") + + +def is_blank_node(node): + return node == BLANK_NODE + + +def is_leaf_node(node): + if len(node) != 2: + return False + key, _ = node + nibbles = decode_nibbles(key) + return is_nibbles_terminated(nibbles) + + +def is_extension_node(node): + if len(node) != 2: + return False + key, _ = node + nibbles = decode_nibbles(key) + return not is_nibbles_terminated(nibbles) + + +def is_branch_node(node): + return len(node) == 17 + + +def extract_key(node): + prefixed_key, _ = node + key = remove_nibbles_terminator(decode_nibbles(prefixed_key)) + return key + + +def compute_leaf_key(nibbles): + return encode_nibbles(add_nibbles_terminator(nibbles)) + + +def compute_extension_key(nibbles): + return encode_nibbles(nibbles) + + +def get_common_prefix_length(left_key, right_key): + for idx, (left_nibble, right_nibble) in enumerate(zip(left_key, right_key)): + if left_nibble != right_nibble: + return idx + return min(len(left_key), len(right_key)) + + +def consume_common_prefix(left_key, right_key): + common_prefix_length = get_common_prefix_length(left_key, right_key) + common_prefix = left_key[:common_prefix_length] + left_remainder = left_key[common_prefix_length:] + right_remainder = right_key[common_prefix_length:] + return common_prefix, left_remainder, right_remainder + + +def key_starts_with(full_key, partial_key): + return all(left == right for left, right in zip(full_key, partial_key)) diff --git a/trie/utils/sha3.py b/trie/utils/sha3.py new file mode 100644 index 00000000..f25fa25f --- /dev/null +++ b/trie/utils/sha3.py @@ -0,0 +1,10 @@ +from __future__ import absolute_import + +from sha3 import keccak_256 + + +def keccak(value): + return keccak_256(value).digest() + + +assert keccak(b'') == b"\xc5\xd2F\x01\x86\xf7#<\x92~}\xb2\xdc\xc7\x03\xc0\xe5\x00\xb6S\xca\x82';{\xfa\xd8\x04]\x85\xa4p", "Incorrect sha3. Make sure it's keccak" # noqa diff --git a/trie/validation.py b/trie/validation.py new file mode 100644 index 00000000..0b0a6150 --- /dev/null +++ b/trie/validation.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import + +from trie.constants import ( + BLANK_NODE, +) +from trie.exceptions import ( + ValidationError, +) + + +def validate_is_bytes(value): + if not isinstance(value, bytes): + raise ValidationError("Value is not of type `bytes`: got '{0}'".format(type(value))) + + +def validate_length(value, length): + if len(value) != length: + raise ValidationError("Value is of length {0}. Must be {1}".format(len(value), length)) + + +def validate_is_node(node): + if node == BLANK_NODE: + return + elif len(node) == 2: + key, value = node + validate_is_bytes(key) + if isinstance(value, list): + validate_is_node(value) + else: + validate_is_bytes(value) + elif len(node) == 17: + validate_is_bytes(node[16]) + for sub_node in node[:16]: + if sub_node == BLANK_NODE: + continue + elif isinstance(sub_node, list): + if len(sub_node) != 2: + raise ValidationError("Invalid branch subnode: {0}".format(subnode)) + sub_node_key, sub_node_value = sub_node + validate_is_bytes(sub_node_key) + if isinstance(sub_node_value, list): + validate_is_node(sub_node_value) + else: + validate_is_bytes(sub_node_value) + else: + validate_is_bytes(sub_node) + validate_length(sub_node, 32) + else: + raise ValidationError("Invalid Node: {0}".format(node))