### Tree Data Structures

Decision tree construction requires that we think recursively.  Why?   As we expand each node, we are making the same decision over and over again - what is the next best split given the instances that I have at this point in the tree?   To construct the tree we recursively construct trees on the left and right branches.

In [1]:
# Recursive functions are functions that call themselves

def factorial(n):
    if n <= 1:
        return 1
    else:
        return n * factorial(n-1)
    
factorial(20)

2432902008176640000

In [2]:
# Binary search on a SORTED list
# Create a sorted list of numbers (in linear O(n) time)

import random as rnd

N = 1000
L = [1]
for i in range(1, N):
    L.append(L[i-1] + rnd.randrange(1,3))

print(L)


[1, 2, 4, 5, 6, 7, 9, 11, 13, 14, 15, 17, 19, 21, 22, 23, 24, 26, 28, 29, 31, 32, 33, 35, 36, 38, 40, 41, 42, 44, 46, 48, 50, 52, 54, 55, 57, 59, 61, 62, 63, 65, 67, 69, 70, 71, 73, 75, 77, 78, 79, 80, 82, 84, 86, 88, 90, 92, 94, 95, 97, 99, 101, 103, 105, 107, 108, 109, 110, 112, 114, 116, 118, 120, 121, 123, 124, 125, 127, 129, 131, 132, 134, 136, 138, 139, 140, 142, 143, 144, 146, 148, 149, 150, 152, 153, 154, 155, 157, 158, 160, 161, 163, 164, 165, 167, 168, 169, 170, 172, 173, 175, 177, 178, 180, 181, 182, 184, 185, 187, 188, 190, 191, 192, 194, 196, 198, 200, 202, 203, 204, 206, 207, 208, 210, 212, 213, 215, 216, 217, 219, 221, 223, 225, 227, 228, 230, 232, 234, 236, 237, 238, 240, 242, 244, 245, 247, 249, 251, 253, 254, 255, 256, 257, 259, 261, 262, 264, 266, 268, 270, 271, 272, 274, 275, 277, 278, 280, 281, 283, 285, 286, 288, 290, 292, 293, 294, 296, 297, 298, 299, 301, 303, 305, 306, 308, 309, 311, 312, 314, 315, 316, 318, 319, 321, 323, 324, 325, 326, 327, 329, 330, 331, 332

In [3]:
# Binary search

def bsearch(L, x):
    return binsearch(L, x, 0, len(L)-1)


def binsearch(L, x, lower, upper):
    
    if (lower > upper):
        print(x,"not found.")
        return None
    
    mid = (upper + lower) // 2
    val = L[mid]
    print(lower, mid, upper)
    
    if x == val:
        return mid
    elif x > val:
        return binsearch(L, x, mid+1, upper)
    else:
        return binsearch(L, x, lower, mid-1)
    
    

In [4]:
bsearch(L, 235)
bsearch(L, 234)

0 499 999
0 249 498
0 124 248
125 186 248
125 155 185
125 139 154
140 147 154
148 151 154
148 149 150
148 148 148
235 not found.
0 499 999
0 249 498
0 124 248
125 186 248
125 155 185
125 139 154
140 147 154
148 151 154
148 149 150
148 148 148


148

In [5]:
# We are already familiar with a recursive data structures. 
# They are called a DICTIONARIES AND TUPLES!
# 
#
#                     5
#                    | \
#                   3   9
#                  /\    \
#                 1  4    11
#
from pprint import pprint

T = (5, 
       (3, 
            (1, None, None), 
            (4, None, None)), 
       (9, 
            None, 
            (11, None, None))
    )

pprint(T, width=30)




(5,
 (3,
  (1, None, None),
  (4, None, None)),
 (9, None, (11, None, None)))


In [83]:
# Create a tree where each node is of the form (key, left_tree, right_tree)
# THIS IS NOT THE SIMPLEST APPROACH BUT IT WILL GET YOU INTO THE RIGHT
# FRAME OF MIND AND IS WORTH STUDYING

def insert(T, x):
    if T is None:
        T = (x, None, None) # Create a new node with no branches
    else:
        key, left, right = T
        if x <= key:
            T = (key, insert(left, x), right) # Update left tree
        else:
            T = (key, left, insert(right, x)) # Update the right tree
    return T
            

def tree(L, T=None):
    """Create a tree from the values in the list L """
    # print(L, T)
    if len(L) == 0:
        return T
    else:
        #T = insert(T, L[0])
        #T = tree(L[1:], T)
        #return T
        return tree(L[1:], insert(T, L[0]))
    
def walk(T):
    
    if T is not None:
        key, left, right = T
        walk(left)
        print(key)
        walk(right)

    
    

In [87]:
Lsmall = [5, 3, 1, 4, 9, 11]
Lsorted = sorted(Lsmall)

T = tree(Lsmall)
pprint(T, width=30)
walk(T)

(5,
 (3,
  (1, None, None),
  (4, None, None)),
 (9, None, (11, None, None)))
1
3
4
5
9
11


In [88]:
# A much improved (SIMPLER) approach!
# More inline with the splitting we do to construct a decision tree
# where instances are split based on the best condition we can find.

def tree(L):
    if L is None or len(L) == 0: # This is our stopping condition
        return None
    key = L[0]
    left_vals = [x for x in L[1:] if x <= key] # Left split
    right_vals = [x for x in L[1:] if x > key] # Right split
    return (key, tree(left_vals), tree(right_vals))

T = tree([5, 3, 1, 4, 9, 11])
pprint(T, width=30)
walk(T)

(5,
 (3,
  (1, None, None),
  (4, None, None)),
 (9, None, (11, None, None)))
1
3
4
5
9
11


In [89]:
# Notice what happens if your data is already sorted (bad)
Tsorted = tree(Lsorted)
pprint(Tsorted, width=30)

(1,
 None,
 (3,
  None,
  (4,
   None,
   (5,
    None,
    (9,
     None,
     (11, None, None))))))


In [91]:
# Some helper functions
def key(T):
    return T[0] if T is not None else None


def left(T):
    return T[1] if T is not None else None
    

def right(T):
    return T[2] if T is not None else None



# Searching a binary tree
# You'll need to do something like this when predicting classes in a decision tree...identify the right leaf
# And determine the majority class of that leaf....

def btree_search(T, x):
    if T is None:
        return False
    elif x == key(T):
        return True
    elif x < key(T):
        return btree_search(left(T), x)
    else:
        return btree_search(right(T), x)

In [94]:
for i in range(12):
    print(i, btree_search(T, i))

0 False
1 True
2 False
3 True
4 True
5 True
6 False
7 False
8 False
9 True
10 False
11 True


In [95]:
# What is the depth of the tree?
# The depth is the Maximum number of "steps" to any leaf
def depth(T):
    """ Return the maximum depth of the tree """
    if T is None:
        return -1
    else:
        return 1 + max(depth(left(T)), depth(right(T)))
    

print(depth(T), ":", T)
print(depth(Tsorted), ":", Tsorted)
print(depth((5, None, None)), ":", (5, None, None))


2 : (5, (3, (1, None, None), (4, None, None)), (9, None, (11, None, None)))
5 : (1, None, (3, None, (4, None, (5, None, (9, None, (11, None, None))))))
0 : (5, None, None)


In [96]:
def nodes(T):
    """ The number of nodes in the tree (internal or leaf) """
    if T is None:
        return 0
    else:
        return 1 + nodes(left(T)) + nodes(right(T))
    
print(nodes(T), ":", T)
print(nodes(Tsorted), ":", Tsorted)
print(nodes((5, None, None)), ":", (5, None, None))


6 : (5, (3, (1, None, None), (4, None, None)), (9, None, (11, None, None)))
6 : (1, None, (3, None, (4, None, (5, None, (9, None, (11, None, None))))))
1 : (5, None, None)


In [97]:
def leaves(T):
    """ Return the number of leaves in the Tree """
    if T is None:
        return 0
    elif left(T) is None and right(T) is None:
        return 1
    else:
        return leaves(left(T)) + leaves(right(T))
    
print(leaves(T), ":", T)
print(leaves(Tsorted), ":", Tsorted)    
print(leaves((5, None, None)), ":", (5, None, None))




3 : (5, (3, (1, None, None), (4, None, None)), (9, None, (11, None, None)))
1 : (1, None, (3, None, (4, None, (5, None, (9, None, (11, None, None))))))
1 : (5, None, None)


In [98]:
# Internal (non-leaf) nodes
def internal(T):
    return nodes(T) - leaves(T)

print(internal(T), ":", T)
print(internal(Tsorted), ":", Tsorted)    
print(internal((5, None, None)), ":", (5, None, None))


3 : (5, (3, (1, None, None), (4, None, None)), (9, None, (11, None, None)))
5 : (1, None, (3, None, (4, None, (5, None, (9, None, (11, None, None))))))
0 : (5, None, None)


We are gradually building up the tools necessary to construct decision tree.
Let's change the tree node representation.  Instead of (key, left, right),
lets carry along the values that reach that node.   This is akin to storing the
instances that reach each node of our decision tree. (We may not want to have a stopping
condition based on the number of remaining instance or the maximum depth of the tree.)
These "hyperparameters" can help us limit the complexity of our decision tree and prevent overfitting.

In [99]:
def vals(T):
    """The values that reach that node """
    T[3] if T is not None else None


In [100]:
def tree(L):
    """ Modified tree builder - store values at each node """
    if L is None or len(L) == 0: # This is our stopping condition
        return None
    key = L[0]
    left_vals = [x for x in L[1:] if x <= key] # Left split
    right_vals = [x for x in L[1:] if x > key] # Right split
    return (key, tree(left_vals), tree(right_vals), L)  # <-  The only line that changed!

In [101]:
T2 = tree(Lsmall)
pprint(T2)

(5,
 (3, (1, None, None, [1]), (4, None, None, [4]), [3, 1, 4]),
 (9, None, (11, None, None, [11]), [9, 11]),
 [5, 3, 1, 4, 9, 11])


Now we'll modify our tree builder with a new stopping condition: the minimum number of values.
If we don't have at least  *min_vals*  values (instances) we won't bother to split.  The default value
for this parameter will still be 1

In [153]:
def tree(L, min_vals=1):
    """ Build a binary tree, but stop splitting if number of values
    falls falls below some threshhold (min_vals) """
    if L is None or len(L)==0:
        return None
    elif len(L) < min_vals: # This is our NEW additional stopping condition
        return (L[0], None, None, L)
    else:
        key = L[0]
        left_vals = [x for x in L[1:] if x <= key] 
        right_vals = [x for x in L[1:] if x > key] 
        return (key, tree(left_vals, min_vals), tree(right_vals, min_vals), L)  

In [159]:
T3 = tree(Lsmall, min_vals=4)
pprint(T3, width=30)

(5,
 (3, None, None, [3, 1, 4]),
 (9, None, None, [9, 11]),
 [5, 3, 1, 4, 9, 11])


In [144]:
Lsmall

[5, 3, 1, 4, 9, 11]