In [1]:
from typing import List, Tuple
import math
from enum import Enum

In [2]:
def is_power_of_2(number):
    return math.log2(number).is_integer()

class Side(Enum):
    LEFT=0,
    RIGHT=1,

class MerkleTree:
    def __init__(self, values:List[str]):
        self.raw_leafs = values
        self.leafs = [self.calculate_hash(i) for i in self.raw_leafs]
        self.build_tree()

    def __repr__(self):
        return str(self.raw_leafs)

    def calculate_hash(self,a:str,b:str='') -> str:
        return a+b

    def build_tree(self):
        while not is_power_of_2(len(self.leafs)):
            # add "" until it becomes a full tree
            self.leafs.append("")

        self.tree = self.leafs[::] # deep copy
        nodes_to_process = self.tree[::]
        while len(nodes_to_process) > 1:
            #print ('nodes_to_process',nodes_to_process)
            left = nodes_to_process.pop(0)
            right = nodes_to_process.pop(0)
            new_hash = self.calculate_hash(left,right)
            self.tree.append(new_hash)
            nodes_to_process.append(new_hash)

    def get_proof(self, index):
        root_index = self.get_root_index()
        exit(1)
        target = index
        proof = []
        while target != root_index:
            symmetric = self.get_symmetric(target)
            side = Side.LEFT if symmetric < target else Side.RIGHT
            target = self.get_parent_index(target)
            # move to parent
            proof.append((symmetric,side))
            
            #print ('moving to target ', target)
        return proof

    def get_root_index(self):
        n = len(self.leafs)
        num_elements = 0
        n_levels = int(math.log2(n) + 1)
        for i in range(n_levels):
            new_elements = (n // 2**(i))
            #print ('new ', new_elements)
            num_elements += new_elements
        return num_elements - 1 #0-based indexing
        
    def verify(self, index: int, proof:List[Tuple[int, Side]]) -> bool:
        hash_to_verify = self.tree[index]
        while proof:
            (next_index, side) = proof.pop(0)
            print ('applying ', (self.tree[next_index],side))
            if side == Side.RIGHT:
                hash_to_verify = self.calculate_hash(hash_to_verify, self.tree[next_index])
            else:
                hash_to_verify = self.calculate_hash(self.tree[next_index], hash_to_verify)     
            print ('final hash ', hash_to_verify)
        return hash_to_verify == self.tree[-1]
    
    def get_symmetric(self, index):
        # we get the other child of the parent
        if index % 2 == 0:
            return index + 1
        return index - 1
    
    def get_parent_index(self, index):
        if not self.tree:
            self.build_tree()

        # go to right node instead of left node
        if index % 2 == 0:
            index += 1

        delta = len(self.leafs) - 1 - (index//2)
        return index + delta
        

In [3]:
#mt = MerkleTree(['a','b','c','d','e','f','g','h','i'])
mt = MerkleTree(['a','b','c','d','e','f','g','h'])
print(mt.tree)
print(mt.leafs)

['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'ab', 'cd', 'ef', 'gh', 'abcd', 'efgh', 'abcdefgh']
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']


In [4]:
mt.get_parent_index(12)

14

In [5]:
mt.get_root_index()

14

In [6]:
mt.get_proof(0)

[(1, <Side.RIGHT: (1,)>), (9, <Side.RIGHT: (1,)>), (13, <Side.RIGHT: (1,)>)]

In [8]:
index = 9
proof = mt.get_proof(index)
print ('proof ',proof)
mt.verify(index, proof)

proof  [(8, <Side.LEFT: (0,)>), (13, <Side.RIGHT: (1,)>)]
applying  ('ab', <Side.LEFT: (0,)>)
final hash  abcd
applying  ('efgh', <Side.RIGHT: (1,)>)
final hash  abcdefgh


True