In [2]:
import matplotlib.pyplot as plt
import numpy as np
# Scikit-Learn provides the popular datasets easily
from sklearn.datasets import fetch_openml

In [3]:
mnist = fetch_openml("mnist_784", version=1)
mnist.keys()

dict_keys(['data', 'target', 'frame', 'categories', 'feature_names', 'target_names', 'DESCR', 'details', 'url'])

In [4]:
data = mnist['data']
labels = mnist['target']

In [5]:
labels = labels.astype(np.uint8)

In [6]:
X_train, y_train = data[:60000], labels[:60000]
X_test, y_test = data[60000:], labels[60000:]

In [13]:
# This data is shuffled and ready to use
# Let's use KNN classifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
param_grid = {
    'n_neighbors': [2,3,4, 5],
    'weights': ['uniform', 'distance']
}
# The trial of different combinations of hyperparams can be done using "GridSearchCV". It will do cross-validation to evaluate the combinations.
grid_search = GridSearchCV(KNeighborsClassifier(), param_grid, cv=5, scoring='accuracy', verbose=1)
grid_search.fit(X_train, y_train)

Fitting 5 folds for each of 8 candidates, totalling 40 fits


GridSearchCV(cv=5, estimator=KNeighborsClassifier(),
             param_grid={'n_neighbors': [2, 3, 4, 5],
                         'weights': ['uniform', 'distance']},
             scoring='accuracy', verbose=1)

In [14]:
grid_search.best_params_

{'n_neighbors': 4, 'weights': 'distance'}

In [16]:
from sklearn.metrics import f1_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import accuracy_score
knn_pred = cross_val_predict(grid_search.best_estimator_, X_train, y_train, cv=5)
f1_scores = f1_score(y_train, knn_pred, average="macro") # All labels have the same importance.
acc_score = accuracy_score(y_train, knn_pred)

In [22]:
print("Train Set")
print(f1_scores)
print(acc_score)

Train Set
0.9714326948770676
0.9716166666666667


In [19]:
knn_model = grid_search.best_estimator_
knn_test_pred = knn_model.predict(X_test)
f1_scores_test = f1_score(y_test, knn_test_pred, average="macro") # All labels have the same importance.
acc_score_test = accuracy_score(y_test, knn_test_pred)

In [21]:
print("Test Set")
print(f1_scores_test)
print(acc_score_test)

Test Set
0.971224084176584
0.9714
