# Using Decision trees and Random forests to predict Hearts diseases

In [44]:
import pandas as pd
import numpy as np
from pprint import pprint, pformat
from anytree import Node, RenderTree #To construct trees.
import copy #Will use to deepcopy instances of the Node class.
from sklearn.model_selection import KFold #For the K-fold cross validation.
import logging
logging.basicConfig(level=logging.DEBUG, format='%(message)s')

We will build a simple logging function to make debugging easier.

In [45]:
global logger
logger = logging.getLogger()
def log(*args, separator=' ', logger=logger):
    '''
    Logs args into the console using logger, separating them by separator.
    A logger must be defined externally.
    '''
    str_args = [str(arg) for arg in args]
    string = separator.join(str_args)
    logger.debug(string)

In [46]:
logger.disabled = False
#Example:
a = np.eye(2)
b = 1
log(a,b)

[[1. 0.]
 [0. 1.]] 1


## Importing the data

In [47]:
heart_data = pd.read_csv('Heart.csv')
heart_data.describe()

Unnamed: 0.1,Unnamed: 0,Age,Sex,RestBP,Chol,Fbs,RestECG,MaxHR,ExAng,Oldpeak,Slope,Ca
count,303.0,303.0,303.0,303.0,303.0,303.0,303.0,303.0,303.0,303.0,303.0,299.0
mean,152.0,54.438944,0.679868,131.689769,246.693069,0.148515,0.990099,149.607261,0.326733,1.039604,1.60066,0.672241
std,87.612784,9.038662,0.467299,17.599748,51.776918,0.356198,0.994971,22.875003,0.469794,1.161075,0.616226,0.937438
min,1.0,29.0,0.0,94.0,126.0,0.0,0.0,71.0,0.0,0.0,1.0,0.0
25%,76.5,48.0,0.0,120.0,211.0,0.0,0.0,133.5,0.0,0.0,1.0,0.0
50%,152.0,56.0,1.0,130.0,241.0,0.0,1.0,153.0,0.0,0.8,2.0,0.0
75%,227.5,61.0,1.0,140.0,275.0,0.0,2.0,166.0,1.0,1.6,2.0,1.0
max,303.0,77.0,1.0,200.0,564.0,1.0,2.0,202.0,1.0,6.2,3.0,3.0


In [48]:
heart_data.head()

Unnamed: 0.1,Unnamed: 0,Age,Sex,ChestPain,RestBP,Chol,Fbs,RestECG,MaxHR,ExAng,Oldpeak,Slope,Ca,Thal,AHD
0,1,63,1,typical,145,233,1,2,150,0,2.3,3,0.0,fixed,No
1,2,67,1,asymptomatic,160,286,0,2,108,1,1.5,2,3.0,normal,Yes
2,3,67,1,asymptomatic,120,229,0,2,129,1,2.6,2,2.0,reversable,Yes
3,4,37,1,nonanginal,130,250,0,0,187,0,3.5,3,0.0,normal,No
4,5,41,0,nontypical,130,204,0,2,172,0,1.4,1,0.0,normal,No


## Decision trees

1. Make a tree constructor based on the algorithm from fig. 18.5 of [Norvig] and/or algorithm 8.1 of [Gareth].
2. Allow for both categorical and continuous variables: maybe ask the user to specify what variables are continuous and how many splits to perform on each step. Sklearn does not handle categorical variables super easily, so this can actually be useful. Decide on a stopping criterion, maybe even allow for different ones.
3. Understand why pruning works as [Hastie] says and using [Breiman], and then implement it as [Hastie] describes.
4. Apply to the data. Maybe even apply on only two variables first and plot the regions as [Gareth] does for the Hitters data.

### Regression trees [Gareth]

In [49]:
def binary_split(examples, predictor_to_split_idx, cutpoint):
    '''
    Performs a binary split on the examples array.
    ------------
    Parameters:
    examples is a (# examples, p+1) numpy array where p is the number of predictors. The last column contains the responses.
    Splits the predictor_to_split in two at cutpoint.
    ------------
    Returns a tuple with two (, p+1) numpy arrays of examples:
        1: examples whose predictor_to_split is < cutpoint.
        2: examples whose predictor_to_split is >= cutpoint.
    '''
    predictors_to_split = examples[:, predictor_to_split_idx] #Column vector with the value for the predictor to split for all the examples.
    mask = predictors_to_split < cutpoint #Boolean mask with True values where predictor to split < cutpoint.
    set1 = examples[mask, :]
    set2 = examples[np.logical_not(mask), :] #Selects the examples where predictor to split >= cutpoint.
    return set1, set2

Small test:

In [50]:
exs = np.array([[1,2,3],[1,5,6]])
j = 1 #index of the split variable
cutpoints = [1, 3, 5]
for s in cutpoints:
    set1, set2 = binary_split(exs, j, s)
    xj_set1 = set1[:, j] #All the xjs in set1. These should all be < s.
    xj_set2 = set2[:, j] #All the xjs in set2. These should all be >= s.
    print('xj values from in set1: {}, cutpoint = {}, xj values from in set1: {}'.format(xj_set1, s, xj_set2))

xj values from in set1: [], cutpoint = 1, xj values from in set1: [2 5]
xj values from in set1: [2], cutpoint = 3, xj values from in set1: [5]
xj values from in set1: [2], cutpoint = 5, xj values from in set1: [5]


In [51]:
exs[0,:-1]

array([1, 2])

It's working!

In [52]:
def optimal_binary_split(examples, predictors_to_split_indices, grid_step_number=10, debug_prints=False):
    '''
    Splits the examples array at the best cutpoint and using the best predictor.
    ------------
    Parameters:
    examples is a (# examples, p+1) numpy array where p is the number of predictors. The last column contains the responses.
    predictors_to_split_indices is a list containing the indices of the predictors we want to split. It can also be the string 'all', in which case we split all the predictors.
    ------------
    Returns:
    Tuple (j, s, smallest_cost):
        1) index of the predictor that was split.
        2) chosen cutpoint for the split.
        3) cost of the chosen (thus optimal) split.
    '''
    ###Set up logger
    global logger
    if debug_prints:
        logger.disabled = False
    else:
        logger.disabled = True #Must write this explicitly to not have problems when running inside other functions.
    ###
    p = len(examples[0, :-1]) #Number of predictors.
    smallest_cost = 100000
    best_split = (0, 0, smallest_cost) #placeholder for best_split.
    for j in predictors_to_split_indices:
        #Construct array of cutpoints:
        max_cutpoint =  max(examples[:, j]) 
        step = (max_cutpoint - min(examples[:,j]))/grid_step_number #grid step
        min_cutpoint =  min(examples[:, j]) + step  #We do not want to include min(examples[:,j]) itself, since that would lead to cases with no points 'on the left' because of how binary_split() was defined.
        if min_cutpoint == max_cutpoint: #Examples have the same jth predictor value, so no split can be made, so we abort the split.
            log('Skip j = {}'.format(j))
            continue #Goes to the top of this loop again.
        cutpoints = np.linspace(min_cutpoint, max_cutpoint, grid_step_number) 
        log('Cutpoints for j={}: {}, step: {}'.format(j, cutpoints, step))
        for s in cutpoints:
            set1, set2 = binary_split(examples, j, s)
            log('s,j: {}, {} \n set1 --- set2: {} --- {}'.format(s,j, set1, set2))
            y_1 = set1[:, p]; y_2 = set2[:, p] #Extract the responses.
            y1_estimate = np.average(y_1); y2_estimate = np.average(y_2) #Estimates will simply be the averages.
            cost = np.sum(np.square(y_1 - y1_estimate)) + np.sum(np.square(y_2 - y2_estimate))
            if cost < smallest_cost:
                log('NEW COST: {}'.format(cost))
                smallest_cost = cost
                best_split = (j, s, smallest_cost) #Store info about this iteration.
    return best_split

In [53]:
#Example.
exs = np.array([[1,4,7],[1,5,2],[1,2,8]])
print('Examples: \n {}'.format(exs))
p = len(exs[0, :-1]) #Number of predictors.
predictors_to_split_indices = range(p) #List with indices of predictors to split.
optimal_binary_split(exs, predictors_to_split_indices, debug_prints=True)

Skip j = 0
Cutpoints for j=1: [2.3 2.6 2.9 3.2 3.5 3.8 4.1 4.4 4.7 5. ], step: 0.3
s,j: 2.3, 1 
 set1 --- set2: [[1 2 8]] --- [[1 4 7]
 [1 5 2]]
NEW COST: 12.5
s,j: 2.5999999999999996, 1 
 set1 --- set2: [[1 2 8]] --- [[1 4 7]
 [1 5 2]]
s,j: 2.9, 1 
 set1 --- set2: [[1 2 8]] --- [[1 4 7]
 [1 5 2]]
s,j: 3.2, 1 
 set1 --- set2: [[1 2 8]] --- [[1 4 7]
 [1 5 2]]
s,j: 3.5, 1 
 set1 --- set2: [[1 2 8]] --- [[1 4 7]
 [1 5 2]]
s,j: 3.8, 1 
 set1 --- set2: [[1 2 8]] --- [[1 4 7]
 [1 5 2]]
s,j: 4.1, 1 
 set1 --- set2: [[1 4 7]
 [1 2 8]] --- [[1 5 2]]
NEW COST: 0.5
s,j: 4.4, 1 
 set1 --- set2: [[1 4 7]
 [1 2 8]] --- [[1 5 2]]
s,j: 4.7, 1 
 set1 --- set2: [[1 4 7]
 [1 2 8]] --- [[1 5 2]]
s,j: 5.0, 1 
 set1 --- set2: [[1 4 7]
 [1 2 8]] --- [[1 5 2]]


Examples: 
 [[1 4 7]
 [1 5 2]
 [1 2 8]]


(1, 4.1, 0.5)

Notice that the function behaves as expected, clumping together the first and third rows, which indeed have the closest response values 7 and 8.

In [54]:
def recursive_binary_split(examples, max_leaf_population, predictors_to_split_indices='all', grid_step_number=10, debug_prints=False, deep_debug_prints = False):
    '''
    Recursively splits the examples array, building a decision tree with binary splits and stopping when every leaf contains less than max_leaf_population examples.
    ------------
    Parameters:
    examples is a (# examples, p+1) numpy array where p is the number of predictors. The last column contains the responses.
    predictors_to_split_indices is a list containing the indices of the predictors we want to split. It can also be the string 'all', in which case we split all the predictors.
    ------------
    Returns:
    tree_root: the root node of the resulting tree.
    '''
    ###Set up logger
    global logger
    if debug_prints:
        logger.disabled = False
    else:
        logger.disabled = True #Must write this explicitly to not have problems when running inside other functions.
    logger_bool = logger.disabled
    ###
    p = len(examples[0, :-1]) #Number of predictors.
    if predictors_to_split_indices == 'all':
        predictors_to_split_indices = range(p) #List with indices of predictors to split.
    #Initialize variables.
    tree_root = Node(examples, id='root'); tree_root.region = [] #Initialize tree with the root node.
    top_leaf_population = len(examples[:,0])
    while top_leaf_population >= max_leaf_population:
        #Create list of region leaf nodes.
        regions = tree_root.leaves
        #Select best region to make split, and split parameters.
        splits_info = [((idx, region), optimal_binary_split(region.name, predictors_to_split_indices, grid_step_number=grid_step_number, debug_prints=deep_debug_prints)) for (idx,region) in enumerate(regions) if region.name.shape[0]>0] #The if statement ensures there's at least one example. region.name is the array of examples in the region node.
        logger.disabled = logger_bool #Restore logger.disable (may have been changed by optimal_binary_split()).
        #Find best split.
        costs = [info[1][2] for info in splits_info] #info[1] is tuple (j,s,smallest_cost)
        min_idx = np.argmin(costs)
        best_split = splits_info[min_idx] #Tuple ((idx, region), (j,s,cost))
        #We now make the actual split.
        idx_region, region = best_split[0]
        j = best_split[1][0]
        s = best_split[1][1]
        R1, R2 = binary_split(region.name, j, s)
        #Create new leaves, update tree and leaf pop:
        leaf1 = Node(R1, id='x_{} < {}'.format(j,s)); leaf1.region = region.region + [(j, s, True)] #Add new constraints to the parent's contraints.
        leaf2 = Node(R2, id='x_{} >= {}'.format(j,s)); leaf2.region = region.region + [(j, s, False)]
        region.children += (leaf1, leaf2,)
        top_leaf_population = max([len(leaf.name[:,0]) for leaf in tree_root.leaves])
        #Logs for debugging:
        log('--------------------')
        log('Two representations of the current tree:'); log('')
        log(RenderTree(tree_root).by_attr('id'));  log('')
        log(RenderTree(tree_root).by_attr('name'));  log('')
        log('Current max leaf population: {}'.format(top_leaf_population))
    return tree_root

In [55]:
#Example.
exs = np.array([[1,4,7],[1,5,2],[1,2,8]])
print('Examples: \n {}'.format(exs))
tree_root = recursive_binary_split(exs, 2, debug_prints=True)
print('\n The final tree: \n', RenderTree(tree_root).by_attr('name'))
print('\n The final tree (encoded) region of the first leaf: \n', tree_root.leaves[0].region)

--------------------
Two representations of the current tree:

root
├── x_1 < 4.1
└── x_1 >= 4.1

[[1 4 7]
 [1 5 2]
 [1 2 8]]
├── [[1 4 7]
│    [1 2 8]]
└── [[1 5 2]]

Current max leaf population: 2
--------------------
Two representations of the current tree:

root
├── x_1 < 4.1
│   ├── x_1 < 2.2
│   └── x_1 >= 2.2
└── x_1 >= 4.1

[[1 4 7]
 [1 5 2]
 [1 2 8]]
├── [[1 4 7]
│    [1 2 8]]
│   ├── [[1 2 8]]
│   └── [[1 4 7]]
└── [[1 5 2]]

Current max leaf population: 1


Examples: 
 [[1 4 7]
 [1 5 2]
 [1 2 8]]

 The final tree: 
 [[1 4 7]
 [1 5 2]
 [1 2 8]]
├── [[1 4 7]
│    [1 2 8]]
│   ├── [[1 2 8]]
│   └── [[1 4 7]]
└── [[1 5 2]]

 The final tree (encoded) region of the first leaf: 
 [(1, 4.1, True), (1, 2.2, True)]


#### Tree prunning [Hastie]

The tree $T_0$ created using recursive binary split will probably overfit. The strategy is to simplify the model by finding an appropriate subtree $T$ of $T_0$ by pruning $T$ using *cost-complexity-pruning*, which we'll briefly explain here.

The *cost complexity criterion* is defined by
$$
C_\alpha(T) = \sum_{m=1}^{\vert T \vert} \sum_{i\in I_m} (y_i - \hat{y}_{R_m})^2 + \alpha \vert T \vert
$$


where $\vert T \vert$ is the number of leaves of $T$, $I_m$ is the indexing set of the region $R_m$ (*i.e.* $R_m = \{x_i \in \text{examples}: i\in I_m\}$), $\hat{y}_{R_m}$ is the predicted response in $R_m$ (in our case $\hat{y}_{R_m}$ is just the mean $\mu_{R_m}$), and $\alpha\in \mathbb{R}^+$ controls the size of the tree.

We now want to find the subtree $T_\alpha\subseteq T_0$ that minimizes $C_\alpha(T)$. Notice that for $\alpha = 0$ the minimizing subtree is $T_0$ itself, thus justifying the notation.

It can be shown [Breiman] that for all $\alpha\in\mathbb{R}^+$ there is a unique smallest subtree $T_\alpha$ that minimizes the cost complexity criterion, and to find it one can use *weakest link pruning*: starting from the bottom (the leaves) of the tree $T_0$, undo the split which has less impact (decreases the least) on the $\sum_{m=1}^{\vert T \vert} \sum_{i\in I_m} (y_i - \hat{y}_{R_m})^2$ part of the cost complexity criterion, obtaining a subtree; keep doing this until you're left only with the root of the tree.
This gives us a sequence of subtrees of $T_0$, and it turns out [Breiman] that this sequence must contain $T_\alpha$.

This means that we can simply implement weakest link pruning to obtain a sequence of subtrees and find the one which minimizes $C_\alpha$. That subtree must be $T_\alpha$.

In [56]:
def cost_complexity_criterion(alpha, regions):
    '''
    Computes the cost complexity criterion using the average as the in-region prediction.
    ------------
    Parameters:
    alpha >= 0.
    regions is a list of 'regions', each 'region' being an array of examples.
    ------------
    Returns:
    cost: real positive number.
    '''
    cost = 0
    leaf_number = len(regions)
    for region in regions:
        ys = region[:,-1] #Extract responses.
        pred = np.mean(ys) #Decision trees usually predict using the in-region mean.
        sq_dev = np.sum(np.square( ys - pred ))
        cost += sq_dev + alpha*leaf_number
    return cost

In [57]:
#Let's test this on the example from before:
alphas = [0, 0.5, 1, 2, 10]
example_leaves = recursive_binary_split(exs, 2).leaves
leaves = [leaf.name for leaf in example_leaves]
print('leaves: ', leaves)
for alpha in alphas:
    print('alpha: {}, cost: {}'.format(alpha, cost_complexity_criterion(alpha, leaves)) )

# But what if we allow for 2 examples per region when splitting?:
example_leaves = recursive_binary_split(exs, 3).leaves
leaves = [leaf.name for leaf in example_leaves]
print('leaves: ', leaves)
for alpha in alphas:
    print('alpha: {}, cost: {}'.format(alpha, cost_complexity_criterion(alpha, leaves)) )

leaves:  [array([[1, 2, 8]]), array([[1, 4, 7]]), array([[1, 5, 2]])]
alpha: 0, cost: 0.0
alpha: 0.5, cost: 4.5
alpha: 1, cost: 9.0
alpha: 2, cost: 18.0
alpha: 10, cost: 90.0
leaves:  [array([[1, 4, 7],
       [1, 2, 8]]), array([[1, 5, 2]])]
alpha: 0, cost: 0.5
alpha: 0.5, cost: 2.5
alpha: 1, cost: 4.5
alpha: 2, cost: 8.5
alpha: 10, cost: 40.5


Notice that the results make sense: at $\alpha=0$ the cost is zero (so of course minimum) for the more complex tree (the first one). But for the other values of alpha one sees that the cost of the simpler tree (the second one) gives a lower cost! This is precisely the kind of behavior we wanted.

In order to define weakest link pruning, we will use the fact that each pair of siblings in the tree correspond exactly to a split:

In [58]:
#Using again the above example:
print('Tree: \n', RenderTree(tree_root).by_attr('name'), '\n\nDescendants of the root and their siblings:')
for descendant in tree_root.descendants:
    print('descendant:\n{}\n   sibling:{}'.format(descendant.name, [sib.name for sib in descendant.siblings]))

Tree: 
 [[1 4 7]
 [1 5 2]
 [1 2 8]]
├── [[1 4 7]
│    [1 2 8]]
│   ├── [[1 2 8]]
│   └── [[1 4 7]]
└── [[1 5 2]] 

Descendants of the root and their siblings:
descendant:
[[1 4 7]
 [1 2 8]]
   sibling:[array([[1, 5, 2]])]
descendant:
[[1 2 8]]
   sibling:[array([[1, 4, 7]])]
descendant:
[[1 4 7]]
   sibling:[array([[1, 2, 8]])]
descendant:
[[1 5 2]]
   sibling:[array([[1, 4, 7],
       [1, 2, 8]])]


So the algorithm of the weakest link pruning (acting on the tree $T$) will:
   1. prune away one pair of sibling leaves, creating a temporary subtree $T'$ of $T$.
   2. compute the complexity cost criterion of $T'$.
   3. repeat 1,2 for all pairs of sibling leaves.
   4. select the subtree $T'$ with the lowest cost, and do $T=T'$.
   5. repeat 1-4 until $T= \{\mathrm{root}\}$ -- or equivalently until the height of $T$ is 0.

In [59]:
def get_sibling_pairs(tree_root):
    '''
    '''
    leaves = tree_root.leaves
    sib_pairs = []
    for leaf in tree_root.leaves:
        if any(leaf in pair for pair in sib_pairs): #We don't want duplicates, so we move on to the next leaf if this one is already in a pair.
            continue
        sib_pairs.append((leaf, leaf.siblings[0])) #Each leaf will have a unique sibling.
    return sib_pairs

def prune_siblings(sibs):
    '''
    sibs is tuple.
    '''
    parent = sibs[0].parent
    parent.children = [] #This erases the siblings.
    return 
    
    

def wl_prun_step(tree_root, alpha): #Contains steps 1-4 of the algorithm.
    '''
    '''
    cost = 100000000
    sib_pairs = get_sibling_pairs(tree_root)
    for i in range(len(sib_pairs)): #Will go through all leaf sibling pairs. Must do this this way to allow for deep copies on every iteration.
        tree_temp = copy.deepcopy(tree_root) #Copy the tree for manipulation.
        #Select and prune pair:
        sib_pairs_temp = get_sibling_pairs(tree_temp)
        pair = sib_pairs_temp[i]
        prune_siblings(pair)
        #Compute cost of T':
        leaves_nodes = tree_temp.leaves
        leaves = [leaf.name for leaf in example_leaves] #We must feed arrays to the cost function, not nodes.
        cost_temp = cost_complexity_criterion(alpha, leaves)
        if cost_temp < cost:
            cost = cost_temp
            best_tree_temp = copy.deepcopy(tree_temp)
    tree_root = copy.deepcopy(best_tree_temp) #Step 4 of the algorithm.
    return tree_root
    
def weakest_link_pruning(tree_root, alpha):
    '''
    '''
    pruning_log = [( copy.deepcopy(tree_root), cost_complexity_criterion(alpha, [leaf.name for leaf in tree_root.leaves]) )]
    itr = 0
    print('iter: {}\n{}\n----------'.format(itr, RenderTree(tree_root).by_attr('id')))
    while tree_root.height > 0:
        itr += 1
        tree_root = wl_prun_step(tree_root, alpha)
        print('iter: {}\n{}\n----------'.format(itr, RenderTree(tree_root).by_attr('id')))
        pruning_log.append(( copy.deepcopy(tree_root), cost_complexity_criterion(alpha, [leaf.name for leaf in tree_root.leaves]) ))
    return pruning_log

In [60]:
#Test the first and second functions with the example from before:
tree_root = recursive_binary_split(exs, 2)
tree_og = copy.deepcopy(tree_root)
print(RenderTree(tree_root).by_attr('id'))
print('\n Pairs:\n', get_sibling_pairs(tree_root))
pair = get_sibling_pairs(tree_root)[0]
prune_siblings(pair)
print('\nDeleting the first pair: \n', RenderTree(tree_root).by_attr('id'))

root
├── x_1 < 4.1
│   ├── x_1 < 2.2
│   └── x_1 >= 2.2
└── x_1 >= 4.1

 Pairs:
 [(Node('/[[1 4 7]\n [1 5 2]\n [1 2 8]]/[[1 4 7]\n [1 2 8]]/[[1 2 8]]', id='x_1 < 2.2', region=[(1, 4.1, True), (1, 2.2, True)]), Node('/[[1 4 7]\n [1 5 2]\n [1 2 8]]/[[1 4 7]\n [1 2 8]]/[[1 4 7]]', id='x_1 >= 2.2', region=[(1, 4.1, True), (1, 2.2, False)])), (Node('/[[1 4 7]\n [1 5 2]\n [1 2 8]]/[[1 5 2]]', id='x_1 >= 4.1', region=[(1, 4.1, False)]), Node('/[[1 4 7]\n [1 5 2]\n [1 2 8]]/[[1 4 7]\n [1 2 8]]', id='x_1 < 4.1', region=[(1, 4.1, True)]))]

Deleting the first pair: 
 root
├── x_1 < 4.1
└── x_1 >= 4.1


In [61]:
#Test the pruning:
tree_root = recursive_binary_split(exs, 2)

alpha = 0.8
pruning_log = weakest_link_pruning(tree_root, alpha)

#See if the log has all the subtress that we expect:
print('----------- Log ----------')
for idx,(tree,cost) in enumerate(pruning_log):
    print('iter: {}, cost: {}\n{}\n'.format(idx, cost, RenderTree(tree).by_attr('id')))

iter: 0
root
├── x_1 < 4.1
│   ├── x_1 < 2.2
│   └── x_1 >= 2.2
└── x_1 >= 4.1
----------
iter: 1
root
├── x_1 < 4.1
└── x_1 >= 4.1
----------
iter: 2
root
----------
----------- Log ----------
iter: 0, cost: 7.200000000000001
root
├── x_1 < 4.1
│   ├── x_1 < 2.2
│   └── x_1 >= 2.2
└── x_1 >= 4.1

iter: 1, cost: 3.7
root
├── x_1 < 4.1
└── x_1 >= 4.1

iter: 2, cost: 21.466666666666665
root



Lastly, we want to select, from the log, the tree with the smallest cost.

In [62]:
def cost_complexity_pruning(tree_root, alpha):
    '''
    Returns: tuple (best_tree, min_cost)
    '''
    pruning_log = weakest_link_pruning(tree_root, alpha) #The entries are of type (tree_root, cost)
    best_tree, min_cost = min(pruning_log, key=lambda p: p[1]) #  p[1]=cost.
    return best_tree, min_cost

In [66]:
#Bck to the same example as before:
tree_root = recursive_binary_split(exs, 2)

alpha = 0.8
best_tree, min_cost = cost_complexity_pruning(tree_root, alpha)
print('cost: ', min_cost, '\nTree Talpha with alpha=0.8:\n', RenderTree(best_tree).by_attr('id'))

iter: 0
root
├── x_1 < 4.1
│   ├── x_1 < 2.2
│   └── x_1 >= 2.2
└── x_1 >= 4.1
----------
iter: 1
root
├── x_1 < 4.1
└── x_1 >= 4.1
----------
iter: 2
root
----------
cost:  3.7 
Tree Talpha with alpha=0.8:
 root
├── x_1 < 4.1
└── x_1 >= 4.1


We see that we obtained the correct $T_\alpha$.

## Selecting $\alpha$

In order to select a good value for alpha we use $K$-fold cross-validation.

In [33]:
# 3-folding in our example:
kf = KFold(n_splits=3)
for train_idx, test_idx in kf.split(exs):
    print(train_idx, test_idx)

[1 2] [0]
[0 2] [1]
[0 1] [2]


In [65]:
def find_region(tree_root, x):
    '''
    x is a one dimensional array.
    Returns: leaf (node) corresponding to the region where x is.
    '''
    leaves = tree_root.leaves
    for leaf in leaves:
        encoded_region = leaf.region  # [(j,s,smaller?),...,(j,s,smaller?)]
        #for j,s,smaller in encoded_region:
         #   if (x[j] < s) == smaller: # True if xj<s and smaller=True or if xj>=s and smaller=False. False otherwise.
        bools = [ (x[j] < s) == smaller for j,s,smaller in encoded_region ]
        if all(bools):
            return leaf
    return 'x not in the domain?!'

In [74]:
#Test find_region
print(find_region(best_tree, [1, 4.2]).id)
print(find_region(best_tree, [1, 4.1]).id)
print(find_region(best_tree, [1, 4.0]).id)

x_1 >= 4.1
x_1 >= 4.1
x_1 < 4.1


In [65]:
                
def prediction_from_tree(tree_root, xs):
    '''
    '''
    for x in xs:
        #Get the region:
        region_leaf = find_region(tree_root, x)
        region_examples = region_leaf
    pass 
    

def choose_alpha(examples, alphas, K, max_leaf_population=3,):
    '''
    alphas: list of values to test for alpha.
    '''
    kf = KFold(n_splits=3)
    for train_idx, test_idx in kf.split(exs): #Will have K iterations.
        train, test = examples[train_idx], examples[test_idx]
        tree_root = recursive_binary_split(train, max_leaf_population) #Construct the complex tree T.
        #Construct sequence of best subtrees, one for every alpha:
        tree_sequence = []
        for alpha in alphas:
            best_tree, cost = cost_complexity_pruning(tree_root, alpha)
            tree_sequence.append((best_tree, cost))
        #Compute the errors of these trees on the test set (still one for each alpha).
        predictions = []
        errors = []
    #Average the errors. One average for each alpha.
    pass
    

## References:

1. [Hastie]  
2. [Norvig]
3. [Gareth]
4. [Breiman]