# Chapter 3 - Classification
## MNIST dataset 
### Multiclass Classification

There are 2 strategies that can be used to perform multiclass classification with multiple binay classifiers:
- **OvR strategy** (One versus Rest). Create a system that classifies imagen into 10 classes (from 0 to 9): 0-detector, 1-detector, 2-detector,...
    - Then you get the decision for each classifier and the highest score class is chosen.
- **OvO strategy** (One versus One). Train a binary classifier for every pair of digits.
    - If there are N classes, you'll need to train `N * (N - 1) / 2` classifiers, but, it just has to be trained for the part of the training set for the two classes it must distinguish.
    
Scikit learn detecs when you are trying to do a multiclass classification and it runs OvR or OvO depending on the algorithm.

In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

We download the dataset from fetch_openml

In [2]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
mnist.keys()

dict_keys(['data', 'target', 'feature_names', 'DESCR', 'details', 'categories', 'url'])

Let's create the following variables:
- `X`: contains the full dataset
- `y`: contains the labels
- `X_train`: Training set
- `X_test`: Test set
- `y_train`: Labels training set
- `y_test`: Labels test set

In [3]:
import numpy as np
X, y = mnist["data"], mnist["target"]
y = y.astype(np.uint8)
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

some_digit = X[0]

##### Support Vector Machine Classifier (`sklearn.smv.SVC`)

The steps are the following

    from sklearn.svm import SVC

    svm_clf = SVC()
    svm_clf.fit(X_train, y_train)
    svm_clf.predict([some_digit])


> We can also see the scores for each class by running: `svm_clf.decision_scores([some_digit])`
> See also de resulto of the following command: `svm_clf.classes_`


Also, it can be run deciding manually if use OVR or OVO:


    from sklearn.svm import SVC
    from sklearn.multiclass import OneVsRestClassifier

    ovr_svm_clf = OneVsRestClassifier(SVC())
    ovr_svm_clf.fit(X_train, y_train)
    ovr_svm_clf.predict([some_digit])

    len(ovr_svm_clf.estimators_)


It ¡t also can be done with OvR strategy (`OneVsOneClassificator`)

Now, let's do it with SGDC classificator:

In [None]:
from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train)

In [None]:
sgd_clf.decision_function([some_digit])

And we perform the cross validation of it:

In [None]:
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")

And we scale the inputs to get more accuracy:

In [None]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")

### Error analysis

Looking the cnofusion matrix, we can see which are the most confusing values:

In [None]:
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
conf_mx

Let's see it through Matplotlib's `matshow()` function in order to analyze it easier:

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt

plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

It looks pretty good, since most images on the main diagonal, but 5s looks darker.

Let's focus on the errors. 

- First we need to divide each value by the number of images in the corresponding class so that you can compare error rates instead of absolute numbers of errors.
- Then, we fill the diagonal with 0s to keep only the errors.

In [None]:
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums

np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()

It's clearly seen that the 8 is the most confusing number.