# MNIST handwritten digits classification with an ensemble of classifiers 

In this notebook, we'll use a [classifier emsemble](https://scikit-learn.org/stable/modules/ensemble.html#voting-classifier) to classify MNIST digits using scikit-learn (version 0.20 or later required).

First, the needed imports. 

In [None]:
%matplotlib inline

from pml_utils import get_mnist, show_failures

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import __version__
from sklearn.linear_model import SGDClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.naive_bayes import BernoulliNB
from sklearn.ensemble import VotingClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from packaging.version import Version
assert(Version(__version__) >= Version("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time we need to download the data, which can take a while.

In [None]:
X_train, y_train, X_test, y_test = get_mnist('MNIST')

print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

The training data (`X_train`) is a matrix of size (60000, 784), i.e. it consists of 60000 digits expressed as 784 sized vectors (28x28 images flattened to 1D). `y_train` is a 60000-dimensional vector containing the correct classes ("0", "1", ..., "9") for each training digit.

## Individual classifiers

Let's first define and train a set of different classifiers.

### SGDClassifier

In [None]:
%%time

clf_sgd = SGDClassifier()
print(clf_sgd.fit(X_train, y_train))
pred_sgd = clf_sgd.predict(X_test)
print('Predicted', len(pred_sgd), 'digits with accuracy:', accuracy_score(y_test, pred_sgd))

### Decision tree

In [None]:
%%time

clf_dt = DecisionTreeClassifier()
print(clf_dt.fit(X_train, y_train))
pred_dt = clf_dt.predict(X_test)
print('Predicted', len(pred_dt), 'digits with accuracy:', accuracy_score(y_test, pred_dt))

### Bernoulli naive Bayes

In [None]:
%%time

clf_bnb = BernoulliNB(binarize=128.)
print(clf_bnb.fit(X_train, y_train))
pred_bnb = clf_bnb.predict(X_test)
print('Predicted', len(pred_bnb), 'digits with accuracy:', accuracy_score(y_test, pred_bnb))

## Ensemble classifier

The goal of ensemble methods is to combine the predictions of several base classifiers to improve generalizability and robustness.

### Learning

We use [`VotingClassifier`](https://scikit-learn.org/stable/modules/ensemble.html#voting-classifier) to combine the results of the individual classifiers.
The default mode is to use majority (`"hard"`) voting, where each classifier gets a vote and the final prediction is the class that gets the majority of the votes.
Another option is to use the average of the predicted probabilities (`"soft"` voting), which however requires that all used individual classifiers are able to predict class probabilities. 

In [None]:
%%time

clf_vote = VotingClassifier(estimators=[('sgd', clf_sgd),
                                        ('dt', clf_dt),
                                        ('bnb', clf_bnb)],
                            voting='hard')
clf_vote.fit(X_train, y_train)

### Inference

The classification accuracy of the ensemble classifier:

In [None]:
pred_vote = clf_vote.predict(X_test)
print('Predicted', len(pred_vote), 'digits with accuracy:', accuracy_score(y_test, pred_vote))

#### Confusion matrix

We can compute the confusion matrix to see which digits get mixed the most:

In [None]:
labels=[str(i) for i in range(10)]
print('Confusion matrix (rows: true classes; columns: predicted classes):'); print()
cm=confusion_matrix(y_test, pred_vote, labels=labels)
print(cm); print()

#### Accuracy, precision and recall

Classification accuracy for each class:

In [None]:
for i,j in enumerate(cm.diagonal()/cm.sum(axis=1)): print("%d: %.4f" % (i,j))

Precision and recall for each class:

In [None]:
print(classification_report(y_test, pred_vote, labels=labels))

#### Failure analysis

We can also do some failure analysis.  Let's check the 10 first wrongly predicted digits.

In [None]:
show_failures(pred_vote, y_test, X_test)

## Model tuning

Try adding various classifiers covered on this course to the ensemble and experiment with different setups.  

Report the highest classification accuracy you manage to obtain.  Also mark down the parameters you used, so others can try to reproduce your results. 
