# K-Nearest Neighbour Classification

In [None]:
import jax
import numpy as np
from sklearn.datasets import load_digits
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split, StratifiedKFold
from scipy.stats import mode
import pandas as pd
import matplotlib.pyplot as plt

jnp = jax.numpy

In [None]:
X, y = load_digits(return_X_y=True)
fig, ax = plt.subplots(nrows=5, ncols=5, figsize=(5, 5))
ax = ax.ravel()
ix = np.random.choice(np.arange(X.shape[0]), size=(25,))
for i, (image, label) in enumerate(zip(X[ix], y[ix])):
    ax[i].imshow(image.reshape(8, 8), cmap=plt.cm.gray)
_ = [k.set_axis_off() for k in ax]

## How does a KNN classifier work?

In [None]:
def np_predict(x_test, X, y, k=5):
    diff = x_test[:, None, :] - X[None, :, :]
    distance = np.sum(diff ** 2, axis=-1)
    closest = np.argsort(distance, axis=1)[:, :k]
    return mode(y[closest], axis=1).mode.ravel()

### Doing a quick test

In [None]:
from sklearn.model_selection import StratifiedKFold
import pandas as pd
from sklearn.metrics import f1_score

In [None]:
X, y = load_digits(return_X_y=True)

for trix, tsix in StratifiedKFold(n_splits=5).split(X, y):
    xtrain, xtest = X[trix], X[tsix]
    ytrain, ytest = y[trix], y[tsix]
    
    y_pred = np_predict(xtest, xtrain, ytrain)
    print(accuracy_score(ytest, y_pred))

In [None]:
# Too good to be true?
pd.Series(y).value_counts(normalize=True)

In [None]:
# Let's do a better test
for trix, tsix in StratifiedKFold(n_splits=5).split(X, y):
    xtrain, xtest = X[trix], X[tsix]
    ytrain, ytest = y[trix], y[tsix]
    
    y_pred = np_predict(xtest, xtrain, ytrain)
    print(f1_score(ytest, y_pred, average='micro'))

## Time performance of the classifier

In [None]:
xtrain, xtest, ytrain, ytest = train_test_split(X, y, train_size=1200, stratify=y, random_state=42)

In [None]:
%%timeit
ypred = np_predict(xtest, xtrain, ytrain)

## Acceleration with Jax

In [None]:
X, y = load_digits(return_X_y=True)
xtrain, xtest, ytrain, ytest = train_test_split(X, y, train_size=1200, stratify=y, random_state=42)

xtrain, ytrain, xtest, ytest = map(jnp.array, (xtrain, ytrain, xtest, ytest))


def jax_predict(x_test, X, y, k=5):
    diff = x_test[:, None, :] - X[None, :, :]
    distance = jnp.sum(diff ** 2, axis=-1)
    closest = jnp.argsort(distance, axis=1)[:, :k]
    
    def _mode(x):
        un, counts = jnp.unique(x, return_counts=True, size=x.size)
        return un[counts.argmax()]
    
    return jax.vmap(_mode, 0)(closest)

In [None]:
%%timeit
jax_predict(xtest, xtrain, ytrain)

In [None]:
jitted = jax.jit(jax_predict)

In [None]:
%%timeit
jitted(xtest, xtrain, ytrain)