<table align="center">
   <td align="center"><a target="_blank" href="https://colab.research.google.com/github/ds5110/summer-2021/blob/master/09b_naive-bayes-digits.ipynb">
<img src="https://github.com/ds5110/summer-2021/raw/master/colab.png"  style="padding-bottom:5px;" />Run in Google Colab</a></td>
</table>

# 09b -- Naive Bayes with the digits dataset

Reference: 

* [MNIST digits](https://en.wikipedia.org/wiki/MNIST_database) -- wikipedia
* [Section 5.05 from VanderPlas](https://jakevdp.github.io/PythonDataScienceHandbook/05.05-naive-bayes.html) -- github
* [Let's Try t-SNE](https://observablehq.com/@mbostock/lets-try-t-sne) (Mike Bostock) -- observable

In [None]:
from sklearn.datasets import load_digits
digits = load_digits()
digits.images.shape

In [None]:
print(digits.DESCR)

In [None]:
X = digits.data
X.shape

In [None]:
y = digits.target
y.shape

# Visualize the data

* MNIST sprites are 28-by-28
* [sklearn digits](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) are 8-by-8
* [Let's Try t-SNE](https://observablehq.com/@mbostock/lets-try-t-sne) (Mike Bostock) uses the 28-by-28 images

In [None]:
import matplotlib.pyplot as plt

In [None]:
print("digits.images:", digits.images.shape)
digits.images[0]

In [None]:
print("digits.data:", digits.data.shape)
digits.data[0]

In [None]:
plt.imshow(digits.images[0], cmap='binary')

In [None]:
print(digits.data[0].shape)
print(digits.images[0].shape)
plt.imshow(digits.images[0].reshape(8,8), cmap='binary');

In [None]:
fig, axes = plt.subplots(10, 10, figsize=(8, 8),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

for i, ax in enumerate(axes.flat):
    ax.imshow(digits.images[i], cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(digits.target[i]),
            transform=ax.transAxes, color='green')

# Classifying digits

In [None]:
# Train/test split
from sklearn.model_selection import train_test_split

Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, random_state=0)

In [None]:
# Gaussian naive Bayes
from sklearn.naive_bayes import GaussianNB

model = GaussianNB()
model.fit(Xtrain, ytrain)
y_model = model.predict(Xtest)

In [None]:
# Assess model performance
from sklearn.metrics import accuracy_score
accuracy_score(ytest, y_model)

# Confusion matrix

<img src="https://github.com/rasbt/python-machine-learning-book-3rd-edition/raw/master/ch06/images/06_08.png" width="300"/>

[Figure credit:](https://github.com/rasbt/python-machine-learning-book-3rd-edition/blob/master/ch06/ch06.ipynb) Raschka -- github

### References 

* [confusion_matrix](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html) -- scikit-learn.org
* [matplotlib heatmap](https://matplotlib.org/stable/gallery/images_contours_and_fields/image_annotated_heatmap.html) -- matplotlib.org
* [seaborn heatmap](https://seaborn.pydata.org/generated/seaborn.heatmap.html) -- seaborn.pydata.org

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
#sns.set_theme()

mat = confusion_matrix(ytest, y_model)

sns.heatmap(mat, square=True, annot=True, cmap="YlGnBu", cbar=True)
plt.xlabel('predicted value')
plt.ylabel('true value');

# Balanced datasets

Repeat the analysis with balanced train/test datasets



In [None]:
# Balance the datasets with "stratify=target"
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, random_state=1,
                                                stratify=y)

model = GaussianNB()
model.fit(Xtrain, ytrain)
y_model = model.predict(Xtest)

mat = confusion_matrix(ytest, y_model)

sns.heatmap(mat, square=True, annot=True, cmap="YlGnBu", cbar=True)
plt.xlabel('predicted value')
plt.ylabel('true value');

# Visualizing performance

Highlight mistakes with red annotation

In [None]:
fig, axes = plt.subplots(10, 10, figsize=(8, 8),
                         subplot_kw={'xticks':[], 'yticks':[]},
                         gridspec_kw=dict(hspace=0.1, wspace=0.1))

test_images = Xtest.reshape(-1, 8, 8)

for i, ax in enumerate(axes.flat):
    ax.imshow(test_images[i], cmap='binary', interpolation='nearest')
    ax.text(0.05, 0.05, str(y_model[i]),
            transform=ax.transAxes,
            color='blue' if (ytest[i] == y_model[i]) else 'red')