In [1]:
import scipy
import sklearn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from scipy.spatial import distance
from collections import Counter
import pprint
import ipdb
import math
from multiprocessing import Pool

In [2]:
mnist = fetch_openml("mnist_784", as_frame=False)
data = mnist["data"]
labels = mnist["target"]

In [3]:
def knn_search(train, train_labels, query, k):
    distances = np.array([distance.euclidean(train_sample, query) for train_sample in train])
    k_nearest_indices = np.argsort(distances)[:k]
    k_nearest_labels = train_labels[k_nearest_indices]
    chosen_label_item, = Counter(k_nearest_labels).most_common(1)
    chosen_label, _ = chosen_label_item
    
    return chosen_label

In [4]:
idx = np.random.RandomState(0).choice(70000, 11000)
train = data[idx[:1000], :].astype(int)
train_labels = labels[idx[:1000]]
test = data[idx[10000:], :].astype(int)
test_labels = labels[idx[10000:]]

In [6]:
knn_search(train, train_labels, test[0], 2)

'9'

In [7]:
prediction = np.array([knn_search(train, train_labels, query, k=10) for query in test])

In [8]:
acc = sum(prediction == test_labels) / len(test)
f"{acc:%}"

'85.800000%'

In [8]:
labels

array(['5', '0', '4', ..., '4', '5', '6'], dtype=object)

In [5]:
def knn_accuracy(train, train_labels, test, test_labels, k):
    prediction = np.array([knn_search(train, train_labels, query=query, k=k) for query in test])
    acc = sum(prediction == test_labels) / len(test_labels)
    
    return acc

In [None]:
def plot_acc_as_k(train, train_labels, test, test_labels):
    k_ranges = range(1, 1001)
    items = ((train, train_labels, test, test_labels, k) for k in k_ranges)
    results = [knn_accuracy(*item) for item in items]
    plt.scatter(k_ranges, results, s=50)
    plt.xlabel("k value")
    plt.grid(True)

plot_acc_as_k(train, train_labels, test, test_labels)

In [None]:
def plot_acc_as_n(train, train_labels, test, test_labels):
    n_ranges = range(100, 5001, 100)
    items = ((data[idx[:n], :].astype(int), labels[idx[:n]], test, test_labels, 1) for n in n_ranges)
    results = [knn_accuracy(*item) for item in items]
    plt.scatter(n_ranges, results, s=50)
    plt.grid(True)

plot_acc_as_n(train, train_labels, test, test_labels)