# MNIST handwritten digits classification with linear methods

In this notebook, we'll classify handwritten digits using linear classifiers and [scikit-learn](https://scikit-learn.org/).

First, the needed imports. 

In [None]:
%matplotlib inline

from pml_utils import get_mnist

import sklearn
from sklearn import svm
from sklearn.linear_model import LogisticRegression
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from packaging.version import Version
assert(Version(sklearn.__version__) >= Version("0.20")), "Version >= 0.20 of sklearn is required."

## MNIST digit data

Then we will load the MNIST data. The first time, it will download the data over the network, which can take a while.

In [None]:
X_train, y_train, X_test, y_test = get_mnist('MNIST')

print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

The training data `X_train` is a matrix of size 60000x784, i.e., it consists of 60000 images expressed as vectors of length 784.  These vectors are in fact "flattened" 28x28 images, where each component corresponds the gray scale value of a pixel (0=black, 0.5=middle gray, 1=white, etc.).

`y_train` is a 60000-dimensional vector containing the correct classes ("0", "1", ..., "9") for each training sample.


### Plotting images

Let's take a closer look at the MNIST images. Here are the first 10 training digits plotted as images together with the correct class label:

In [None]:
pltsize=1
plt.figure(figsize=(10*pltsize, pltsize))

for i in range(10):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(X_train[i,:].reshape(28,28), cmap="gray")
    plt.title('Class: '+str(y_train[i]))

Note that for each digit we use `reshape(28,28)` to transform the 768-size vector into a 28x28 size image matrix.

## Using scikit-learn

In this course we will be mostly relying on [scikit-learn, a machine learning framework for Python](https://scikit-learn.org/stable/index.html).  

In scikit-learn all machine learning models follow the same pattern:

1. First create a model object with the appropriate constructor for the method you are using.  Here you can also specify _hyperparameters_ for the method:
```
clf = SomeModel(param1=a, param2=b)
```


2. Next, fit your model to the training set (e.g., train your classifier):
```
clf.fit(X_train, y_train)
```


3. Finally, for the inference stage (e.g., predict the classes of new unseen items with your trained classifier):
```
y_predicted_test = clf.predict(X_test)
```


## Logistic regression

Let's start by trying logistic regression with a stochastic gradient descent algorithm.  The corresponding scikit-learn class is [LogisticRegression](http://scikit-learn.org/stable/modules/linear_model.html#logistic-regression).

### Learning

We'll actually just use the first 10,000 samples as the method is rather slow. We are using the "sag" solver (which is a variant of SGD), and the one-versus-rest strategy for doing multi-class classification.

In [None]:
%%time

clf_lr = OneVsRestClassifier(LogisticRegression(solver='sag'))
clf_lr.fit(X_train[:10000,:], y_train[:10000])

### Inference

As the decision boundaries are linear, prediction with logistic regression is fast:

In [None]:
%%time 

pred_lr = clf_lr.predict(X_test)
print('Predicted', len(pred_lr), 'digits with accuracy:', accuracy_score(y_test, pred_lr))

#### Confusion matrix

We can compute the confusion matrix to see which digits get mixed the most:

In [None]:
labels=[str(i) for i in range(10)]
print('Confusion matrix (rows: true classes; columns: predicted classes):'); print()
cm=confusion_matrix(y_test, pred_lr, labels=labels)
print(cm); print()

If we plot it as an image, we can see it more visually.  The matrix looks quite good as most image are on the diagonal, meaning they were classified correctly.

In [None]:
plt.matshow(cm, cmap=plt.cm.gray)
plt.xticks(range(10))
plt.yticks(range(10))
plt.grid(None)
plt.show()

#### Accuracy, precision and recall

Classification accuracy for each class:

In [None]:
for i,j in enumerate(cm.diagonal()/cm.sum(axis=1)): print("%d: %.4f" % (i,j))

Precision and recall for each class:

In [None]:
print(classification_report(y_test, pred_lr, labels=labels))

## Linear SVM

### Learning

Next we'll try linear SVM.  Let's use the [`LinearSVC`](http://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html) class, as it is a specialized in linear SVMs. `C` is the penalty parameter.

In [None]:
%%time

C = 1.0
clf_lsvm = svm.LinearSVC(C=C, multi_class='ovr')
clf_lsvm.fit(X_train[:10000,:], y_train[:10000])

The training of a Linear SVM is rather fast, so it seems more data could easily be used in the training.

### Inference

Again, prediction with linear functions is fast:

In [None]:
%%time

pred_lsvm = clf_lsvm.predict(X_test)
print('Predicted', len(pred_lsvm), 'digits with accuracy:', accuracy_score(y_test, pred_lsvm))

#### Confusion matrix

We can compute the confusion matrix to see which digits get mixed the most, and look at classification accuracies separately for each class:

In [None]:
labels=[str(i) for i in range(10)]
print('Confusion matrix (rows: true classes; columns: predicted classes):'); print()
cm=confusion_matrix(y_test, pred_lsvm, labels=labels)
print(cm); print()

If we plot it as an image, we can see it more visually.  The matrix looks quite good as most image are on the diagonal, meaning they were classified correctly.

In [None]:
plt.matshow(cm, cmap=plt.cm.gray)
plt.xticks(range(10))
plt.yticks(range(10))
plt.grid(None)
plt.show()

#### Accuracy, precision and recall

Classification accuracy for each class:

In [None]:
for i,j in enumerate(cm.diagonal()/cm.sum(axis=1)): print("%d: %.4f" % (i,j))

Precision and recall for each class:

In [None]:
print(classification_report(y_test, pred_lsvm, labels=labels))

## Model tuning

Study the scikit-learn documentation of [LogisticRegression](http://scikit-learn.org/stable/modules/linear_model.html#logistic-regression) and [LinearSVC](http://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html). Experiment with different hyperparameter values. You can also try [SGDClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html) which does pure SGD (one sample at a time).

Can you improve on the accuracy or make training faster?

Report the highest classification accuracy you manage to obtain. Also mark down the parameters you used, so others can try to reproduce your results.
