TL;DR torch-like implementation of convolutional layer blocks over binary trees; can be used to efficiently encode trees

💡 Idea

Convolution over binary tries lies between conventional CNNs used for images and graph-based CNNs. The constraint that each node in the binary tree has at most two neighbors allows the data to be formatted in a way that a 1-dim CNN can efficiently process while considering the tree’s structure. Such layers allows the structure of trees to be taken into account when encoding them, which simplifies the task of modelling dependency on them.

🧐 Why is BinaryTreeConvolution useful?

That the proposed convolutions are able to extract useful features can be verified by direct comparison. On the task of predicting the execution time of requests based on their plans, we can observe the following pattern - BTCNN extension makes the dependency approximation problem for FCNN easier.


📦 Setup

python -m pip install --upgrade pip
python3 -v venv venv
source venv/bin/activate
pip install -e .
pytest --cov=. --cov-report=term-missing

🚀 How To

How to create a Dataset / DataLoader, configure the architecture, run training and manage trained models is demonstrated in the notebook.


🧩 Interface


Our layers process objects using the following representation:

  • vertices - 3D tensor of shape [batch_size, n_channels, max_length_in_batch]
  • edges - 4D tensor of shape [batch_size, 1, max_length_in_batch, 3], where the last dimension contains three indices representing the node’s 1-hop neighborhood ([parent_id, left_child_id, right_child_id])
def forward(self, vertices: "Tensor", edges: "Tensor") -> "Tensor":

P.S. Currently implemented layers are: BinaryTreeActivation, BinaryTreeAdaptivePooling, BinaryTreeConv, BinaryTreeLayerNorm, BinaryTreeInstanceNorm

P.P.S. To work with this format, zero padding is used to handle a) missing children and b) aligning the tree lengths.


Since layers must always remember the structure behind the vertices (which is stored in edges), we decided to build module for layer stacking BinaryTreeSequential:

class BinaryTreeSequential(nn.Module):
    def forward(self, vertices: "Tensor", edges: "Tensor") -> "Tensor":
        for layer in self.layers:
            vertices = layer(vertices, edges)
        return vertices

By combining CNN block with FCNN, it is possible to solve prediction problems. In fact, the whole inference is broken down into two parts - encoding into a vector taking into account the tree structure (btcnn part), and then running a fully-connected network (fcnn part). This is put together in the BinaryTreeRegressor module:

class BinaryTreeRegressor(nn.Module):
    def forward(self, vertices: "Tensor", edges: "Tensor") -> "Tensor":
        return self.fcnn(self.btcnn(vertices=vertices, edges=edges))

🔢 Pipeline

Step 1. Vectorize the binary tree.
                [1.0, 1.0]
                 /       \
        [1.0, -1.0]     *None*
            /    \                 
  [-1.0, -1.0]   [1.0, 1.0]
Step 2. Add padding nodes for all incomplited nodes.
                [1.0, 1.0]
                 /       \
        [1.0, -1.0]   [0.0, 0.0]   # padding node
            /    \                 
  [-1.0, -1.0]   [1.0, 1.0]
Step 3. Construct tensors for vertices and edges using a tree traversal.
# vertices 
[[0, 0], [1.0, 1.0], [1.0, -1.0], [-1.0, -1.0], [1.0, 1.0]]

# edges in the form `[node_id, left_child_id, right_child_id]`
[[1, 2, 0], [2, 3, 4], [3, 0, 0], [4, 0, 0]]
Step 4. Convolve over the binary tree neighborhoods.

To account for the binary tree structure, we’ll convolve over the parent, left child, and right child nodes. This can be visualized as a filter moving across the tree structure:

       [θ_11, θ_12]
         /       \
[θ_21, θ_22]   [θ_31, θ_32]

🪄 Trick: the knowledge that each node has either zero or two children allows us to stretch the entire tree into a tensor of size 3 * tree_length, a one-dimensional CNN with a stride=3 can then capture the tree’s neighborhood, leveraging efficient convolution implementations while maintaining the tree’s geometry.

Step 5. Apply point-wise Activation and Adaptive Pooling.

After applying several convolutional layers (along with point-wise non-linear functions and normalization layers), we can use a adaptive pooling method to reduce the tree to a fixed-size vector.

                [a, e]
                /    \
            [b, f]   [e, k]
             /  \                 
        [c, g]  [d, h]

# after `AdaptiveMaxPooling` layer, the tree becomes a vector which size is equal to the number of channels in the tree
vector = [max(a, b, c, d, e), max(e, f, g, h, k)]

Normalisation Layer

To simplify the optimisation problem, it is useful to use normalisation layers within the convolution blocks. Among all the options tried by us, InstanceNormalisation worked best of all.



Batch Normalisation. Aggregation is performed across all trees in the batch.

               [10000]                       [100]
               /     \                      /     \
            [100]   [100]               [10]     [10]
            /  \     /  \               /  \     
         [10] [10] [10] [10],        [2]   [5] 

batch_vertices = [
    [[.0], [10000], [100], [100], [10], [10], [10], [10]],
    [[.0], [100],   [10],  [2],   [5],  [10], [.0], [.0]],
batch_edges = [[[1,2,5], [2,3,4], [3,0,0], [4,0,0], [5,0,0]]]
batch_vertices_mean = mean(batch_vertices)  # [5050, 105, 101, 7.5, 10, 5, 5]

The Batch Normalisation does not suit us in a similar way to any NN over sequence reason - objects in a batch may have representations responsible for completely different information at the same position. As a result, aggregation by objects in the batches will lead to the fact that we will mix, for example, statistics of tree roots of different heights (which, given the semantics of statistics, is inappropriate - characteristic orders of magnitude of cardinalities grow with tree height).

Layer Normalisation. Aggregation is performed independently for each tree.

                 /     \
        [1.0, -1.0]   *None*
            /  \                 
 [-1.0,-1.0]   [1.0,1.0]

tree_mean = mean([1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0])  # 0.25
tree_std = std([1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.,0 0., 0.])  # 0.9682458365518543

Instance Normalisation. Aggregation is performed independently for each tree and each channel.

                 /     \
        [1.0, -1.0]   *None*
            /  \                 
 [-1.0,-1.0]   [1.0,1.0]

tree_mean = [mean([1.0, 1.0, -1.0, 1.0]), mean([1.0, -1.0, -1.0, 1.0])]  # [0.5, .0]
tree_std = [std([1.0, 1.0, -1.0, 1.0]),  std([1.0, -1.0, -1.0, 1.0])]  # [0.8660254037844386, 1.0]

📝 Completed Example

Click on me if you're not afraid

First, a convolution with the filter is performed independently for each neighbourhood. An example of neighbourhood convolution on the root:

# tree
                 /     \
        [1.0, -1.0]   *None*
            /  \                 
 [-1.0,-1.0]   [1.0,1.0]

# filter
         /      \
[-1.0,-1.0]   [1.0,1.0]

# root's neighborhood convolution
                [1.0,1.0]                [1.0,-1.0]
                 /     \        *         /      \            =      [0.0]
        [1.0,-1.0]   [0.0,0.0]    [-1.0,-1.0]   [1.0,1.0]
# (1.0 * 1.0 + 1.0 * -1.0) + (1.0 * -1.0 + -1.0 * -1.0) + (0.0 * 1.0 + 0.0 * 1.0) = 0.0

In second, normalisation and activation layers are applied. In third, dynamic pooling layer maps tree to fixed-length vector. Considering the structure of the tree, the following happens to the tree throughout the process:

                 # tree                  # filter                # after Conv            # after Norm & ReLU    # after AdaptiveMaxPooling

                [1.0,1.0]                                             [0.0]                      [0.0]                                     
                 /     \                [1.0,-1.0]                   /     \                     /   \
        [1.0,-1.0]   *None*  *           /      \        ->       [6.0]     *None*  ->     [1.73]  *None*  ->  [1.73]
            /  \                [-1.0,-1.0]   [1.0,1.0]           /  \                       /  \                                       
 [-1.0,-1.0]   [1.0,1.0]                                     [0.0]   [0.0]              [0.0]  [0.0]

👁️⃤ Intuition. After normalizing and applying the ReLU activation, the left child of the root becomes prominent. This happens because its values closely match the filter weights. This prominence indicates the similarity of the substructure to the filter. When training multiple filters simultaneously and combining convolutional blocks, we begin to capture more complex structures, such as subtrees of height 2, 3, and beyond. BTCNN effectively identifies key substructures in the tree, and then a FCNN assesses their presence.