# MNIST handwritten digits classification with decision trees 

In this notebook, we'll use [decision trees](http://scikit-learn.org/stable/modules/tree.html) and [ensembles of trees](http://scikit-learn.org/stable/modules/ensemble.html) to classify MNIST digits using scikit-learn and [XGBoost](https://xgboost.readthedocs.io/en/latest/).

First, the needed imports. 

In [None]:
%matplotlib inline

from pml_utils import get_mnist, show_failures

import numpy as np
from sklearn import __version__
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

import matplotlib.pyplot as plt
import seaborn as sns
import graphviz
sns.set()

from packaging.version import Version
assert(Version(__version__) >= Version("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time we need to download the data, which can take a while.

In [None]:
X_train, y_train, X_test, y_test = get_mnist('MNIST')

print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

## Decision tree

Decision tree is a model that predicts the value of a target variable by learning simple *if-then-else* decision rules inferred from the data features.

### Learning

Let's start by training a decision tree with default parameter values for classifying MNIST digits.

In [None]:
%%time

clf_dt = DecisionTreeClassifier()
clf_dt.fit(X_train, y_train)

### Inference

Classifying a new sample with a decision tree is fast, as it consists of following a single path in the tree until a leaf node is found.

In [None]:
%%time

pred_dt = clf_dt.predict(X_test)
print('Predicted', len(pred_dt), 'digits with accuracy:', accuracy_score(y_test, pred_dt))

### Visualization

Decision trees are simple to understand and visualize.  Large trees can, however, be rather hard to inspect. 

The code below draws the trained decision tree classifier.  The resulting figure is huge, so it is better to save it as a separate file (`mydt.pdf`) and use a separate PDF viewer instead of drawing the figure into this notebook.  

To obtain a small tree for better suited for visualization, try adding the option `max_depth=3` to the above `DecisionTreeClassifier()`. 

In [None]:
export_graphviz(clf_dt, out_file="mydt.dot")
with open("mydt.dot") as f:
    dot_graph = f.read()
a=graphviz.Source(dot_graph)
print('Wrote PDF file:', a.render('mydt', view=False))

## Random forest

Random forest is an ensemble (or a group; hence the name *forest*) of decision trees, obtained by introducing randomness into the tree generation. The prediction of the random forest is obtained by *averaging* the predictions of the individual trees.

Random forest is a solid workhorse that almost always produces serviceable results without much tuning.

### Learning

Random forest classifiers are quick to train, quite robust to hyperparameter values, and often work relatively well.

In [None]:
%%time

n_estimators = 10
clf_rf = RandomForestClassifier(n_estimators=n_estimators)
clf_rf.fit(X_train, y_train)

### Inference

In [None]:
%%time

pred_rf = clf_rf.predict(X_test)
print('Predicted', len(pred_rf), 'digits with accuracy:', accuracy_score(y_test, pred_rf))

#### Failure analysis

The random forest classifier worked quite well, so let's take a closer look.

Here are the first 10 test digits the random forest model classified to a wrong class:

In [None]:
show_failures(pred_rf, y_test, X_test)

We can use `show_failures()` to inspect failures in more detail. For example:

* show failures in which the true class was "5":

In [None]:
show_failures(pred_rf, y_test, X_test, trueclass='5')

* show failures in which the prediction was "0":

In [None]:
show_failures(pred_rf, y_test, X_test, predictedclass='0')

* show failures in which the true class was "0" and the prediction was "2":

In [None]:
show_failures(pred_rf, y_test, X_test, trueclass='0', predictedclass='2')

#### Confusion matrix, accuracy, precision, and recall

We can also compute the confusion matrix to see which digits get mixed the most, and look at classification accuracies separately for each class:

In [None]:
labels=[str(i) for i in range(10)]
print('Confusion matrix (rows: true classes; columns: predicted classes):'); print()
cm=confusion_matrix(y_test, pred_rf, labels=labels)
print(cm); print()

print('Classification accuracy for each class:'); print()
for i,j in enumerate(cm.diagonal()/cm.sum(axis=1)): print("%d: %.4f" % (i,j))

Precision and recall for each class:

In [None]:
print(classification_report(y_test, pred_rf, labels=labels))

## Gradient boosted trees (XGBoost)

Gradient boosted trees (or extreme gradient boosted trees) is another way of constructing ensembles of decision trees, using the *boosting* framework.  Let's use a popular separate package, [XGBoost](http://xgboost.readthedocs.io/en/latest/), to train gradient boosted trees to classify MNIST digits.  

XGBoost has been used to obtain record-breaking results on many machine learning competitions, but have quite a lot of hyperparameters that need to be carefully tuned to get the best performance.

### Learning

Training an XGBoost classifier takes a bit more time, so let's start by using only a subset of the training data. 

XGBoost needs to have integer labels, not strings "0", "1", "2" etc that we have.

In [None]:
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
le.fit(y_train)
le.classes_

In [None]:
%%time

n_data = 10000
clf_xgb = XGBClassifier()
clf_xgb.fit(X_train[:n_data,:], le.transform(y_train[:n_data]))

### Inference

At least with only a subset of training data and default hyperparameters values, XGBoost does not reach the performance of random forest.

In [None]:
%%time

pred_xgb = clf_xgb.predict(X_test)
pred_xgb = le.inverse_transform(pred_xgb)  # convert back to our string labels
print('Predicted', len(pred_xgb), 'digits with accuracy:', accuracy_score(y_test, pred_xgb))

You can also use `show_failures()` to inspect the failures, and calculate the confusion matrix and other metrics as was done with the random forest above.

## Model tuning

Study the documentation of the different decision tree models used in this notebook ([decision trees](http://scikit-learn.org/stable/modules/tree.html), [tree ensembles](http://scikit-learn.org/stable/modules/ensemble.html), [XGBoost](https://xgboost.readthedocs.io/en/latest/)), and experiment with different hyperparameter values.  

Report the highest classification accuracy you manage to obtain for each model type.  Also mark down the parameters you used, so others can try to reproduce your results. 