[<img src="http://www.primaryobjects.com/images/digitrecognition1.jpg" align="right">](http://www.primaryobjects.com/2014/01/09/classifying-handwritten-digits-with-machine-learning/)
# Classification with the MNIST dataset.

[mldata.org](http://mldata.org/) is a public repository for machine learning data, supported by the [PASCAL network](http://www.pascal-network.org/).

The sklearn.datasets package is able to directly download data sets from the repository using the function [sklearn.datasets.fetch_mldata](http://scikit-learn.org/stable/modules/generated/sklearn.datasets.fetch_mldata.html#sklearn.datasets.fetch_mldata).

The [MNIST database](https://en.wikipedia.org/wiki/MNIST_database) contains a total of 70000 examples of handwritten digits of size 28x28 pixels, labeled from 0 to 9.

In [None]:
from sklearn import metrics
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt
%matplotlib inline

## Load data and preprocessing

In [None]:
from sklearn.datasets import fetch_mldata
ds = fetch_mldata('MNIST original', data_home='./data')

In [None]:
X = ds.data
y = ds.target
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

In [None]:
# Don't cheat - fit only on training data
scaler = StandardScaler().fit(X_train)
X_train = scaler.transform(X_train)
# apply same transformation to test data
X_test = scaler.transform(X_test)

## Build the model and train

In [None]:
net = MLPClassifier(solver='sgd',\
                    hidden_layer_sizes=(50, ),\
                    max_iter=4000)

In [None]:
net.fit(X_train, y_train)

## Analysis of the network

### Classification report

In [None]:
expected = y_test
predicted = net.predict(X_test)
print(metrics.classification_report(expected, predicted))

### Confusion matrix

In [None]:
print(metrics.confusion_matrix(expected, predicted))

### Loss curve

In [None]:
plt.plot(net.loss_curve_);
plt.xlabel('Iterations');
plt.ylabel('Loss');