# MNIST handwritten digits classification with nearest neighbors 

In this notebook, we'll use [nearest-neighbor classifiers](http://scikit-learn.org/stable/modules/neighbors.html#nearest-neighbors-classification) to classify MNIST digits using scikit-learn (version 0.20 or later required).

First, the needed imports. 

In [None]:
%matplotlib inline

from pml_utils import get_mnist, show_failures

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import neighbors, __version__
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from packaging.version import Version
assert(Version(__version__) >= Version("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time we need to download the data, which can take a while.

In [None]:
X_train, y_train, X_test, y_test = get_mnist('MNIST')

print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

The training data (`X_train`) is a matrix of size (60000, 784), i.e. it consists of 60000 digits expressed as 784 sized vectors (28x28 images flattened to 1D). `y_train` is a 60000-dimensional vector containing the correct classes ("0", "1", ..., "9") for each training digit.

Let's take a closer look. Here are the first 10 training digits:

In [None]:
pltsize=1
plt.figure(figsize=(10*pltsize, pltsize))

for i in range(10):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(X_train[i,:].reshape(28, 28), cmap="gray")
    plt.title('Class: '+y_train[i])

## 1-NN classifier

### Initialization

Let's create first a 1-NN classifier.  Note that with nearest-neighbor classifiers there is no internal (parameterized) model and therefore no learning required.  Instead, calling the `fit()` function simply stores the samples of the training data in a suitable data structure.

In [None]:
%%time

n_neighbors = 1
clf_nn = neighbors.KNeighborsClassifier(n_neighbors)
clf_nn.fit(X_train, y_train)

### Inference

And try to classify some test samples with it.

In [None]:
%%time

pred_nn = clf_nn.predict(X_test[:200,:])

We observe that the classifier is rather slow, and classifying the whole test set would take quite some time. What is the reason for this?

The accuracy of the classifier:

In [None]:
print('Predicted', len(pred_nn), 'digits with accuracy:',
      accuracy_score(y_test[:len(pred_nn)], pred_nn))

## Faster 1-NN classifier

### Initialization

One way to make our 1-NN classifier faster is to use less training data:

In [None]:
%%time

n_neighbors = 1
n_data = 1024
clf_nn_fast = neighbors.KNeighborsClassifier(n_neighbors)
clf_nn_fast.fit(X_train[:n_data,:], y_train[:n_data])

### Inference

Now we can use the classifier created with reduced data to classify our whole test set in a reasonable amount of time.

In [None]:
%%time

pred_nn_fast = clf_nn_fast.predict(X_test)

The classification accuracy is however now not as good:

In [None]:
print('Predicted', len(pred_nn_fast), 'digits with accuracy:',
      accuracy_score(y_test, pred_nn_fast))

#### Confusion matrix

We can compute the confusion matrix to see which digits get mixed the most:

In [None]:
labels=[str(i) for i in range(10)]
print('Confusion matrix (rows: true classes; columns: predicted classes):'); print()
cm=confusion_matrix(y_test, pred_nn_fast, labels=labels)
print(cm); print()

Plotted as an image:

In [None]:
plt.matshow(cm, cmap=plt.cm.gray)
plt.xticks(range(10))
plt.yticks(range(10))
plt.grid(None)
plt.show()

#### Accuracy, precision and recall

Classification accuracy for each class:

In [None]:
for i,j in enumerate(cm.diagonal()/cm.sum(axis=1)): print("%d: %.4f" % (i,j))

Precision and recall for each class:

In [None]:
print(classification_report(y_test, pred_nn_fast, labels=labels))

#### Failure analysis

We can also inspect the results in more detail. Let's use the `show_failures()` helper function (defined in `pml_utils.py`) to show the wrongly classified test digits.

The helper function is defined as:

```
show_failures(predictions, y_test, X_test, trueclass=None, predictedclass=None, maxtoshow=10)
```

where:
- `predictions` is a vector with the predicted classes for each test set image
- `y_test` the _correct_ classes for the test set images
- `X_test` the test set images
- `trueclass` can be set to show only images for a given correct (true) class
- `predictedclass` can be set to show only images which were predicted as a given class
- `maxtoshow` specifies how many items to show


In [None]:
show_failures(pred_nn_fast, y_test, X_test)

We can use `show_failures()` to inspect failures in more detail. For example:

* show failures in which the true class was "5":

In [None]:
show_failures(pred_nn_fast, y_test, X_test, trueclass='5')

* show failures in which the prediction was "0":

In [None]:
show_failures(pred_nn_fast, y_test, X_test, predictedclass='0')

* show failures in which the true class was "0" and the prediction was "2":

In [None]:
show_failures(pred_nn_fast, y_test, X_test, trueclass='0', predictedclass='2')

We can observe that the classifier makes rather "easy" mistakes, and there might thus be room for improvement.

## Model tuning

Try to improve the accuracy of the nearest-neighbor classifier while preserving a reasonable runtime to classify the whole test set. Things to try include using more than one neighbor (with or without weights) or increasing the amount of training data.  See the documentation for [KNeighborsClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn-neighbors-kneighborsclassifier).

See also http://scikit-learn.org/stable/modules/neighbors.html#nearest-neighbors-classification for more information.