# 1. Make classifier for MNIST with 97% accuracy

In [12]:
from sklearn.datasets import fetch_openml
from sklearn.model_selection import cross_val_score, GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import Pipeline

In [6]:
mnist = fetch_openml('mnist_784', version=1)

In [3]:
# split into training and test sets

X, y = mnist["data"], mnist["target"]
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

In [None]:
# try KNeighborsClassifier w/ grid search on weights and n_neighbors

In [17]:
# THIS TAKES A LONG TIME TO RUN!
# uncomment if you have time
'''
pipeline_kn = Pipeline(steps = [("kn_class", KNeighborsClassifier())])

param_grid = [
    {'kn_class__n_neighbors': [1,2,3], 'kn_class__weights': ["uniform", "distance"]}
]

grid_search = GridSearchCV(pipeline_kn, param_grid, cv=5,
                        scoring="accuracy",
                        return_train_score=True)

grid_search.fit(X_train, y_train)

print("Best parameter (CV score=%0.3f):" % grid_search.best_score_)
print(grid_search.best_params_)
'''

Best parameter (CV score=0.971):
{'kn_class__n_neighbors': 3, 'kn_class__weights': 'distance'}


In [None]:
# Solution to the above was 
'''
Best parameter (CV score=0.971):
{'kn_class__n_neighbors': 3, 'kn_class__weights': 'distance'}
'''

# 2. Test augmented data by shifting images in different directions

In [None]:
from scipy.ndimage.interpolation import shift
import matplotlib.pyplot as plt
import numpy as np

In [None]:
all_digits_images = X.iloc[:].values.reshape(70000, 28, 28)

In [None]:
# make new datasets with shifts [0,del_y,del_x]

all_digits_images_shifted_xp1 = shift(all_digits_images, [0,0,1], cval=0)
all_digits_images_shifted_xm1 = shift(all_digits_images, [0,0,-1], cval=0)
all_digits_images_shifted_yp1 = shift(all_digits_images, [0,1,0], cval=0)
all_digits_images_shifted_ym1 = shift(all_digits_images, [0,-1,0], cval=0)

In [None]:
# the labels should still be the same

X_train_augmented = np.concatenate((all_digits_images_shifted_xp1,
                                    all_digits_images_shifted_xm1,
                                    all_digits_images_shifted_yp1,
                                    all_digits_images_shifted_ym1), axis=0)

In [None]:
# flatten again

X_train_augmented = X_train_augmented.reshape(280000,784)