In [1]:
import os
import zipfile
import numpy as np

In [2]:
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import confusion_matrix

In [3]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import gzip
import struct

In [4]:
def read_idx(filename):
    with gzip.open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)

In [5]:
 #Load the EMNIST dataset
X_train = read_idx('emnist_data/gzip/emnist-digits-train-images-idx3-ubyte.gz')
y_train = read_idx('emnist_data/gzip/emnist-digits-train-labels-idx1-ubyte.gz')
X_test = read_idx('emnist_data/gzip/emnist-digits-test-images-idx3-ubyte.gz')
y_test = read_idx('emnist_data/gzip/emnist-digits-test-labels-idx1-ubyte.gz')


In [6]:
 #Check the shapes of the loaded data
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

(240000, 28, 28)
(40000, 28, 28)
(240000,)
(40000,)


In [7]:
X_train = X_train.reshape(X_train.shape[0], 28*28)
X_test = X_test.reshape(X_test.shape[0], 28*28)

In [8]:
# Initialize and train the KNN classifier
knn = KNeighborsClassifier()
knn.fit(X_train, y_train)

In [None]:
# Predict and evaluate
predicted = knn.predict(X_test)
expected = y_test

In [None]:
# Print accuracy
print(f'{knn.score(X_test, y_test):.2%}')

In [None]:
confusion = confusion_matrix(y_true=expected, y_pred=predicted)
confusion_df = pd.DataFrame(confusion, index=range(10), columns=range(10))

In [None]:
# Plot confusion matrix
plt.figure(figsize=(7,6))
sns.heatmap(confusion_df, annot=True, cmap=plt.cm.nipy_spectral_r)
plt.show()

In [None]:
from sklearn.model_selection import KFold, cross_val_score
kfold = KFold(n_splits=10, random_state=11, shuffle=True)
scores = cross_val_score(estimator=knn, X=X_train, y=y_train, cv=kfold)

In [None]:
print(f'Mean accuracy: {scores.mean():.2%}')
print(f'Accuracy Standard Deviation: {scores.std():.2%}')