# Decision Trees
![](figures/loan_tree.gif)

The binary splitting makes this extremely efficient: in a well-constructed tree, each question will cut the number of options by approximately half, very quickly narrowing the options even among a large number of classes.
The trick, of course, comes in deciding which questions to ask at each step.
In machine learning implementations of decision trees, the questions generally take the form of axis-aligned splits in the data: that is, each node in the tree splits the data into two groups using a cutoff value within one of the features.
Let's now look at an example of this.

In [2]:
import pandas as pd
import altair as alt
from sklearn import tree
from sklearn.tree import export_text
from sklearn.tree import export_graphviz

In [3]:
iris = pd.read_csv("data/iris.csv")
iris.head()

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa


In [4]:
from sklearn.datasets import load_iris

In [5]:
X = iris[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']]
y = iris['species']

In [6]:
dtree = tree.DecisionTreeClassifier()

In [7]:
dtree.fit(X, y)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=None, splitter='best')

In [8]:
y_fit = dtree.predict(X)

In [9]:
sum(y_fit == y) / len(y)

1.0

In [10]:
from sklearn.metrics import accuracy_score, confusion_matrix

accuracy_score(y_fit, y)

In [11]:
confusion_matrix(y_fit, y)

array([[50,  0,  0],
       [ 0, 50,  0],
       [ 0,  0, 50]])

## Shall we celebreate?

In [12]:
print(export_text(dtree, feature_names=list(X.columns)))

|--- petal_width <= 0.80
|   |--- class: setosa
|--- petal_width >  0.80
|   |--- petal_width <= 1.75
|   |   |--- petal_length <= 4.95
|   |   |   |--- petal_width <= 1.65
|   |   |   |   |--- class: versicolor
|   |   |   |--- petal_width >  1.65
|   |   |   |   |--- class: virginica
|   |   |--- petal_length >  4.95
|   |   |   |--- petal_width <= 1.55
|   |   |   |   |--- class: virginica
|   |   |   |--- petal_width >  1.55
|   |   |   |   |--- sepal_length <= 6.95
|   |   |   |   |   |--- class: versicolor
|   |   |   |   |--- sepal_length >  6.95
|   |   |   |   |   |--- class: virginica
|   |--- petal_width >  1.75
|   |   |--- petal_length <= 4.85
|   |   |   |--- sepal_width <= 3.10
|   |   |   |   |--- class: virginica
|   |   |   |--- sepal_width >  3.10
|   |   |   |   |--- class: versicolor
|   |   |--- petal_length >  4.85
|   |   |   |--- class: virginica



If above does not work, export to graphviz and render online using [webgraphviz](http://www.webgraphviz.com/)

In [13]:
print(export_graphviz(dtree, feature_names=list(X.columns)))

digraph Tree {
node [shape=box] ;
0 [label="petal_width <= 0.8\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]"] ;
1 [label="gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="petal_width <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="petal_length <= 4.95\ngini = 0.168\nsamples = 54\nvalue = [0, 49, 5]"] ;
2 -> 3 ;
4 [label="petal_width <= 1.65\ngini = 0.041\nsamples = 48\nvalue = [0, 47, 1]"] ;
3 -> 4 ;
5 [label="gini = 0.0\nsamples = 47\nvalue = [0, 47, 0]"] ;
4 -> 5 ;
6 [label="gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]"] ;
4 -> 6 ;
7 [label="petal_width <= 1.55\ngini = 0.444\nsamples = 6\nvalue = [0, 2, 4]"] ;
3 -> 7 ;
8 [label="gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]"] ;
7 -> 8 ;
9 [label="sepal_length <= 6.95\ngini = 0.444\nsamples = 3\nvalue = [0, 2, 1]"] ;
7 -> 9 ;
10 [label="gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]

## We should test the model on unseen data

In [14]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0,train_size=0.5)
dtree = tree.DecisionTreeClassifier()
dtree.fit(X_train, y_train)
yfit_test = dtree.predict(X_test)
accuracy_score(y_test, yfit_test)

0.96

In [15]:
confusion_matrix(y_test, yfit_test)

array([[21,  0,  0],
       [ 0, 29,  1],
       [ 0,  2, 22]])

In [16]:
print(export_text(dtree, feature_names=list(X.columns)))

|--- petal_length <= 2.35
|   |--- class: setosa
|--- petal_length >  2.35
|   |--- petal_length <= 5.05
|   |   |--- petal_width <= 1.75
|   |   |   |--- class: versicolor
|   |   |--- petal_width >  1.75
|   |   |   |--- sepal_width <= 3.10
|   |   |   |   |--- class: virginica
|   |   |   |--- sepal_width >  3.10
|   |   |   |   |--- class: versicolor
|   |--- petal_length >  5.05
|   |   |--- class: virginica



## Prediction vs explanations

In [17]:
dtree = tree.DecisionTreeClassifier(max_depth=2)
dtree.fit(X_train, y_train)
yfit_test = dtree.predict(X_test)
accuracy_score(y_test, yfit_test)

0.8933333333333333

In [18]:
print(export_text(dtree, feature_names=list(X.columns)))

|--- petal_width <= 0.75
|   |--- class: setosa
|--- petal_width >  0.75
|   |--- petal_length <= 5.05
|   |   |--- class: versicolor
|   |--- petal_length >  5.05
|   |   |--- class: virginica



In [19]:
print(export_graphviz(dtree, feature_names=list(X.columns)))

digraph Tree {
node [shape=box] ;
0 [label="petal_width <= 0.75\ngini = 0.659\nsamples = 75\nvalue = [29, 20, 26]"] ;
1 [label="gini = 0.0\nsamples = 29\nvalue = [29, 0, 0]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="petal_length <= 5.05\ngini = 0.491\nsamples = 46\nvalue = [0, 20, 26]"] ;
0 -> 2 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
3 [label="gini = 0.165\nsamples = 22\nvalue = [0, 20, 2]"] ;
2 -> 3 ;
4 [label="gini = 0.0\nsamples = 24\nvalue = [0, 0, 24]"] ;
2 -> 4 ;
}


## Model parsimony

In [20]:
dtree.feature_importances_  

array([0.        , 0.        , 0.41421017, 0.58578983])

In [21]:
X = iris[[ 'petal_length']]
y = iris['species']

In [22]:
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0,train_size=0.5)
dtree = tree.DecisionTreeClassifier(max_depth=2)
dtree.fit(X_train, y_train)
yfit_test = dtree.predict(X_test)
accuracy_score(y_test, yfit_test)

0.8933333333333333

In [23]:
print(export_text(dtree, feature_names=list(X.columns)))

|--- petal_length <= 2.35
|   |--- class: setosa
|--- petal_length >  2.35
|   |--- petal_length <= 5.05
|   |   |--- class: versicolor
|   |--- petal_length >  5.05
|   |   |--- class: virginica



# Example: Titanic

In [40]:
data = pd.read_csv("data/titanic.csv").dropna()

In [41]:
data.head()

Unnamed: 0,survived,pclass,sex,age,sibsp,parch,fare,embarked,class,who,adult_male,deck,embark_town,alive,alone
1,1,1,female,38.0,1,0,71.2833,C,First,woman,False,C,Cherbourg,yes,False
3,1,1,female,35.0,1,0,53.1,S,First,woman,False,C,Southampton,yes,False
6,0,1,male,54.0,0,0,51.8625,S,First,man,True,E,Southampton,no,True
10,1,3,female,4.0,1,1,16.7,S,Third,child,False,G,Southampton,yes,False
11,1,1,female,58.0,0,0,26.55,S,First,woman,False,C,Southampton,yes,True


In [71]:
X = data[['pclass', 'age', 'adult_male', 'sibsp', 'fare', 'alone']]
y = data['survived']

In [81]:
#X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0,train_size=0.75)

dtree = tree.DecisionTreeClassifier(max_depth=3)
dtree.fit(X, y)
yfit = dtree.predict(X)
accuracy_score(y, yfit)

0.8186813186813187

In [82]:
print(export_text(dtree, feature_names=list(X.columns)))

|--- adult_male <= 0.50
|   |--- fare <= 10.48
|   |   |--- class: 0
|   |--- fare >  10.48
|   |   |--- fare <= 11.49
|   |   |   |--- class: 1
|   |   |--- fare >  11.49
|   |   |   |--- class: 1
|--- adult_male >  0.50
|   |--- age <= 43.00
|   |   |--- fare <= 7.85
|   |   |   |--- class: 0
|   |   |--- fare >  7.85
|   |   |   |--- class: 1
|   |--- age >  43.00
|   |   |--- age <= 47.50
|   |   |   |--- class: 0
|   |   |--- age >  47.50
|   |   |   |--- class: 0

