# What are merkle trees and how could I use them?

In this notebook I look into how merkle trees work and how I might use them.

Let me first create some batches that I want to hash.

In [None]:
import hashlib
import json
from typing import List, Tuple, Callable
from pprint import pprint

from pydantic import BaseModel
from datetime import datetime, timezone
from typing import List

In [172]:
class LlmBatchItem(BaseModel):
    id: int
    timestamp: datetime
    tokenCount: int
    wallet: str

def stable_serialize(batch: LlmBatchItem) -> str:
    """Serialize a dictionary to a stable JSON string.
    This ensures that the keys are sorted, which is important for consistent hashing.

    Args:
        obj): The dictionary to serialize.

    Returns:
        A stable JSON string representation of the dictionary.
    """
    return batch.model_dump_json()


def sha256(data: bytes) -> bytes:
    """Default SHA-256 hash function."""
    return hashlib.sha256(data).digest()

In [173]:
llmBatch = [
    LlmBatchItem(
        id=1,
        timestamp="2023-09-01T12:00:00Z",
        tokenCount=1024,
        wallet="0xUser1Address",
    ),
    LlmBatchItem(
        id=2,
        timestamp="2023-09-02T12:00:00Z",
        tokenCount=1023,
        wallet="0xUser2Address",
    ),
    LlmBatchItem(
        id=3,
        timestamp="2023-09-03T12:00:00Z",
        tokenCount=1022,
        wallet="0xUser1Address",
    ),
    LlmBatchItem(
        id=4,
        timestamp="2023-09-04T12:00:00Z",
        tokenCount=1021,
        wallet="0xUser4Address",
    ),
]

Now that I have my batches, I can create a merkle tree. In the merkle tree we first encode each of the batches into a byte string, then we hash each of these strings to create the leaf nodes of the tree. The hash once again producing a byte string.

In [174]:
leaf_hashes = [sha256(stable_serialize(batch).encode('utf-8')) for batch in llmBatch]
print("Leaf hashes:", [h.hex() for h in leaf_hashes])

Leaf hashes: ['5bc1a99e7e8903c3368d84e5812ef2406caaea6ff9f4086a7da9f24d20e6ab5a', 'f2cb983c8269e2cfe0d304644ab728d498cbae86fbc5d4c9d3080e85b6065872', 'ca8dcdb03db2b2808cf38d4ba0f71eddb71e214d65b107838241afc7a2edd388', '66f53003874c6aea844469e7a78e5f259e7c8d2fdc7477089a12ec26d44d8de6']


Now that I have the four leafs I can create the parent nodes. The parent node is the hash of the binary sum of the two child nodes of its two child nodes.

Finally, I can create the root node, which is the hash of the binary sum of its two child nodes.

In [175]:
# create parent nodes
parent_nodes = []
for i in range(0, len(leaf_hashes), 2):
    parent_node = sha256(leaf_hashes[i] + leaf_hashes[i + 1])
    parent_nodes.append(parent_node)
print("Parent nodes:", [h.hex() for h in parent_nodes])

root_node = sha256(parent_nodes[0] + parent_nodes[1])
print("Root node:", root_node.hex())

Parent nodes: ['e867d7ea3b37ae8faf4387e507106a54837485a24990ce3aed26ea13488f21fe', '1cb983666708c3054522879fcb3416efaf451eed80f61795c1ee42c5f4f6f116']
Root node: afd62e33c636dbffa7a7ffb0553090225f1d206a6a535f99d4879dacfa338bff


we will later most likely work with `merkletreejs` in javascript. So let us write a simple wrapper with a similiar API.

In [176]:
class MerkleTree:
    def __init__(self, hash_func: Callable[[bytes], bytes] = sha256):
        """
        Initialize a Merkle tree with an optional hash function.

        Args:
            hash_func: A function that takes bytes and returns a hash digest as bytes.
        """
        self.hash_func = hash_func
        self.leaves: List[bytes] = []
        self.levels: List[List[bytes]] = []

    def add_leaf(self, data: str) -> None:
        """
        Add a leaf node to the tree.

        Args:
            data: Input string to be hashed and added as a leaf.
        """
        self.leaves.append(self.hash_func(data.encode()))

    def make_tree(self) -> None:
        """
        Construct the Merkle tree from the current leaves.
        """
        if not self.leaves:
            raise ValueError("No leaves to build the tree.")
        self.levels = [self.leaves.copy()]
        current = self.leaves
        while len(current) > 1:
            next_level: List[bytes] = []
            for i in range(0, len(current), 2):
                left = current[i]
                right = current[i + 1] if i + 1 < len(current) else left
                combined = self.hash_func(left + right)
                next_level.append(combined)
            self.levels.append(next_level)
            current = next_level

    def get_root(self) -> str:
        """
        Get the Merkle root as a hex string.

        Returns:
            Hex-encoded Merkle root.

        Raises:
            ValueError: If the tree has not been built yet.
        """
        if not self.levels:
            raise ValueError("Tree not built yet. Call make_tree() first.")
        return self.levels[-1][0].hex()

    def get_proof(self, index: int) -> List[Tuple[str, str]]:
        """
        Generate a Merkle proof for a leaf at the given index.

        Args:
            index: Index of the target leaf in the original list.

        Returns:
            A list of tuples (position, sibling_hash) where position is 'left' or 'right'.

        Raises:
            ValueError: If the tree has not been built.
        """
        if not self.levels:
            raise ValueError("Tree not built yet. Call make_tree() first.")
        proof: List[Tuple[str, str]] = []
        for level in self.levels[:-1]:
            sibling_index = index ^ 1
            if sibling_index < len(level):
                sibling_pos = "left" if sibling_index < index else "right"
                proof.append((sibling_pos, level[sibling_index].hex()))
            index //= 2
        return proof

    def verify_proof(
        self, proof: List[Tuple[str, str]], target_leaf: str, root: str
    ) -> bool:
        """
        Verify a Merkle proof against the given root.

        Args:
            proof: A list of (position, sibling_hash) tuples.
            target_leaf: Original input string for the leaf.
            root: Expected Merkle root (hex string).

        Returns:
            True if the proof is valid, False otherwise.
        """
        current_hash = self.hash_func(target_leaf.encode())
        for position, sibling_hash_hex in proof:
            sibling_hash = bytes.fromhex(sibling_hash_hex)
            if position == "left":
                current_hash = self.hash_func(sibling_hash + current_hash)
            else:
                current_hash = self.hash_func(current_hash + sibling_hash)
        return current_hash.hex() == root
    
    def print_tree(self, short: bool = True) -> None:
        """
        Print the Merkle tree level by level from root to leaves.

        Args:
            short: If True, abbreviates hashes to the first 8 chars.
        """
        if not self.levels:
            raise ValueError("Tree not built yet. Call make_tree() first.")
        
        def fmt(h: bytes) -> str:
            return h.hex()[:8] if short else h.hex()

        print("\nMerkle Tree:")
        total_levels = len(self.levels)
        for i, level in enumerate(reversed(self.levels)):
            level_index = total_levels - i - 1
            label = "Root" if i == 0 else ("Leaf" if i == total_levels - 1 else f"Level {level_index}")
            hashes = " | ".join(fmt(h) for h in level)
            print(f"{label}: {hashes}")


with this API we can once again try to create the tree.

In [177]:
# create a tree from the batch
tree = MerkleTree()
for batch in llmBatch:
    tree.add_leaf(stable_serialize(batch))
tree.make_tree()
print("Merkle root:", tree.get_root())

Merkle root: afd62e33c636dbffa7a7ffb0553090225f1d206a6a535f99d4879dacfa338bff


In [178]:
tree.print_tree(short=True)


Merkle Tree:
Root: afd62e33
Level 1: e867d7ea | 1cb98366
Leaf: 5bc1a99e | f2cb983c | ca8dcdb0 | 66f53003


Now I understand that every so often I will commit the Merkle root to the blockchain. But the user might want to know and verify if their leaves are properly included into the tree. So I need to provide a way to verify that a leaf is part of the tree. For this the user needs to get the proof for their leafs. For each leaf the user needs to see:

{
  "leaf": serialized,
  "proof": list of tuples,
  "root": hex root
}

In [179]:
llmBatch

[LlmBatchItem(id=1, timestamp=datetime.datetime(2023, 9, 1, 12, 0, tzinfo=TzInfo(UTC)), tokenCount=1024, wallet='0xUser1Address'), LlmBatchItem(id=2, timestamp=datetime.datetime(2023, 9, 2, 12, 0, tzinfo=TzInfo(UTC)), tokenCount=1023, wallet='0xUser2Address'), LlmBatchItem(id=3, timestamp=datetime.datetime(2023, 9, 3, 12, 0, tzinfo=TzInfo(UTC)), tokenCount=1022, wallet='0xUser1Address'), LlmBatchItem(id=4, timestamp=datetime.datetime(2023, 9, 4, 12, 0, tzinfo=TzInfo(UTC)), tokenCount=1021, wallet='0xUser4Address')]

In [180]:
wallet = "0xUser1Address"

# get all leaves for the user
user_leaves = []
for ii, batch in enumerate(llmBatch):
    if batch.wallet == wallet:

        leaf_serialized = stable_serialize(batch)
        proof = tree.get_proof(ii)
        user_leaves.append({
            "leaf": leaf_serialized,
            "proof": proof,
            "root": tree.get_root()
        })

pprint("User leaves and proofs:")
pprint(user_leaves)

'User leaves and proofs:'
[{'leaf': '{"id":1,"timestamp":"2023-09-01T12:00:00Z","tokenCount":1024,"wallet":"0xUser1Address"}',
  'proof': [('right',
             'f2cb983c8269e2cfe0d304644ab728d498cbae86fbc5d4c9d3080e85b6065872'),
            ('right',
             '1cb983666708c3054522879fcb3416efaf451eed80f61795c1ee42c5f4f6f116')],
  'root': 'afd62e33c636dbffa7a7ffb0553090225f1d206a6a535f99d4879dacfa338bff'},
 {'leaf': '{"id":3,"timestamp":"2023-09-03T12:00:00Z","tokenCount":1022,"wallet":"0xUser1Address"}',
  'proof': [('right',
             '66f53003874c6aea844469e7a78e5f259e7c8d2fdc7477089a12ec26d44d8de6'),
            ('left',
             'e867d7ea3b37ae8faf4387e507106a54837485a24990ce3aed26ea13488f21fe')],
  'root': 'afd62e33c636dbffa7a7ffb0553090225f1d206a6a535f99d4879dacfa338bff'}]


now let us use the proofs of each leaf to verify that the leaf is part of the tree. 

In [181]:
for leaf in user_leaves:
    is_valid = tree.verify_proof(
        proof = leaf["proof"],
        target_leaf = leaf["leaf"],
        root = leaf["root"]
    )
    print(f"Leaf: {leaf['leaf'][:30]}... is valid: {is_valid}")

Leaf: {"id":1,"timestamp":"2023-09-0... is valid: True
Leaf: {"id":3,"timestamp":"2023-09-0... is valid: True


# Emulate submission

Now let us see how it would look like. For Alice. She has a wallet and she does a call. Then we create a batch and send back the leaf. The proof cannot be send back yet as the tree is not quite finished.

In [None]:
remoteLLMBatch:list[LlmBatchItem] 

In [None]:
def send_llm_call(wallet: str, prompt:str):
    """
    Simulate sending an LLM call with a wallet address and prompt.
    
    Args:
        wallet: The user's wallet address.
        prompt: The prompt for the LLM call.
    
    Returns:
        A dictionary containing the call details.
    """
    item = LlmBatchItem(
        id=len(remoteLLMBatch) + 1,
        timestamp=datetime.now(timezone.utc),
        tokenCount=1024,  # Example token count
        wallet=wallet
    )
    remoteLLMBatch.append(item)

    # now create the hash for the leaf
    prompt_serialized = stable_serialize(item)
    leaf_hash = sha256(prompt_serialized.encode('utf-8'))
    return {
        "leaf": item.model_dump(),
        "hash": leaf_hash.hex(),
    }
    )

    # 
    return {
        "user": "Alice",
        "prompt": prompt,
        "model": "gpt-4-turbo",
        "wallet": wallet
    }

In [None]:
# send the call
wallet = "0xUser1Address"
prompt = "Analyze sentiment of user feedback"

# what do we need as answer
tokenCount = 1024
