# Bayesian CART

In [1]:
import jax
import pydot
import numpy as np
import matplotlib.pyplot as plt

In [2]:
base_tree = {1: 1}
tree = {0: {1:1, 2:2}}

In [3]:
jax.tree_leaves(tree)

[1, 2]

In [4]:
from numpy.random import rand, seed, randint

In [33]:
def partition(leaf_value, alpha):
    """
    Partition the leafs of a tree. If the leaf has value False,
    it cannot be split (it has already been attempted). Else, we
    split with probability alpha.
    """
    global ix
    if leaf_value == False:
        return False
    v = rand()
    if v < alpha:
        data = {ix + 1: True, ix + 2: True}
        ix = ix + 2
        return data
    else:
        return False

In [34]:
alpha = 0.6

In [36]:
seed(3141)
ix = 1
base_tree = {1: 1}
next_tree = jax.tree_map(lambda leaf: partition(leaf, alpha), base_tree)
print(next_tree)

{1: {2: True, 3: True}}


In [37]:
seed(3141)
next_tree = jax.tree_map(lambda leaf: partition(leaf, alpha), next_tree)
print(next_tree)

{1: {2: {4: True, 5: True}, 3: False}}


In [38]:
seed(3141)
next_tree = jax.tree_map(lambda leaf: partition(leaf, alpha), next_tree)
print(next_tree)

{1: {2: {4: {6: True, 7: True}, 5: False}, 3: False}}


In [10]:
from functools import partial

In [46]:
def sample_tree(alpha):
    """
    True: You are free to split
    False: You cannot split any further
    """
#     nid = partial(randint, low=0, high=1000)
    tree = {ix: True}
    while any(jax.tree_leaves(tree)):
        tree = jax.tree_map(lambda leaf: partition(leaf, alpha), tree)
    return tree

In [103]:
seed(31415926)
ix = 1
tree = sample_tree(0.5)
tree

{1: {2: False,
  3: {4: {6: {10: False, 11: False}, 7: {12: False, 13: False}},
   5: {8: {14: {16: False,
      17: {18: {20: False, 21: False}, 19: {22: False, 23: False}}},
     15: False},
    9: False}}}}

In [105]:
seed(31)
ix = 1
tree = sample_tree(0.5)
tree

{1: {2: False, 3: False}}

In [106]:
def draw(parent_name, child_name):
    edge = pydot.Edge(parent_name, child_name)
    graph.add_edge(edge)

def visit(node, parent=None):
    for k,v in node.items():
        if isinstance(v, dict):
            # We start with the root node whose parent is None
            # we don"t want to graph the None node
            if parent:
                draw(parent, k)
            visit(v, k)
        else:
            draw(parent, k)
            # drawing the label using a distinct name
            draw(k, f"Stop({k})")

graph = pydot.Dot(graph_type="graph")
visit(tree)
graph.write_png("example1_graph.png")

## Example of a prior tree with $\alpha=0.5$

![](example1_graph.png)

In [107]:
!open .

## References
* https://stackoverflow.com/questions/13688410/dictionary-object-to-decision-tree-in-pydot