In [4]:
import pandas as pd
import warnings
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score, GridSearchCV
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
warnings.filterwarnings("ignore", category=FutureWarning)

In [5]:
tr_features = pd.read_csv('data/train_features.csv')
tr_labels = pd.read_csv('data/train_labels.csv')

te_features = pd.read_csv('data/test_features.csv')
te_labels = pd.read_csv('data/test_labels.csv')

In [6]:
def print_results(results):
    print(f'BEST PARAMS: {results.best_params_}\n')

    means = sorted(results.cv_results_['mean_test_score'], reverse=True)
    stds = results.cv_results_['std_test_score']
    for mean, std, params in zip(means, stds, results.cv_results_['params']):
        print(f'{round(mean,3)} (+/- {round(std * 2, 3)}) for {params}')

In [7]:
knn = KNeighborsClassifier()
rfscores = cross_val_score(knn, tr_features, tr_labels.values.ravel(), cv=5, n_jobs=32)
print(rfscores)

[0.28046666 0.29402491 0.29197541 0.29150244 0.29670503]


In [8]:
knnparams = {
    'n_neighbors': [1,3,5],
    'weights': [['uniform', None], ['distance', None]],
    'algorithm': ['auto', 'ball_tree', 'kd_tree', 'brute'],
    'leaf_size': [10,20,30],
    'p': [1, 2, 3],
    'metric': ['minkowski', 'precomputed']
}

cv = GridSearchCV(knn, knnparams, cv=5, n_jobs=16)
cv.fit(tr_features, tr_labels.values.ravel())

print_results(cv)

ValueError: X should be a square kernel matrix

In [None]:
knn.fit(tr_features, tr_labels.values.ravel())

In [None]:
y_pred = knn.predict(te_features)
accuracy = round(accuracy_score(te_labels, y_pred), 8)
precision = round(precision_score(te_labels, y_pred, average='weighted'), 8)
recall = round(recall_score(te_labels, y_pred, average='weighted'), 8)
f1 = round(f1_score(te_labels, y_pred, average='weighted'), 8)
print(f'LEAF SIZE: {knn.leaf_size} / NEAREST NEIGHBOURS: {knn.n_neighbors} / A: {accuracy} / P: {precision} / R: {recall} / F1: {f1}')

LEAF SIZE: 30 / NEAREST NEIGHBOURS: 5 / A: 0.29940724 / P: 0.30193629 / R: 0.29940724 / F1: 0.29542616
