# Prepare data

This notebook is used to generate training/testing data that consists of
* features : binary trees
* targets  : inorder traversals of binary trees

Each binary tree is a directed acyclic graph (DAG) that is being exported to the PyG-compatible format.

In [2]:
from typing import Optional

### Tree functions

In [3]:
class TreeNode:
    def __init__(self, value):
        self.value = value
        self.left  = None
        self.right = None

In [4]:
import random

def list_to_binary_tree(nums: list[int], balanced=True) -> Optional[TreeNode]:
    """ Convert a list of numbers to a binary tree. """
    
    if not nums : return None

    if balanced : mid = len(nums) // 2
    else        : mid = random.randint(0, len(nums)-1)

    root = TreeNode(nums[mid])
    root.left  = list_to_binary_tree(nums[:mid], balanced)
    root.right = list_to_binary_tree(nums[mid+1:], balanced)

    return root

In [5]:
# Different input graphs can be fed into GNN:
# 1. BSTs (i.e., representing a sorted list)
# 2. Random trees
# 3. Balanced vs unbalcanced trees

# Add:
# - We can also shuffle node indices to make the task harder
# - A model trained on a BST can be tested on a random tree
# - A model trained with original indices can be tested on shuffled indices
# - A model trained on balanced trees can be tested on unbalanced trees and vice versa

In [6]:
def inorder_traversal(node, nodes, edges):
    """ 
        This function performs inorder traversal of a binary tree
        and stores its nodes and edges. Edges are tuples of the 
        form (parent, child, direction) where 'parent' and 'child' 
        are node indices in the 'nodes' list and 'direction' is 
        either 0 (left) or 1 (right).
    """

    if node.left  : _, _, l_index = inorder_traversal(node.left, nodes, edges)
    
    n_index = len(nodes)      # index of the current node in 'nodes'
    nodes.append(node.value)

    if node.right : _, _, r_index = inorder_traversal(node.right, nodes, edges)
    
    if node.left  : edges.append((n_index, l_index, 0))
    if node.right : edges.append((n_index, r_index, 1))

    return nodes, edges, n_index

### BST to PyG graph

In [10]:
import torch
from torch_geometric.data import Data

In [11]:
def shuffle_graph(nodes, edges):
    """ This function shuffles node indices and returns the new
        node indices and the corresponding edge indices. """

    n = len(nodes)
    perm = torch.randperm(n)
    sh_nodes = nodes[perm]
    sh_edges = edges.clone()
    for i in range(n) : sh_edges[edges==perm[i]] = i

    return sh_nodes, sh_edges

In [27]:
def binary_tree_to_pyg(tree: TreeNode, shuffle: bool = False) -> Data:

    nodes, edges, _ = inorder_traversal(tree, [], [])
    
    nodes_x = torch.tensor(nodes) 
    nodes_y = nodes_x.clone()
    edges = torch.tensor(edges)
    edge_index, edge_attr = edges[:,0:2].clone(), edges[:,2:3].clone()

    if shuffle : nodes_x, edge_index = shuffle_graph(nodes_x, edge_index)

    data = Data(x=nodes_x, edge_index=edge_index, edge_attr=edge_attr, y=nodes_y)

    return data

In [28]:
from torch_geometric.data import InMemoryDataset

In [29]:
dataset_size = 5
min_value, max_value = -1000, 1000
min_nodes, max_nodes = 10, 100

dataset_name = "balanced-shuffled"
path_dataset = f"./data/{dataset_name}"


In [30]:
# hyperparameters: balanced, shuffle nodes
dataset = []
for _ in range(dataset_size):
    
    n_nodes = random.randint(min_nodes, max_nodes)
    nums = random.sample(range(min_value, max_value + 1), n_nodes)
    binary_tree = list_to_binary_tree(nums, balanced=True)
    sample = binary_tree_to_pyg(binary_tree, shuffle=True)
    dataset.append(sample)

In [31]:
dataset

[Data(x=[66], edge_index=[65, 2], edge_attr=[65, 1], y=[66]),
 Data(x=[47], edge_index=[46, 2], edge_attr=[46, 1], y=[47]),
 Data(x=[11], edge_index=[10, 2], edge_attr=[10, 1], y=[11]),
 Data(x=[91], edge_index=[90, 2], edge_attr=[90, 1], y=[91]),
 Data(x=[17], edge_index=[16, 2], edge_attr=[16, 1], y=[17])]