# K Nearest Neighbors
Straightforward classification to provide an easily interpretable baseline model.

## Load python modules and data

In [None]:
# %% Initialization
import sys
from pathlib import Path
import os

# Hardcoded
basefolder_loc = Path(os.path.abspath('')).parents[1]

import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

sys.path.append(str(basefolder_loc))
from utils import create_metrics

import matplotlib.pyplot as plt

from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
from scipy import interp

In [None]:
data_dir = os.path.join(basefolder_loc, "2.process-data", "data")


def load_data(file_name="train_processed.tsv.gz"):
    file_loc = os.path.join(data_dir, file_name)
    data = pd.read_csv(file_loc, sep="\t")
    X = np.array(data.drop(columns=["cell_code", "cell_id", "plate", "well", "target"]))
    Y = np.array(data.target)
    return X, Y


# %% Training model
X, Y = load_data("train_processed.tsv.gz")

# %% Validation of model
valX, valY = load_data("test_processed.tsv.gz")

## Train and evaluate KNN classifier

In [None]:
KNNclf = KNeighborsClassifier(n_neighbors=25)

In [None]:
KNNclf.fit(X, Y)

In [None]:
prediction = KNNclf.predict(valX)

In [None]:
create_metrics(
    prediction, valY, os.path.join(os.path.abspath(''), "results")
)

For comparison, the most common label (*dopaminereceptor*) is only found in 10% of the test set, so the model perform three times better than random.

    from collections import Counter
    max(Counter(valY).values())/len(valY)


## ROC curve
See https://stackoverflow.com/questions/52910061/implementing-roc-curves-for-k-nn-machine-learning-algorithm-using-python-and-sci/52910821 and https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html#sphx-glr-auto-examples-model-selection-plot-roc-py

In [None]:
prediction_proba = KNNclf.predict_proba(valX)

In [None]:
# From ?KNNclf.predict_proba "Classes are ordered by lexicographic order."
# So we binarize the true labels similarly
valYbin = label_binarize(valY, classes = np.sort(np.unique(valY)))

In [None]:
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(prediction_proba.shape[1]):
    fpr[i], tpr[i], _ = roc_curve(valYbin[:, i], prediction_proba[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(valYbin.ravel(), prediction_proba.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) 

In [None]:
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(prediction_proba.shape[1])]))

# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(prediction_proba.shape[1]):
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])

# Finally average it and compute AUC
mean_tpr /= prediction_proba.shape[1]

fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Plot all ROC curves
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
         label='micro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["micro"]),
         color='deeppink', linestyle=':', linewidth=4)

plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["macro"]),
         color='navy', linestyle=':', linewidth=4)

colors = plt.cm.get_cmap('tab20')
for i, color in zip(range(prediction_proba.shape[1]), 
                    colors([x for x in range(prediction_proba.shape[1])])):
    plt.plot(fpr[i], tpr[i], color=color, lw=2,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()

In [None]:
roc_auc