# Bayesian CART

In [1]:
import jax
import numpy as np
import matplotlib.pyplot as plt

In [2]:
base_tree = {1: 1}
tree = {0: {1:1, 2:2}}

In [3]:
jax.tree_leaves(tree)

[1, 2]

In [4]:
from numpy.random import rand, seed, randint

In [5]:
def partition(leaf_value, alpha):
    if leaf_value == False:
        return False
    v = rand()
    if v < alpha:
        return {randint(0, 1000): True, randint(0, 1000): True}
    else:
        return False

In [6]:
alpha = 0.6

In [7]:
seed(3141)
next_tree = jax.tree_map(lambda leaf: partition(leaf, alpha), base_tree)
print(next_tree)

{1: {74: True, 790: True}}


In [8]:
seed(3141)
next_tree = jax.tree_map(lambda leaf: partition(leaf, alpha), next_tree)
print(next_tree)

{1: {74: {74: True, 790: True}, 790: {725: True, 217: True}}}


In [9]:
seed(3141)
next_tree = jax.tree_map(lambda leaf: partition(leaf, alpha), next_tree)
print(next_tree)

{1: {74: {74: {74: True, 790: True}, 790: {725: True, 217: True}}, 790: {217: False, 725: {546: True, 520: True}}}}


In [10]:
from functools import partial

In [11]:
def sample_tree(alpha):
    """
    True: You are free to split
    False: You cannot split any further
    """
    nid = partial(randint, low=0, high=1000)
    tree = {nid(): True}
    while any(jax.tree_leaves(tree)):
        tree = jax.tree_map(lambda leaf: partition(leaf, alpha), tree)
    return tree

In [12]:
seed(3141)
sample_tree = sample_tree(0.5)
sample_tree

{360: {507: {217: False,
   629: {44: False,
    690: {484: {775: {83: False, 680: False},
      994: {138: False,
       444: {29: {409: {793: {69: {529: {656: False, 913: False},
            672: {465: False, 611: False}},
           872: {725: False, 802: {228: False, 956: False}}},
          830: False}},
        204: {843: {101: {548: False, 603: False}, 911: False}, 954: False}}}},
     875: False}}},
  790: False}}

In [13]:
import pydot

In [14]:
sample_tree.items()

dict_items([(360, {507: {217: False, 629: {44: False, 690: {484: {775: {83: False, 680: False}, 994: {138: False, 444: {29: {409: {793: {69: {529: {656: False, 913: False}, 672: {465: False, 611: False}}, 872: {725: False, 802: {228: False, 956: False}}}, 830: False}}, 204: {843: {101: {548: False, 603: False}, 911: False}, 954: False}}}}, 875: False}}}, 790: False})])

In [31]:

def draw(parent_name, child_name):
    edge = pydot.Edge(parent_name, child_name)
    graph.add_edge(edge)

def visit(node, parent=None):
    for k,v in node.items():
        if isinstance(v, dict):
            # We start with the root node whose parent is None
            # we don"t want to graph the None node
            if parent:
                draw(parent, k)
            visit(v, k)
        else:
            draw(parent, k)
            # drawing the label using a distinct name
            draw(k, f"Stop_{k}")

graph = pydot.Dot(graph_type="graph")
visit(sample_tree)
graph.write_png("example1_graph.png")

In [32]:
!open example1_graph.png