# KNeighbors

In [3]:
from joblib import dump

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import shap

from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, GridSearchCV

from sklearn.metrics import accuracy_score, balanced_accuracy_score,  f1_score, precision_score, recall_score, confusion_matrix

In [4]:
df = pd.read_csv("../../../data/combined_subjects.csv")

In [5]:
df.head()

Unnamed: 0.1,Unnamed: 0,net_acc_mean,net_acc_std,net_acc_min,net_acc_max,EDA_phasic_mean,EDA_phasic_std,EDA_phasic_min,EDA_phasic_max,EDA_smna_mean,...,ACC_z_min,ACC_z_max,0_mean,0_std,0_min,0_max,BVP_peak_freq,TEMP_slope,subject,label
0,0,1.331891,0.153556,1.014138,1.678399,2.247876,1.112076,0.367977,4.459367,1.592308,...,-2.6e-05,6e-05,0.027558,0.013523,0.0,0.087383,0.080556,-0.000102,2,1
1,1,1.218994,0.090108,1.014138,1.4858,1.781323,1.203991,0.232625,4.459367,1.34775,...,-2.6e-05,6e-05,0.02342,0.01531,0.0,0.087383,0.144444,-0.000424,2,1
2,2,1.143312,0.110987,0.948835,1.4858,1.173169,1.285422,0.00695,4.459367,0.752335,...,-1.5e-05,4.9e-05,0.018759,0.012604,0.0,0.071558,0.102778,-0.000814,2,1
3,3,1.020669,0.135308,0.81109,1.239944,0.311656,0.27865,0.00695,1.303071,0.198576,...,-5e-06,3.7e-05,0.022888,0.01218,0.000688,0.054356,0.108333,-0.000524,2,1
4,4,0.887458,0.116048,0.727406,1.125306,0.163826,0.110277,0.00695,0.369298,0.11808,...,2e-06,3.7e-05,0.028105,0.010415,0.002752,0.054356,0.147222,-0.000165,2,1


In [6]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2091 entries, 0 to 2090
Data columns (total 45 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   Unnamed: 0       2091 non-null   int64  
 1   net_acc_mean     2091 non-null   float64
 2   net_acc_std      2091 non-null   float64
 3   net_acc_min      2091 non-null   float64
 4   net_acc_max      2091 non-null   float64
 5   EDA_phasic_mean  2091 non-null   float64
 6   EDA_phasic_std   2091 non-null   float64
 7   EDA_phasic_min   2091 non-null   float64
 8   EDA_phasic_max   2091 non-null   float64
 9   EDA_smna_mean    2091 non-null   float64
 10  EDA_smna_std     2091 non-null   float64
 11  EDA_smna_min     2091 non-null   float64
 12  EDA_smna_max     2091 non-null   float64
 13  EDA_tonic_mean   2091 non-null   float64
 14  EDA_tonic_std    2091 non-null   float64
 15  EDA_tonic_min    2091 non-null   float64
 16  EDA_tonic_max    2091 non-null   float64
 17  BVP_mean      

## Data Preparation

In [7]:
features = [
    "net_acc_std",
    "net_acc_max",
    "EDA_tonic_mean",
    "EDA_tonic_min",
    "EDA_tonic_max",
    "EDA_smna_mean",
    "EDA_smna_std",
    "EDA_smna_min",
    "EDA_smna_max",
    "EDA_phasic_min",
    "label"
]

In [8]:
df_feat = df[features]
df_feat.head()

Unnamed: 0,net_acc_std,net_acc_max,EDA_tonic_mean,EDA_tonic_min,EDA_tonic_max,EDA_smna_mean,EDA_smna_std,EDA_smna_min,EDA_smna_max,EDA_phasic_min,label
0,0.153556,1.678399,0.608263,-1.213173,2.55475,1.592308,2.645333,3.096905e-08,17.418821,0.367977,1
1,0.090108,1.4858,0.731985,-1.213173,2.477276,1.34775,2.666659,3.096905e-08,17.418821,0.232625,1
2,0.110987,1.4858,1.110242,-1.213173,2.037179,0.752335,1.958546,3.096905e-08,17.418821,0.00695,1
3,0.135308,1.239944,1.598995,0.959752,2.037179,0.198576,0.413802,3.30999e-08,2.788862,0.00695,1
4,0.116048,1.125306,1.342085,0.945946,2.037179,0.11808,0.237575,2.787285e-08,1.30081,0.00695,1


In [9]:
df_feat.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2091 entries, 0 to 2090
Data columns (total 11 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   net_acc_std     2091 non-null   float64
 1   net_acc_max     2091 non-null   float64
 2   EDA_tonic_mean  2091 non-null   float64
 3   EDA_tonic_min   2091 non-null   float64
 4   EDA_tonic_max   2091 non-null   float64
 5   EDA_smna_mean   2091 non-null   float64
 6   EDA_smna_std    2091 non-null   float64
 7   EDA_smna_min    2091 non-null   float64
 8   EDA_smna_max    2091 non-null   float64
 9   EDA_phasic_min  2091 non-null   float64
 10  label           2091 non-null   int64  
dtypes: float64(10), int64(1)
memory usage: 179.8 KB


### Merged Amusement

In [10]:
df_feat_merged_amusement = df_feat.copy()
df_feat_merged_amusement["label"] = df_feat_merged_amusement["label"].replace([0], 1)

In [11]:
df_feat_merged_amusement["label"].unique()

array([1, 2])

In [12]:
y_merged_amusement = np.array(df_feat_merged_amusement.pop('label'))
X_merged_amusement = np.array(df_feat_merged_amusement)

In [13]:
X_train_merged_amusement, X_test_merged_amusement, y_train_merged_amusement, y_test_merged_amusement = train_test_split(X_merged_amusement, y_merged_amusement, test_size=0.25, random_state=42)

### Dropped amusement

In [14]:
df_feat_no_amusement = df_feat[df_feat["label"] != 0]

In [15]:
df_feat_no_amusement["label"].unique()

array([1, 2])

In [16]:
y_no_amusement = np.array(df_feat_no_amusement.pop('label'))
X_no_amusement = np.array(df_feat_no_amusement)

In [17]:
X_train_no_amusement, X_test_no_amusement, y_train_no_amusement, y_test_no_amusement = train_test_split(X_no_amusement, y_no_amusement, test_size=0.25, random_state=42)

## Training

In [18]:
parameters = dict(
    n_neighbors=(2,),
    weights=("uniform", "distance"),
    algorithm=("ball_tree", "kd_tree", "brute")
)

In [19]:
neighbor = KNeighborsClassifier()

### Merged Amusement

In [20]:
clf_neighbor_merged_amusement = GridSearchCV(estimator=neighbor, param_grid=parameters)

In [21]:
clf_neighbor_merged_amusement.fit(X_train_merged_amusement, y_train_merged_amusement)

In [22]:
clf_neighbor_merged_amusement.best_estimator_

In [23]:
dump(clf_neighbor_merged_amusement.best_estimator_, "clf_neighbor_merged_amusement_top_10_feat.joblib")

['clf_neighbor_merged_amusement_top_10_feat.joblib']

### Dropped Amusement

In [24]:
clf_neighbor_no_amusement = GridSearchCV(estimator=neighbor, param_grid=parameters)

In [25]:
clf_neighbor_no_amusement.fit(X_train_no_amusement, y_train_no_amusement)

In [26]:
clf_neighbor_no_amusement.best_estimator_

In [27]:
dump(clf_neighbor_no_amusement.best_estimator_, "clf_neighbor_no_amusement_top_10_feat.joblib")

['clf_neighbor_no_amusement_top_10_feat.joblib']

### Evaluation

### Merged Amusement

In [28]:
y_pred_merged_amusement = clf_neighbor_merged_amusement.predict(X_test_merged_amusement)

In [29]:
accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.988527724665392

In [30]:
balanced_accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9836488499691145

In [31]:
f1_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9920634920634921

In [32]:
precision_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9894459102902374

In [33]:
recall_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9946949602122016

In [34]:
confusion_matrix(y_test_merged_amusement, y_pred_merged_amusement, labels=[1, 2])

array([[375,   2],
       [  4, 142]])

### Dropped Amusement

In [35]:
y_pred_no_amusement = clf_neighbor_no_amusement.predict(X_test_no_amusement)

In [36]:
accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.975

In [37]:
balanced_accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.9706034706034705

In [38]:
f1_score(y_test_no_amusement, y_pred_no_amusement)

0.9815126050420168

In [39]:
precision_score(y_test_no_amusement, y_pred_no_amusement)

0.9798657718120806

In [40]:
recall_score(y_test_no_amusement, y_pred_no_amusement)

0.9831649831649831

In [41]:
confusion_matrix(y_test_no_amusement, y_pred_no_amusement, labels=[1, 2])

array([[292,   5],
       [  6, 137]])

## XAI

### Merged Amusement

In [42]:
# explainer = shap.KernelExplainer(
#     clf_neighbor_merged_amusement.best_estimator_.predict,
#     X_test_merged_amusement
# )

# shap_values = explainer.shap_values(X_test_merged_amusement)

In [43]:
# shap.summary_plot(shap_values=shap_values, features=X_test_merged_amusement, feature_names=features)

### Dropped Amusement

In [44]:
# explainer = shap.KernelExplainer(
#     clf_neighbor_no_amusement.best_estimator_.predict,
#     X_test_no_amusement
# )

# shap_values = explainer.shap_values(X_test_no_amusement)

In [45]:
# shap.summary_plot(shap_values=shap_values, features=X_test_no_amusement, feature_names=features)