k-Nearest Neighbors and Handwritten Digit Classification

Using k-NN to classify 8x8 pixel images of hand-written digits.  The k-NN classifier is park of scikit-learn:

[http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier)

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets, neighbors, preprocessing
import numpy as np

The dataset consists of 1,797 images, each 8 pixels by 8 pixels.  The "target" field has the label, telling us the true digit the image represents.

In [None]:
# The digits dataset
digits = datasets.load_digits()

In [None]:
digits.images.shape

In [None]:
digits.target

Here, we define a function that takes an image and the true label and plots it for us:

In [None]:
def plot_handwritten_digit(the_image, label): # plot_handwritten_digit<-function(the_image, label)
    plt.axis('off')
    plt.imshow(the_image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Training: %i' % label)

In [None]:
# this will show us the pixel values
image_num = 30
digits.images[image_num]

In [None]:
# and then we can plot them
plot_handwritten_digit(digits.images[image_num], digits.target[image_num])

Instead of each image being 8x8 pixels, we flatten it to just be a single row of 64 numbers:

In [None]:
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, 64))
labels = digits.target

In [None]:
data.shape

If we need to standardize the features (make them all have mean zero and standard deviation one), this is how we do it:

In [None]:
data_scaled = preprocessing.scale(data)
data_scaled

In [None]:
data.mean(axis=0)

In [None]:
data_scaled.mean(axis=0)

Make a training set and a test set.  We'll use the nearest neighbors from the training set to classify each image from the test set.

In [None]:
n_train = int(0.9*n_samples)

X_train = data[:n_train]
y_train = labels[:n_train]
X_test = data[n_train:]
# re-shape this back so we can plot it again as an image
test_images = X_test.reshape((len(X_test), 8, 8))
y_test = labels[n_train:]

In [None]:
X_train.shape

Scikit-learn classifiers generally have a standard programming interface.  You construct the class:

In [None]:
knn = neighbors.KNeighborsClassifier(n_neighbors=5)

You fit it to your data:

In [None]:
knn.fit(X_train, y_train)

And you predict on new data:

In [None]:
pred_labels = knn.predict(X_test)
pred_labels

In [None]:
pred_probs = knn.predict_proba(X_test)
pred_probs

In [None]:
test_num = 41
plot_handwritten_digit(test_images[test_num], y_test[test_num])
print "true label is %s" % y_test[test_num]
print "predicted label is %s" % pred_labels[test_num]
print "predicted probabilities are %s" % pred_probs[test_num]

Let's find examples where the predicted label is wrong:

In [None]:
np.where(pred_labels != y_test)