# MNIST handwritten digits anomaly detection

In this notebook, we'll test some anomaly detection methods to detect outliers within MNIST digits data using scikit-learn.

First, the needed imports. 

In [None]:
%matplotlib inline

from pml_utils import get_mnist, show_anomalies

import numpy as np
from sklearn import __version__
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor

import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

from packaging.version import Version
assert(Version(__version__) >= Version("0.20")), "Version >= 0.20 of sklearn is required."

Then we load the MNIST data. First time it downloads the data, which can take a while.

To speed up the computations, let's use only 10000 digits in this notebook.

In [None]:
X_train, y_train, X_test, y_test = get_mnist('MNIST')

X = X_train[:10000]
y = y_train[:10000]
print()
print('MNIST data loaded:')
print('X:', X.shape)
print('y:', y.shape)

Let us then create some outliers in our data. We 
* invert all pixels of one sample
* shuffle all pixels of one sample, and
* add salt-and-pepper noise to 10% of pixels of one sample.

You can also continue creating more outliers in a similar fashion. 

In [None]:
X[9999,:]=255-X[9999,:]
np.random.shuffle(X[9998,:])
for i in np.random.randint(0, X.shape[1], int(X.shape[1]*0.1)):
    X[9997,i] = 0.0 if np.random.rand()<0.5 else 255.0 

Let's have a look at our outliers:

In [None]:
n_outliers = 3

pltsize = 5
plt.figure(figsize=(n_outliers*pltsize, pltsize))

for i in range(n_outliers):
    plt.subplot(1,10,i+1)
    plt.axis('off')
    plt.imshow(X[9999-i,:].reshape(28,28), cmap="gray")

## Isolation forest

[Isolation forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.IsolationForest.html#sklearn.ensemble.IsolationForest) is an outlier detection method based on using random forests. The idea is to isolate data items by random features and splits. Outliers are easier to isolate, so they tend to produce shorter paths on average.

We specify the number of trees as `n_estimators` and the assumed proportion of outliers in the data set as `if_contamination`.

In [None]:
%%time

n_estimators = 100
if_contamination = 0.001

if_model = IsolationForest(n_estimators=n_estimators, 
                           contamination=if_contamination)
if_pred = if_model.fit(X).predict(X)
print('Number of anomalies:', np.sum(if_pred==-1))

We use a function `show_anomalies` to take a look at the found outliers.

In [None]:
show_anomalies(if_pred, X)

## Local outlier factor

[Local outlier factor](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.LocalOutlierFactor.html#sklearn.neighbors.LocalOutlierFactor) is another method for outlier detection. It is based on k-nearest neighbors and computes the local density of data points with respect to their neighbors. Outliers have substantially lower local density than inliers.

We specify the number of neighbors considered as `n_neighbors` and the assumed proportion of outliers in the data set as `lof_contamination`.

In [None]:
%%time

n_neighbors = 20
lof_contamination = 0.001

lof_model = LocalOutlierFactor(n_neighbors=n_neighbors,
                               contamination=lof_contamination)
lof_pred = lof_model.fit_predict(X)
print('Number of anomalies:', np.sum(lof_pred==-1))

In [None]:
show_anomalies(lof_pred, X)

## Experiments

Experiment with different parameters for [isolation forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.IsolationForest.html#sklearn.ensemble.IsolationForest) and [local outlier factor](https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.LocalOutlierFactor.html#sklearn.neighbors.LocalOutlierFactor). Are the algorithms able to find all the generated outliers?

You can also create more outliers in a similar fashion. 