# What are merkle trees and how could I use them?

In this notebook I look into how merkle trees work and how I might use them.

Let me first create some batches that I want to hash.

In [30]:
import hashlib
from typing import List, Tuple, Callable

In [31]:
def sha256(data: bytes) -> bytes:
    """Default SHA-256 hash function."""
    return hashlib.sha256(data).digest()

In [32]:
llmBatch = [
{
id: 1,
"prompt": "Analyze sentiment of user feedback",
"model": "gpt-4-turbo",
"recipient": "0xUser1Address..."
},
{
id: 2,
"prompt": "Translate user instructions to French",
"model": "gpt-4-turbo",
"recipient": "0xUser2Address..."
},
{
id: 3,
"prompt": "Summarize the latest research on AI",
"model": "gpt-4-turbo",
"recipient": "0xUser3Address..."
},
{
id: 4,
"prompt": "Generate a creative story based on user input",
"model": "gpt-4-turbo",
"recipient": "0xUser4Address..."
}
]

Now that I have my batches, I can create a merkle tree. In the merkle tree we first encode each of the batches into a byte string, then we hash each of these strings to create the leaf nodes of the tree. The hash once again producing a byte string.

In [33]:
leaf_hashes = [sha256(str(batch).encode('utf-8')) for batch in llmBatch]
print("Leaf hashes:", [h.hex() for h in leaf_hashes])

Leaf hashes: ['e4ac014889c5a7b061b729b59e1f4f6bab840b9766810eea46478091e3d3ff70', '7d29b8b8bf00bbb427723f52923700e931b9dc25d8afc726dba7c59da6b645bd', 'edfe39aab2e1245499dda9c0609c7c7586795d27319f2683f5a1ec92277d74d4', '4f8df7b3e932b626c144ad7a30331c04a9874f1d1a48c20281c59d48a521cd3e']


Now that I have the four leafs I can create the parent nodes. The parent node is the hash of the binary sum of the two child nodes of its two child nodes.

Finally, I can create the root node, which is the hash of the binary sum of its two child nodes.

In [34]:
# create parent nodes
parent_nodes = []
for i in range(0, len(leaf_hashes), 2):
    parent_node = sha256(leaf_hashes[i] + leaf_hashes[i + 1])
    parent_nodes.append(parent_node)
print("Parent nodes:", [h.hex() for h in parent_nodes])

root_node = sha256(parent_nodes[0] + parent_nodes[1])
print("Root node:", root_node.hex())

Parent nodes: ['b17f891f734c9cf93a52bc49ab761c00515d576ed21f2142cbcd2359b67a0e55', 'b6faa6fe3f173b78bdd48633b14fa8ae7af47dfa2e9e84299535bae65a904f69']
Root node: 8a5f8d59a67a78a0979119eda1e2c326cf18df3c7b42235cc74e4e67f56b931d


we will later most likely work with `merkletreejs` in javascript. So let us write a simple wrapper with a similiar API.

In [35]:



class MerkleTree:
    def __init__(self, hash_func: Callable[[bytes], bytes] = sha256):
        """
        Initialize a Merkle tree with an optional hash function.

        Args:
            hash_func: A function that takes bytes and returns a hash digest as bytes.
        """
        self.hash_func = hash_func
        self.leaves: List[bytes] = []
        self.levels: List[List[bytes]] = []

    def add_leaf(self, data: str) -> None:
        """
        Add a leaf node to the tree.

        Args:
            data: Input string to be hashed and added as a leaf.
        """
        self.leaves.append(self.hash_func(data.encode()))

    def make_tree(self) -> None:
        """
        Construct the Merkle tree from the current leaves.
        """
        if not self.leaves:
            raise ValueError("No leaves to build the tree.")
        self.levels = [self.leaves.copy()]
        current = self.leaves
        while len(current) > 1:
            next_level: List[bytes] = []
            for i in range(0, len(current), 2):
                left = current[i]
                right = current[i + 1] if i + 1 < len(current) else left
                combined = self.hash_func(left + right)
                next_level.append(combined)
            self.levels.append(next_level)
            current = next_level

    def get_root(self) -> str:
        """
        Get the Merkle root as a hex string.

        Returns:
            Hex-encoded Merkle root.

        Raises:
            ValueError: If the tree has not been built yet.
        """
        if not self.levels:
            raise ValueError("Tree not built yet. Call make_tree() first.")
        return self.levels[-1][0].hex()

    def get_proof(self, index: int) -> List[Tuple[str, str]]:
        """
        Generate a Merkle proof for a leaf at the given index.

        Args:
            index: Index of the target leaf in the original list.

        Returns:
            A list of tuples (position, sibling_hash) where position is 'left' or 'right'.

        Raises:
            ValueError: If the tree has not been built.
        """
        if not self.levels:
            raise ValueError("Tree not built yet. Call make_tree() first.")
        proof: List[Tuple[str, str]] = []
        for level in self.levels[:-1]:
            sibling_index = index ^ 1
            if sibling_index < len(level):
                sibling_pos = "left" if sibling_index < index else "right"
                proof.append((sibling_pos, level[sibling_index].hex()))
            index //= 2
        return proof

    def verify_proof(
        self, proof: List[Tuple[str, str]], target_leaf: str, root: str
    ) -> bool:
        """
        Verify a Merkle proof against the given root.

        Args:
            proof: A list of (position, sibling_hash) tuples.
            target_leaf: Original input string for the leaf.
            root: Expected Merkle root (hex string).

        Returns:
            True if the proof is valid, False otherwise.
        """
        current_hash = self.hash_func(target_leaf.encode())
        for position, sibling_hash_hex in proof:
            sibling_hash = bytes.fromhex(sibling_hash_hex)
            if position == "left":
                current_hash = self.hash_func(sibling_hash + current_hash)
            else:
                current_hash = self.hash_func(current_hash + sibling_hash)
        return current_hash.hex() == root
    
    def print_tree(self, short: bool = True) -> None:
        """
        Print the Merkle tree level by level from root to leaves.

        Args:
            short: If True, abbreviates hashes to the first 8 chars.
        """
        if not self.levels:
            raise ValueError("Tree not built yet. Call make_tree() first.")
        
        def fmt(h: bytes) -> str:
            return h.hex()[:8] if short else h.hex()

        print("\nMerkle Tree:")
        total_levels = len(self.levels)
        for i, level in enumerate(reversed(self.levels)):
            level_index = total_levels - i - 1
            label = "Root" if i == 0 else ("Leaf" if i == total_levels - 1 else f"Level {level_index}")
            hashes = " | ".join(fmt(h) for h in level)
            print(f"{label}: {hashes}")


with this API we can once again try to create the tree.

In [36]:
# create a tree from the batch
tree = MerkleTree()
for batch in llmBatch:
    tree.add_leaf(str(batch))
tree.make_tree()
print("Merkle root:", tree.get_root())
proof = tree.get_proof(0)
print("Merkle proof for first leaf:", proof)
print("Proof verification:", tree.verify_proof(proof, str(llmBatch[0]), tree.get_root()))

Merkle root: 8a5f8d59a67a78a0979119eda1e2c326cf18df3c7b42235cc74e4e67f56b931d
Merkle proof for first leaf: [('right', '7d29b8b8bf00bbb427723f52923700e931b9dc25d8afc726dba7c59da6b645bd'), ('right', 'b6faa6fe3f173b78bdd48633b14fa8ae7af47dfa2e9e84299535bae65a904f69')]
Proof verification: True


In [37]:
tree.print_tree(short=True)


Merkle Tree:
Root: 8a5f8d59
Level 1: b17f891f | b6faa6fe
Leaf: e4ac0148 | 7d29b8b8 | edfe39aa | 4f8df7b3
