In [1]:
import sys; sys.path.append('..')

In [2]:
from src.tree import list2tree, tree2ascii, ascii_draw, TreeNode
from src.tree_dataset import random_binary_tree, get_leaf_vals, tree_to_edges, edges_to_tree

## Dataset description

In [3]:
# https://arxiv.org/html/2402.11917v2

  - In our experimental setup, we generate training samples by generating binary trees $
  T = (V, E)$ uniformly at random from the set of all trees with 16 nodes, i.e. $|V| = 16$
  - For each tree, a leaf node is randomly selected as the target node.


  - The training dataset consists of 150,000 generated trees. 
  - The edge lists of these trees are shuffled to prevent the model from learning simple heuristics and encourage structural understanding of trees
  - For simplification, our tokenization distinguishes tokens representing source and target nodes of each edge, such as [15] and [→15].


In [4]:
tree = random_binary_tree(16)
tree

                                   6
                                   |
  +--------------------------------+--+
  |                                   |
  3                                  15
  |                                   |
  +-----+                             +--------+
        |                                      |
        8                                     10
        |                                      |
     +--+--------+                       +-----+
     |           |                       |
     0          14                       9
                 |                       |
           +-----+--------------+        +--+
           |                    |           |
           2                    1           4
           |                    |
           +--+        +--------+
              |        |
              5       11
                       |
                    +--+-----+
                    |        |
                   12       13
                  

In [5]:
tree = random_binary_tree(4)
leaf_vals = get_leaf_vals(tree)
print(f'{leaf_vals=}')
tree

leaf_vals=[2, 1]


     0
     |
 +---+-+
 |     |
 3     1
 |
 +-+
   |
   2

In [6]:
edges = tree_to_edges(tree)
print(f'{edges=}')

edges_to_tree(edges)

edges=[(0, 3), (3, 2), (0, 1)]


     0
     |
   +-+-+
   |   |
   3   1
   |
 +-+
 |
 2

In [7]:
from src.tree_dataset import generate_datapoint, input_tokens_to_tree, TreeDataset

In [8]:
input_tokens, target_tokens = generate_datapoint(16)

In [9]:
input_tokens

['15',
 '→2',
 ',',
 '9',
 '→10',
 ',',
 '13',
 '→1',
 ',',
 '9',
 '→6',
 ',',
 '14',
 '→0',
 ',',
 '14',
 '→5',
 ',',
 '5',
 '→9',
 ',',
 '3',
 '→15',
 ',',
 '7',
 '→11',
 ',',
 '1',
 '→8',
 ',',
 '13',
 '→4',
 ',',
 '6',
 '→7',
 ',',
 '0',
 '→13',
 ',',
 '1',
 '→3',
 ',',
 '10',
 '→12',
 '|',
 4,
 ':',
 14]

In [10]:
target_tokens

['→14', '→0', '→13', '→4']

In [11]:
input_tokens_to_tree(input_tokens)

                         14
                          |
                       +--+--------------------+
                       |                       |
                       0                       5
                       |                       |
                 +-----+           +-----------+
                 |                 |
                13                 9
                 |                 |
     +-----------+--+           +--+--------+
     |              |           |           |
     1              4          10           6
     |                          |           |
  +--+--------+              +--+        +--+
  |           |              |           |
  8           3             12           7
              |                          |
           +--+                       +--+
           |                          |
          15                         11
           |
        +--+
        |
        2

In [12]:
%%time

for _ in range(10000):
    input_tokens, target_tokens = generate_datapoint(16)

CPU times: user 229 ms, sys: 1.42 ms, total: 231 ms
Wall time: 231 ms


In [13]:
from torch.utils.data import IterableDataset, DataLoader

In [14]:
dataset = TreeDataset(5)

dataloader = DataLoader(dataset, batch_size=8)

In [15]:
for i in (dataset):
    break

In [16]:
i

{'input_idx': tensor([ 6,  5,  1, 10,  9,  1,  4, 11,  1,  6, 13,  2, 12,  3,  6,  7, 13,  0,
          0,  0,  0,  0,  0,  0]),
 'task_mask': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0.])}

In [17]:
i['input_idx']

tensor([ 6,  5,  1, 10,  9,  1,  4, 11,  1,  6, 13,  2, 12,  3,  6,  7, 13,  0,
         0,  0,  0,  0,  0,  0])

In [18]:
for batch in dataloader:
    break

In [19]:
batch['input_idx']

tensor([[ 4, 13,  1,  6,  5,  1, 12, 11,  1,  6,  9,  2, 10,  3,  6,  7,  5, 13,
         11,  0,  0,  0,  0,  0],
        [10, 13,  1,  4,  9,  1, 12,  7,  1, 10,  5,  2,  8,  3, 10, 11,  5,  9,
          0,  0,  0,  0,  0,  0],
        [ 8, 11,  1,  4,  9,  1,  6, 13,  1,  4,  7,  2, 12,  3,  4,  5,  7, 13,
          0,  0,  0,  0,  0,  0],
        [ 8,  7,  1, 10,  9,  1, 12,  5,  1, 12, 11,  2,  6,  3, 12, 13, 11,  9,
          7,  0,  0,  0,  0,  0],
        [ 6,  5,  1,  6, 13,  1, 12,  9,  1,  8, 11,  2,  4,  3,  6,  7,  5,  0,
          0,  0,  0,  0,  0,  0],
        [ 4, 11,  1,  6,  9,  1, 12,  5,  1, 10,  7,  2,  8,  3, 12, 13,  5, 11,
          7,  9,  0,  0,  0,  0],
        [ 4,  9,  1,  6,  5,  1, 12,  7,  1, 10, 13,  2,  8,  3, 10, 11, 13,  7,
          5,  9,  0,  0,  0,  0],
        [ 4, 11,  1,  8, 13,  1, 12,  7,  1,  6,  5,  2, 10,  3,  8,  9, 13,  7,
          5, 11,  0,  0,  0,  0]])