# MNIST handwritten digits classification with naive Bayes 

In this notebook, we'll use [naive Bayes classifiers](https://scikit-learn.org/stable/modules/naive_bayes.html) to classify MNIST digits using scikit-learn (version 0.20 or later required).

First, the needed imports. 

In [None]:
%matplotlib inline

from pml_utils import get_mnist, show_failures

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets, __version__
from sklearn.naive_bayes import GaussianNB, BernoulliNB
from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from packaging.version import Version
assert(Version(__version__) >= Version("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time we need to download the data, which can take a while.

In [None]:
X_train, y_train, X_test, y_test = get_mnist('MNIST')

print('MNIST data loaded: train:',len(X_train),'test:',len(X_test))
print('X_train:', X_train.shape)
print('y_train:', y_train.shape)
print('X_test', X_test.shape)
print('y_test', y_test.shape)

The training data (`X_train`) is a matrix of size (60000, 784), i.e. it consists of 60000 digits expressed as 784 sized vectors (28x28 images flattened to 1D). `y_train` is a 60000-dimensional vector containing the correct classes ("0", "1", ..., "9") for each training digit.

Let's take a closer look. Here are the first 10 training digits:

In [None]:
pltsize=1
plt.figure(figsize=(10*pltsize, pltsize))

for i in range(10):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(X_train[i,:].reshape(28, 28), cmap="gray")
    plt.title('Class: '+y_train[i])

## Naive Bayes classifiers

Naive Bayes classifiers are a family of simple classifiers based on applying Bayes' theorem. The classifiers are called "naive" as we make a strong assumption that the features are conditionally independent given the value of the class variable. While this assumption is not usually true, a naive Bayes classifier may in practice work reasonably well. Naive Bayes classifiers are also simple and fast compared to many more sophisticated methods.

The classification rule for naive Bayes is
\begin{equation}
\hat{y} = \arg\max_yP(y)\prod_{i=1}^nP(x_i|y)
\end{equation}
where $P(y)$ is the prior probability of class $y$ and $P(x_i|y)$ is the class-conditional likelihood of feature $i$.

## Gaussian naive Bayes

In Gaussian naive Bayes, the likelihood of the features is assumed to be Gaussian
\begin{equation}
P(x_i|y) = \mathcal{N}(x_i\,|\,\mu_{iy},\sigma_{iy}^2) 
\end{equation}
where $\mu_{iy}$ and $\sigma_{iy}^2$ are the mean and variance, respectively, of feature $i$ in objects of class $y$.

In [None]:
mu = 192.
sigma = 32.
x = np.arange(255.)
plt.plot(x, 1/(sigma * np.sqrt(2 * np.pi)) * np.exp( - (x - mu)**2 / (2 * sigma**2)),
         lw=3)
plt.xticks([0,127,255])
plt.title('Gaussian distribution with $\mu={}$ and $\sigma={}$'.format(mu, sigma));

The prior probabilities $P(y)$ are learned from training data by default. 

### Learning

Training a naive Bayes classifier is fast:

In [None]:
%%time

clf_gnb = GaussianNB()
clf_gnb.fit(X_train, y_train)

We can take a look at the mean and variance of features for each class.

In [None]:
pltsize=1
plt.figure(figsize=(10*pltsize, pltsize))
plt.suptitle('Mean of each feature', y=1.3)

for i in range(10):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(clf_gnb.theta_[i,:].reshape(28, 28), cmap="gray")
    plt.title('Class: '+str(i))

plt.figure(figsize=(10*pltsize, pltsize))
plt.suptitle('Variance of each feature', y=1.1)
for i in range(10):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(clf_gnb.var_[i,:].reshape(28, 28), cmap="gray")

### Inference

Evaluating a naive Bayes classifier is also fast:

In [None]:
%%time

predictions = clf_gnb.predict(X_test)

The accuracy of the classifier:

In [None]:
print('Predicted', len(predictions), 'digits with accuracy:', accuracy_score(y_test, predictions))

We can also inspect the results in more detail. Let's use the `show_failures()` helper function to show the wrongly classified test digits.

In [None]:
show_failures(predictions, y_test, X_test)

In [None]:
show_failures(predictions, y_test, X_test, trueclass='5')

We can observe that the classifier makes rather easy mistakes.

## Bernoulli naive Bayes

Gaussian naive Bayes assumes that the features are normally distributed, which is not a good assumption for the MNIST digits. Let's therefore use a second approach and model each feature as a binary variable ("black" or "white"). A suitable distribution in this case is the Bernoulli
\begin{equation}
P(x_i|y) = \mathrm{Ber}(x_i\,|\,\theta_{iy}) 
\end{equation}
where $\theta_{iy}$ is the probability that feature $i$ is "white" in objects of class $y$.

In [None]:
theta = 0.4
plt.bar([0, 255], [1-theta, theta], width=16.)
plt.xticks([0,127,255])
plt.title('Bernoulli distribution with $\\theta={}$'.format(theta));

### Learning

Bernoulli naive Bayes assumes binary data, so we'll binarize the digits with a threshold in the middle.

In [None]:
%%time

clf_bnb = BernoulliNB(binarize=128.)
clf_bnb.fit(X_train, y_train)

We can take a look at the probabilities of features for each class.

In [None]:
pltsize=1
plt.figure(figsize=(10*pltsize, pltsize))
plt.suptitle('Probability of each feature', y=1.3)

for i in range(10):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(np.exp(clf_bnb.feature_log_prob_[i,:]).reshape(28, 28), cmap="gray")
    plt.title('Class: '+str(i))

### Inference

In [None]:
%%time

predictions = clf_bnb.predict(X_test)

In [None]:
print('Predicted', len(predictions), 'digits with accuracy:', accuracy_score(y_test, predictions))

In [None]:
show_failures(predictions, y_test, X_test)

In [None]:
show_failures(predictions, y_test, X_test, trueclass='5')