# Intro to decision trees

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D 
from matplotlib import cm
# Uncomment the following once you have installed graphviz (both the Python module and the package itself)
# https://www.graphviz.org/download/
# from graphviz import Digraph

<font color = 'green'> __Nonparametric model__: <font color = 'red'> **A model where the parameters (number, meaning) are not fixed before training. For example, in a curve fit, we specify the model function, coefficients, and what the coefficients mean before doing anything else. In a nonparametric model, we don't know the full structure of the model in advance.**
    
**This can offer a lot of flexibility, but can also increase the variance of our models and the risk of overfitting.**

<font color = 'green'> __Entropy:__ <font color = 'red'> **A measure of uncertainty applied to a probability distribution. Calculated using the formula**

<font color = 'black'>
$$ H(X) = -\sum_i p_i \log_2(p_i) $$

<font color = 'red'>
    
**where $p_i$ are the probabilities of the outcomes/values $x_i$ of $X$. It's named for an analogous concept in statistical physics/thermodynamics, which measures the disorder in a collection of particles using a similar mathematical function. Originated in the theory of communication and message encoding developed by Claude Shannon and others (which is why you sometimes hear this called Shannon entropy) at Bell Labs in the 1940s and later.**

<font color = 'green'> __Conditional entropy:__ <font color = 'red'> **Like probabilities, we can update the calculation of entropy when we observe a related variable. If the observed variable is categorical, we can think of it as breaking the data into subsets. We calculate the entropy of the target variable within each subset normally, then take a weighted average according to the frequency of each subset.**
    
<font color = 'black'> 
$$ H(X | Y) = \sum_j P(\mbox{subset j}) H(\mbox{subset j}) $$

<font color = 'red'>     
**or, more formally (this is what you'll probably see if you look the definition up):**

<font color = 'black'> 
$$ H(X | Y) = \sum_j P(Y = y_j) H(X | Y = y_j) $$

<font color = 'green'> __Mutual information:__ <font color = 'red'> **If $X$ and $Y$ are two variables, the mutual information is**
    
<font color = 'black'>
$$I(X; Y) = H(X) - H(X | Y)$$
    
<font color = 'red'> 
    
**or the "gap" between the entropy of a variable, and the conditional entropy after observing a related variable. Since it measures how much uncertainty is reduced by the observation, this is also sometimes called "information gain." Note $I(X; Y) = I(Y; X)$ even though $H(X | Y) \neq H(Y | X)$.**

<font color = 'green'> __Cost function / benefit function:__ <font color = 'red'> **A cost function (also called a loss function) is a measure of goodness of fit that we want to minimize -- i.e. something that measures the error in our model. Example: sum of squared residuals in curve fitting. A benefit function is the opposite: a measure of goodness of fit that we try to maximize. Here, mutual information is serving as a benefit function.**
    
**Note you can generally translate between cost / benefit functions by multiplying by -1; this is sometimes done for computational or mathematical convenience.**

In [None]:
# Utility function for estimating the entropy of a target variable from a data frame.

def entropy(df, target):
    '''
    Computes the entropy of the target variable in the data frame.
    Parameters:
        df: a DataFrame containing labeled data with a categorical target variable
        target: the name of the target variable (str)
    Returns:
        H, the entropy of the target variable (float)
    '''
    labels = {label for label in df[target]}
    probs = [len(df[df[target] == label]) / len(df) for label in labels]
    return -sum(p * np.log2(p) for p in probs if p != 0)

# Utility function for estimating conditional entropy from a data frame.

def cond_entropy(df, pred, target):
    '''
    Computes the conditional entropy of a target variable given a certain predictor.
    Parameters:
        df: a DataFrame containing labeled data with categorical predictors and a categorical target
        pred: the name of the attribute being used as a predictor (str)
        target: the name of the target variable (str)
    Returns:
        H, the conditional entropy (float)
    '''
    # These are set comprehensions. Like list comprehensions but they make sets.
    # We use sets so that there are no duplicates; we want to consider each category/label only once
    categories = {cat for cat in df[pred]}
    labels = {label for label in df[target]}
    # Make a list of subsets of the data frame, broken down by the predictor
    subsets = [df[df[pred] == cat] for cat in categories]
    H = 0
    for subset in subsets:
        if len(subset) > 0: # avoid some errors
            # Calculate the label probabilities within the subset, put them in a list
            probs = [len(subset[subset[target] == label]) / len(subset) for label in labels]
            # Calculate the contribution of this subset to the conditional entropy and add it to H
            H += (len(subset) / len(df)) * (-sum(p * np.log2(p) for p in probs if p != 0))
    return H

In [None]:
# Data from the 60s on survival after surgery for breast cancer
surv = pd.read_csv('survival.csv')
surv

In this data frame, a `Survival` value of 1 indicates survival 5+ years, 2 indicates death within 5 years.

In [None]:
entropy(surv, 'Survival')

## Finding the best variable to split on

We want to formulate a decision tree as a sequence of yes/no questions about the predictors. If the predictors are categorical, especially binary categorical, it is simple to formulate the questions.

All of our predictors in this data set are numerical, though. So our questions should look like:
* is age > 35?
* is nodes = 0?
* etc.

At each step we split the data into two pieces, and eventually make an estimate based on that.

There are many such questions we can ask at any point. How do we pick one? Look for the best mutual information/information gain.

#### Exercise

Write a function called `mutual_information` that takes a data frame, a predictor variable, a cutoff value for that predictor variable, and a target variable. The function returns the mutual information of the target variable and the binary categorical variable `predictor > / <= cutoff`.

**Hint:** create a temporary copy of the data frame, expand it by adding a new column representing the binary categorical variable, then use the utility functions `entropy` and `conditional_entropy` defined above.

In [None]:
def mutual_information(df, predictor, cutoff, target):
    pass

In [None]:
plt.plot(sorted(set(surv['Age'])), 
         [mutual_information(surv, 'Age', x, 'Survival') for x in sorted(set(surv['Age']))])

This plot tells us that if we ask about age, the question should be "is `age > 40`?"

In [None]:
plt.plot(sorted(set(surv['Year'])), 
         [mutual_information(surv, 'Year', x, 'Survival') for x in sorted(set(surv['Year']))])

So the year has very little mutual information with the survival rate.

In [None]:
plt.plot(sorted(set(surv['Nodes'])), 
         [mutual_information(surv, 'Nodes', x, 'Survival') for x in sorted(set(surv['Nodes']))])

Now we can settle on something: the best first question to ask is: is `nodes > 4`? In fact, a 1983 study [https://www.ncbi.nlm.nih.gov/pubmed/6352003] groups patients into 0 nodes, 1-3 nodes, and 4+ nodes, so our approach suggests a grouping pretty similar to that, which is encouraging.

Let's see what the survival rates are in these two subgroups.

In [None]:
surv_a = surv[surv['Nodes'] <= 4]
surv_b = surv[surv['Nodes'] > 4]

In [None]:
sum(surv_a['Survival'] == 1) / len(surv_a)

In [None]:
sum(surv_b['Survival'] == 1) / len(surv_b)

#### Exercise

Apply the same approach as above to split each of the data frames `surv_a` and `surv_b` into two. (Think about which variables you might want to use for a second split, then plot their mutual information with Survival. Choose a best variable and a cutoff.)

Now let's make our estimates.

## Visualizing the tree

In [None]:
# Uncomment this if you have graphviz working
'''
tree = Digraph()
tree.node('top', 'Start')
tree.node('A', 'Nodes <= 4')
tree.node('B', 'Nodes > 4')
tree.node('AA', 'Age <= 59')
tree.node('AB', 'Age > 59')
tree.node('BA', 'Age <= 63')
tree.node('BB', 'Age > 63')

tree.node('end_AA', 'Prognosis: 0.835')
tree.node('end_AB', 'Prognosis: 0.772')
tree.node('end_BA', 'Prognosis: 0.507')
tree.node('end_BB', 'Prognosis: 0.286')

tree.edge('top', 'A')
tree.edge('top', 'B')
tree.edge('A', 'AA')
tree.edge('A', 'AB')
tree.edge('B', 'BA')
tree.edge('B', 'BB')

tree.edge('AA', 'end_AA')
tree.edge('AB', 'end_AB')
tree.edge('BA', 'end_BA')
tree.edge('BB', 'end_BB')

tree
'''
pass

## A conceptual digression

From a more theoretical perspective, what a tree does is:
* partition the space of predictors into subsets
* fit a very simple model (a constant) on each subset

Let's visualize the prognosis as a function of age and nodes:

In [None]:
fig = plt.figure(figsize=(12,8))
ax = fig.gca(projection='3d')

# Make data
Y = np.arange(25, 85, 0.25)
X = np.arange(0, 40, 0.25)
X, Y = np.meshgrid(X, Y)

# Quick and dirty prediction function
def predict(nodes, age):
    if nodes <= 4:
        if age <= 59:
            return 0.835
        else:
            return 0.772
    else:
        if age <= 63:
            return 0.507
        else:
            return 0.286

Z = np.array([predict(x, y) for [x,y] in np.c_[X.ravel(), Y.ravel()]])
Z = Z.reshape(Y.shape)

# Plot the surface.
surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
                       linewidth=0.1, edgecolor = 'black', antialiased=True)

ax.set_xlabel('Nodes')
ax.set_ylabel('Age')
ax.set_zlabel('Prognosis')
# Customize the z axis.
ax.set_zlim(0, 1)

plt.show()

Notice the shape of the flat regions. Because of the way we structure our tree, these areas are always rectangles -- i.e., decision boundaries are always parallel to a coordinate axis. This is a limitation of these trees. In theory, any shape of decision boundary could be approximated by these rectangles, but that would mean a very complex tree and a high risk of overfitting.

## Ok, but how good is this really?

We committed a sin by using our entire data set to build the model instead of holding out some data for validation. Next time, we repent our crimes. We'll discover some of the weaknesses of decision trees and how they may be fixed in practice.