# MNIST handwritten digits visualization with scikit-learn

In this notebook, we'll use some popular visualization techniques to visualize MNIST digits.  This notebook is based on the scikit-learn embedding examples found [here](http://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html).

First, the needed imports.

In [None]:
%matplotlib inline

from time import time

from pml_utils import get_mnist

import numpy as np
import sklearn
from sklearn import random_projection, decomposition, manifold, __version__

import matplotlib.pyplot as plt

from packaging.version import Version
assert(Version(__version__) >= Version("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time it downloads the data, which can take a while.

In this notebook, we only use 1024 first samples of the training data.  This reduces the time needed to calculate the visualizations and makes the visualizations appear less crowded.

In [None]:
X_train, y_train, X_test, y_test = get_mnist('MNIST')

# Let's inspect only 1024 first training samples in this notebook
X = X_train[:1024]
y = y_train[:1024]
print()
print('MNIST data loaded:')
print('X:', X.shape)
print('y:', y.shape)

Let's start by inspecting our data.  For such a small dataset, we can actually draw all the samples at once:

In [None]:
n_img_per_row = 32 # 32*32=1024
img = np.zeros((28 * n_img_per_row, 28 * n_img_per_row))

for i in range(n_img_per_row):
    ix = 28 * i
    for j in range(n_img_per_row):    
        iy = 28 * j
        img[ix:ix + 28, iy:iy + 28] = X[i * n_img_per_row + j,:].reshape(28,28)
img = np.max(img)-img

plt.figure(figsize=(9, 9))
plt.imshow(img, cmap='gray')
plt.title('1024 first MNIST digits')
ax=plt.axis('off')

Let's define a helper function to plot the different visualizations:

In [None]:
def plot_embedding(X, title=None, time=None, show_digits=True):
    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)

    plt.figure(figsize=(9,6))
    plt.axis('off')
    if show_digits:
        for i in range(X.shape[0]):
            plt.text(X[i, 0], X[i, 1], str(y[i]),
                     color=plt.cm.Set1(int(y[i]) / 10.),
                     fontdict={'weight': 'bold', 'size': 9})
    else:
        s = plt.scatter(X[:, 0], X[:, 1],
                        color=[plt.cm.Set1(int(yi) / 10.) for yi in y])

    if title is not None:
        if t0 is not None:
            plt.title("%s (%.2fs)" % (title, time))
        else:
            plt.title(title)

## 1. Random projection

A simple first visualization is a [random projection](http://scikit-learn.org/stable/modules/random_projection.html#random-projection) of the data into two dimensions.

In [None]:
t0 = time()
rp = random_projection.SparseRandomProjection(n_components=2, random_state=42)
X_projected = rp.fit_transform(X)
t = time() - t0

plot_embedding(X_projected, "Random projection", t)

The data can also be plotted with points instead of digit labels by setting `show_digits=False`:

In [None]:
plot_embedding(X_projected, "Random projection", t, show_digits=False)

## 2. PCA

[Principal component analysis](http://scikit-learn.org/stable/modules/decomposition.html#pca) (PCA) is a standard method to decompose a high-dimensional dataset in a set of successive orthogonal components that explain a maximum amount of the variance. Here we project the data into two first principal components. The components have the maximal possible variance under the orthogonality constraint.

In [None]:
t0 = time()
pca = decomposition.PCA(n_components=2)
X_pca = pca.fit_transform(X)
t = time() - t0

plot_embedding(X_pca, "PCA projection", t)

## 3. MDS

[Multidimensional scaling](http://scikit-learn.org/stable/modules/manifold.html#multidimensional-scaling) (MDS) seeks a low-dimensional representation of the data in which the distances try to respect the distances in the original high-dimensional space.  

In [None]:
t0 = time()
mds = manifold.MDS(n_components=2, max_iter=500)
X_mds = mds.fit_transform(X)
t = time() - t0

plot_embedding(X_mds, "MDS embedding", t)

## 4. t-SNE

[t-distributed Stochastic Neighbor Embedding](http://scikit-learn.org/stable/modules/manifold.html#t-sne) (t-SNE) is a relatively new and popular tool to visualize high-dimensional data.  t-SNE is particularly sensitive to local structure and can often reveal clusters in the data.

t-SNE has an important tuneable parameter called `perplexity`, that can have a large effect on the resulting visualization, depending on the data.  Typical values for perplexity are between 5 and 50.  

In [None]:
t0 = time()
perplexity=30
tsne = manifold.TSNE(n_components=2, perplexity=perplexity)
X_tsne = tsne.fit_transform(X)
t = time() - t0

plot_embedding(X_tsne, "t-SNE embedding with perplexity=%d" % perplexity, t)

## 5. Further visualizations

Take a look at the original scikit-learn [embedding examples](http://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html) for more visualizations.  Try some of these (for example LLE and isomap) on the MNIST data.