In [1]:
import numpy as np
import os
os.chdir("..")

# Load the data
data = np.loadtxt(('data/preprocessed-dataset.csv'), delimiter=',', skiprows=1)
os.chdir("methods")

# Split the data into training data and test set
x = data[:,1:35]
y = data[:,35]

In [4]:
from sklearn.model_selection import train_test_split, cross_val_score
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=1, test_size=0.2)

In [5]:
from sklearn.neighbors import KNeighborsClassifier

# Create a kNN classifier
knn = KNeighborsClassifier(5) # The number of neighbors

# Use cross-validation to evaluate the performance of the model
scores = cross_val_score(knn, x_train, y_train, cv=10)
print("Cross-validation scores:", scores)
print("Mean score:", scores.mean())

# Train the model on the entire training set
knn.fit(x_train, y_train)

# Make predicitons on the test set
knn_pred = knn.predict(x_test)

Cross-validation scores: [0.83333333 0.82017544 0.77631579 0.77973568 0.84140969 0.78414097
 0.79295154 0.76651982 0.76651982 0.84581498]
Mean score: 0.8006917072416725


In [6]:
from sklearn import metrics
def model_info(model_pred):
    print("Accuracy:", metrics.accuracy_score(y_test, model_pred))
    print("Precision:", metrics.precision_score(y_test, model_pred))
    print("Recall:", metrics.recall_score(y_test, model_pred))

# Print the performance metrics to the set k-value
model_info(knn_pred)

Accuracy: 0.8101933216168717
Precision: 0.7928994082840237
Recall: 0.8758169934640523


In [7]:
# Search for the best k-value for accuracy, precision, recall and AUC-ROC

best_acc = 0
best_acc_for = 0
best_prec = 0
best_prec_for = 0
best_recall = 0
best_recall_for = 0
best_roc_auc=0
best_roc_auc_for=0

for i in range(1,150):
    knn = KNeighborsClassifier(i) # Iterating over the number of neighbors
    knn.fit(x_train, y_train)
    knn_pred = knn.predict(x_test)
    if metrics.roc_auc_score(y_test, knn_pred) > best_roc_auc:
        best_roc_auc_for = i
        best_roc_auc = metrics.roc_auc_score(y_test, knn_pred)
        print("Found a better ROC-AUC (", metrics.roc_auc_score(y_test, knn_pred),") for n = ",best_roc_auc_for)
    if metrics.accuracy_score(y_test, knn_pred) > best_acc:
        best_acc_for = i
        best_acc = metrics.accuracy_score(y_test, knn_pred)
        print("Found a better accuracy (", metrics.accuracy_score(y_test, knn_pred),") for n = ",best_acc_for)
    if metrics.precision_score(y_test, knn_pred) > best_prec:
        best_prec_for = i
        best_prec = metrics.precision_score(y_test, knn_pred)
        print("Found a better precision (",metrics.precision_score(y_test, knn_pred),") for n = ",best_prec_for)
    if metrics.recall_score(y_test, knn_pred) > best_recall:
        best_recall_for = i
        best_recall = metrics.recall_score(y_test, knn_pred)
        print("Found a better recall (",metrics.recall_score(y_test, knn_pred),") for n = ",best_recall_for)
        
print("-----------\nHere are the best values for n:","\naccuracy:",best_acc_for,"\nprecision:",best_prec_for,"\nrecall:",best_recall_for,"\nROC-AUC:",best_roc_auc_for)

Found a better ROC-AUC ( 0.7604438480081512 ) for n =  1
Found a better accuracy ( 0.7627416520210897 ) for n =  1
Found a better precision ( 0.7731629392971247 ) for n =  1
Found a better recall ( 0.7908496732026143 ) for n =  1
Found a better precision ( 0.831275720164609 ) for n =  2
Found a better ROC-AUC ( 0.8173041079549691 ) for n =  3
Found a better accuracy ( 0.820738137082601 ) for n =  3
Found a better recall ( 0.8627450980392157 ) for n =  3
Found a better precision ( 0.8385964912280702 ) for n =  4
Found a better recall ( 0.8758169934640523 ) for n =  5
Found a better recall ( 0.8888888888888888 ) for n =  9
Found a better recall ( 0.8986928104575164 ) for n =  19
Found a better recall ( 0.9084967320261438 ) for n =  21
Found a better accuracy ( 0.8224956063268892 ) for n =  31
Found a better recall ( 0.9117647058823529 ) for n =  31
Found a better recall ( 0.9150326797385621 ) for n =  33
Found a better ROC-AUC ( 0.8176333904918114 ) for n =  40
Found a better accuracy ( 