In [1]:
from treeviz import list2tree, tree2ascii, ascii_draw, TreeNode

## Dataset description

In [1]:
# https://arxiv.org/html/2402.11917v2

  - In our experimental setup, we generate training samples by generating binary trees $
  T = (V, E)$ uniformly at random from the set of all trees with 16 nodes, i.e. $|V| = 16$
  - For each tree, a leaf node is randomly selected as the target node.


  - The training dataset consists of 150,000 generated trees. 
  - The edge lists of these trees are shuffled to prevent the model from learning simple heuristics and encourage structural understanding of trees
  - For simplification, our tokenization distinguishes tokens representing source and target nodes of each edge, such as [15] and [→15].


In [12]:
import random




def random_binary_tree(n):
    node_names = list(range(n))
    random.shuffle(node_names)

    def _gen(n):
        if n == 0:
            return None
    
        node_val = node_names.pop()
    
        if n == 1:
            return TreeNode(val=node_val)
        else:
            left_size = random.randint(0, n-1)
            right_size = n - 1 - left_size
            left_subtree = _gen(left_size)
            right_subtree = _gen(right_size)
            return TreeNode(val=node_val, left=left_subtree, right=right_subtree)
        
    return _gen(n)




# Generate a random binary tree with 16 nodes
tree = random_binary_tree(16)


In [199]:
def inorder_knuth(root):
    # print('Tribute to Knuth.')
    # print("(1969). Fundamental Algorithms. The Art of Computer Programming. Vol. 1")

    res = []
    def Visit(P):
        if P.left is None and P.right is None:
            res.append(P.val)

    # T1. [Initialize.] Set stack A empty, and set link variable P <- T
    A = []
    P = root
    while True:
        # T2. [P = NULL?] if P == NULL, go to step T4.
        if P is None:
            # T4. [P <= Stack] If stack A is empty, algorithm terminates
            if len(A)==0: break
            # Otherwise, set P <= A
            P = A.pop()

            # T5. [Visit P.] Visit NODE(P).
            Visit(P)
            # Then, set P <- RLINK(P) and return to T2.
            P = P.right
        else:
            # T3. [Stack <= P.] (Now P points to a nonempty binary tree that is to be traversed.)
            # Set A <= P; That is, push the P onto the stack A
            A.append(P)
            # Then, set P <- LLINK(P) and return to T2.
            P = P.left
    return res

def get_leaf_vals(root): return inorder_knuth(root)

In [172]:
tree = random_binary_tree(4)
leaf_vals = inorder_knuth(tree)
print(f'{leaf_vals=}')
tree

leaf_vals=[2, 3]


 0
 |
 +---+
     |
     1
     |
   +-+-+
   |   |
   2   3

In [173]:
def tree_to_edges(root):
    """ Convert a binary tree to a list of edges. """
    edges = []
    def traverse(node):
        if node.left:
            edges.append((node.val, node.left.val))
            traverse(node.left)
        if node.right:
            edges.append((node.val, node.right.val))
            traverse(node.right)
    traverse(root)
    return edges

def edges_to_tree(edges):
    """ Create a binary tree from a list of edges. """
    if not edges:
        return None
    
    # Create nodes and find the root
    nodes = {}
    children = set()
    
    for parent, child in edges:
        if parent not in nodes:
            nodes[parent] = TreeNode(parent)
        if child not in nodes:
            nodes[child] = TreeNode(child)
        if child in nodes:
            if nodes[parent].left is None:
                nodes[parent].left = nodes[child]
            else:
                nodes[parent].right = nodes[child]
        children.add(child)
        
    # Find the root (node not in children set)
    root_node = next(node for node in nodes if node not in children)
    return nodes[root_node]


In [177]:
edges = tree_to_edges(tree)
print(f'{edges=}')

edges_to_tree(edges)

edges=[(0, 1), (1, 2), (1, 3)]


       0
       |
   +---+
   |
   1
   |
 +-+-+
 |   |
 2   3

In [182]:
import collections
children = collections.defaultdict(list)
parent = dict()

for from_node, to_node in edges:
    children[from_node].append(to_node)
    parent[to_node] = from_node

children

defaultdict(list, {0: [1], 1: [2, 3]})

In [183]:
parent

{1: 0, 2: 1, 3: 1}

In [196]:
def get_parent_dict(root):
    parent = dict()
    stack = [(root)]
    
    while stack:
        root = stack.pop()
        if root is None: continue
        
        if root.left:
            parent[root.left.val] = root.val
            stack.append(root.left)
        if root.right:
            parent[root.right.val] = root.val
            stack.append(root.right)
            
    return parent


def from_root_to_node(tree, node_val):
    parent = get_parent_dict(tree)
    path = []
    
    while node_val is not None:
        path.append(node_val)
        node_val = parent.get(node_val)
        
    path = list(reversed(path))
    return path



In [197]:
print(tree)

from_root_to_node(tree, 2)

 0
 |
 +---+
     |
     1
     |
   +-+-+
   |   |
   2   3


[0, 1, 2]

In [208]:
([1,2,3])

1

In [245]:


def edges_to_tokens(edges): return [token for from_node, to_node in edges for token in (str(from_node),f'→{to_node}', ',')][:-1]


def tree2tokens(tree):
    edges = tree_to_edges(tree)
    random.shuffle(edges)
    
    edge_tokens = edges_to_tokens(edges)
    root = tree.val
    leafs = get_leaf_vals(tree)

    goal = random.choice(leafs)

    input_tokens = [*edge_tokens, '|', goal, ':', root]
    
    target_path = from_root_to_node(tree, goal)
    target_tokens = [f'→{node_val}' for node_val in target_path]
    
    return input_tokens, target_tokens

In [250]:
N = 5

def generate_datapoint(N=N):
    tree = random_binary_tree(N)
    return tree2tokens(tree)


def input_tokens_to_tree(input_tokens):
    idx = input_tokens.index('|')
    edge_tokens = ''.join(input_tokens[:idx])
    edges = [edge.split('→') for edge in edge_tokens.split(',')]
    tree = edges_to_tree(edges)
    return tree

In [263]:
input_tokens, target_tokens = generate_datapoint(16)

In [264]:
input_tokens

['8',
 '→10',
 ',',
 '9',
 '→15',
 ',',
 '0',
 '→13',
 ',',
 '1',
 '→9',
 ',',
 '3',
 '→4',
 ',',
 '15',
 '→12',
 ',',
 '3',
 '→1',
 ',',
 '6',
 '→5',
 ',',
 '1',
 '→8',
 ',',
 '12',
 '→6',
 ',',
 '8',
 '→2',
 ',',
 '6',
 '→11',
 ',',
 '9',
 '→0',
 ',',
 '4',
 '→7',
 ',',
 '0',
 '→14',
 '|',
 10,
 ':',
 3]

In [265]:
target_tokens

['→3', '→1', '→8', '→10']

In [266]:
input_tokens_to_tree(input_tokens)

        3
        |
     +--+-----------------------------+
     |                                |
     4                                1
     |                                |
  +--+                    +-----------+-----+
  |                       |                 |
  7                       9                 8
                          |                 |
                       +--+-----+        +--+--+
                       |        |        |     |
                      15        0       10     2
                       |        |
                    +--+     +--+--+
                    |        |     |
                   12       13    14
                    |
              +-----+
              |
              6
              |
           +--+--+
           |     |
           5    11