# MNIST handwritten digits classification with support vector machines 

In this notebook, we'll use [support vector machines (SVMs)](http://scikit-learn.org/stable/modules/svm.html#svm-classification) and related algorithms to classify MNIST digits using scikit-learn.

First, the needed imports. 

In [None]:
%matplotlib inline

from pml_utils import get_mnist, show_failures

import sklearn
from sklearn import svm
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from packaging.version import Version
assert(Version(sklearn.__version__) >= Version("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time it downloads the data, which can take a while.

In [None]:
X_train, y_train, X_test, y_test = get_mnist('MNIST')

print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

## Linear SVM 

### Learning

Our first classifier is a linear SVM trained with a subset of training data.  Let's use the [`LinearSVC`](http://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html#sklearn.svm.LinearSVC) class, as it is a specialized in linear SVMs. `C` is the penalty parameter.  (The general `SVC` has a similar `kernel=’linear’` option that can also be used.  The third option is to use `SGDClassifier`.)

In [None]:
%%time

C = 1.0
clf_lsvm = svm.LinearSVC(C=C)
print(clf_lsvm.fit(X_train[:10000,:], y_train[:10000]))

The training of a Linear SVM is rather fast, so it seems more data could easily be used in the training.

Note also that the default multiclass strategy of `LinearSVM` is one-vs-rest.

### Inference

As the decision boundaries are linear, prediction with linear SVMs is fast:

In [None]:
pred_lsvm = clf_lsvm.predict(X_test)
print('Predicted', len(pred_lsvm), 'digits with accuracy:', accuracy_score(y_test, pred_lsvm))

## Kernel SVM

In addition to linear classification, SVMs can be used for non-linear classification by implicitly mapping the input features into high-dimensional feature spaces.  This is sometimes called the *kernel trick*, as the implicit mapping is often computationally cheaper than explicitly operating in the high-dimensional space.

### Learning

Let's train a *3rd degree polynomial kernel SVM* with the *one-vs-rest* strategy for multiclass classification. A Gaussian kernel, that is `kernel='rbf'` is another common choice.

In [None]:
%%time

clf_ksvm = svm.SVC(decision_function_shape='ovr', kernel='poly', degree=3)
print(clf_ksvm.fit(X_train[:10000,:], y_train[:10000]))

### Inference

Despite the kernel trick, prediction of new samples is noticeably slower than with the linear SVM.  The classification accuracy, on the other hand, is improved. 

In [None]:
%%time

pred_ksvm = clf_ksvm.predict(X_test)
print('Predicted', len(pred_ksvm), 'digits with accuracy:', accuracy_score(y_test, pred_ksvm))

#### Confusion matrix

We can compute the confusion matrix to see which digits get mixed the most:

In [None]:
labels=[str(i) for i in range(10)]
print('Confusion matrix (rows: true classes; columns: predicted classes):'); print()
cm=confusion_matrix(y_test, pred_ksvm, labels=labels)
print(cm); print()

If we plot it as an image, we can see it more visually.  The matrix looks quite good as most image are on the diagonal, meaning they were classified correctly.

In [None]:
plt.matshow(cm, cmap=plt.cm.gray)
plt.show()

#### Accuracy, precision and recall

Classification accuracy for each class:

In [None]:
for i,j in enumerate(cm.diagonal()/cm.sum(axis=1)): print("%d: %.4f" % (i,j))

Precision and recall for each class:

In [None]:
print(classification_report(y_test, pred_ksvm, labels=labels))

#### Failure analysis

We can also do some failure analysis.  Let's check the 10 first wrongly predicted digits.

In [None]:
show_failures(pred_ksvm, y_test, X_test)

## Model tuning

Study the scikit-learn documentation of the linear and kernel [SVMs](http://scikit-learn.org/stable/modules/svm.html#svm) and the available SVM classes ([`SVC`](http://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html#sklearn.svm.SVC), [`NuSVC`](http://scikit-learn.org/stable/modules/generated/sklearn.svm.NuSVC.html#sklearn.svm.NuSVC) and [`LinearSVC`](http://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html#sklearn.svm.LinearSVC). Experiment with different hyperparameter values.

Report the highest classification accuracy you manage to obtain. Also mark down the parameters you used, so others can try to reproduce your results.
