# Prepare BST data

This notebook is used to generate training/testing data that consists of
* features : binary search trees (BSTs)
* targets  : inorder traversals of the BSTs

Each BST is a directed acyclic graph (DAG) that is being exported to the PyG-compatible format.

In [7]:
from typing import Optional

### Tree functions

In [6]:
class TreeNode:
    def __init__(self, value):
        self.value = value
        self.left  = None
        self.right = None

In [113]:
def list_to_binary_tree(nums: list[int]) -> Optional[TreeNode]:
    if not nums : return None

    mid = len(nums) // 2
    root = TreeNode(nums[mid])
    root.left  = sorted_list_to_bst(nums[:mid])
    root.right = sorted_list_to_bst(nums[mid+1:])

    return root

In [16]:
# Different input graphs can be fed into GNN:
# 1. BSTs (i.e., representing a sorted list)
# 2. Random trees

# Add:
# - We can also shuffle node indices to make the task harder
# - A model trained on a BST can be tested on a random tree
# - A model trained with original indices can be tested on shuffled indices

In [58]:
def inorder_traversal(node, nodes, edges):
    """ This function performs inorder traversal of a binary tree
        and stores its nodes and edges. Edges are tuples of the 
        form (parent, child, direction) where 'parent' and 'child' 
        are node indices in the 'nodes' list and 'direction' is 
        either 0 (left) or 1 (right). """

    if node.left  : _, _, l_index = inorder_traversal(node.left, nodes, edges)
    
    n_index = len(nodes)      # index of the current node in 'nodes'
    nodes.append(node.value)

    if node.right : _, _, r_index = inorder_traversal(node.right, nodes, edges)
    
    if node.left  : edges.append((n_index, l_index, 0))
    if node.right : edges.append((n_index, r_index, 1))

    return nodes, edges, n_index

In [115]:
# Example usage
nums = [1, 2, 3, 4, 5, 6, 7]
bst  = list_to_binary_tree(sorted(nums))
nodes, edges, _ = inorder_traversal(bst, [], [])
nodes, edges

([1, 2, 3, 4, 5, 6, 7],
 [(1, 0, 0), (1, 2, 1), (5, 4, 0), (5, 6, 1), (3, 1, 0), (3, 5, 1)])

### BST to PyG graph

In [105]:
import torch
from torch_geometric.data import Data

In [106]:
def shuffle_graph(nodes, edges):
    """ This function shuffles node indices and returns the new
        node indices and the corresponding edge indices. """

    n = len(nodes)
    perm = torch.randperm(n)
    sh_nodes = nodes_x[perm]
    sh_edges = edges.clone()
    for i in range(n) : sh_edges[edges==perm[i]] = i

    return sh_nodes, sh_edges

In [112]:
def convert_bst_to_pyg(bst: TreeNode, shuffle: bool = False) -> Data:

    nodes, edges, _ = inorder_traversal(bst, [], [])
    
    nodes_x = torch.tensor(nodes) 
    nodes_y = nodes_x.clone()
    edges = torch.tensor(edges)
    edge_index, edge_attr = edges[:,0:2].clone(), edges[:,2:3].clone()

    if shuffle : nodes_x, edge_index = shuffle_graph(nodes_x, edge_index)

    data = Data(x=nodes_x, edge_index=edge_index, edge_attr=edge_attr, y=nodes_y)

    return data

In [116]:
nums = [1, 2, 3, 4, 5, 6, 7]
bst  = list_to_binary_tree(nums)
data = convert_bst_to_pyg(bst, shuffle=True)

In [118]:
data.x, data.edge_index, data.edge_attr, data.y

(tensor([1, 7, 4, 3, 6, 2, 5]),
 tensor([[5, 0],
         [5, 3],
         [4, 6],
         [4, 1],
         [2, 5],
         [2, 4]]),
 tensor([[0],
         [1],
         [0],
         [1],
         [0],
         [1]]),
 tensor([1, 2, 3, 4, 5, 6, 7]))