# Modelling

- Decision Tree
- SVC
- Random Forest
- K Neighbors

0 = amusement
1 = baseline
2 = stress

In [363]:
import pandas as pd
import numpy as np

from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split, GridSearchCV

from sklearn.metrics import accuracy_score, balanced_accuracy_score,  f1_score, precision_score, recall_score, confusion_matrix, r2_score

In [364]:
df = pd.read_csv("../data/combined_subjects.csv")

In [365]:
df.head()

Unnamed: 0.1,Unnamed: 0,net_acc_mean,net_acc_std,net_acc_min,net_acc_max,EDA_phasic_mean,EDA_phasic_std,EDA_phasic_min,EDA_phasic_max,EDA_smna_mean,...,ACC_z_min,ACC_z_max,0_mean,0_std,0_min,0_max,BVP_peak_freq,TEMP_slope,subject,label
0,0,1.331891,0.153556,1.014138,1.678399,2.247876,1.112076,0.367977,4.459367,1.592308,...,-2.6e-05,6e-05,0.027558,0.013523,0.0,0.087383,0.080556,-0.000102,2,1
1,1,1.218994,0.090108,1.014138,1.4858,1.781323,1.203991,0.232625,4.459367,1.34775,...,-2.6e-05,6e-05,0.02342,0.01531,0.0,0.087383,0.144444,-0.000424,2,1
2,2,1.143312,0.110987,0.948835,1.4858,1.173169,1.285422,0.00695,4.459367,0.752335,...,-1.5e-05,4.9e-05,0.018759,0.012604,0.0,0.071558,0.102778,-0.000814,2,1
3,3,1.020669,0.135308,0.81109,1.239944,0.311656,0.27865,0.00695,1.303071,0.198576,...,-5e-06,3.7e-05,0.022888,0.01218,0.000688,0.054356,0.108333,-0.000524,2,1
4,4,0.887458,0.116048,0.727406,1.125306,0.163826,0.110277,0.00695,0.369298,0.11808,...,2e-06,3.7e-05,0.028105,0.010415,0.002752,0.054356,0.147222,-0.000165,2,1


## Prepare data for training

In [366]:
df_feat = df[["net_acc_std", "net_acc_max", "EDA_tonic_mean", "EDA_tonic_min", "EDA_tonic_max", "label"]]
df_feat.head()

Unnamed: 0,net_acc_std,net_acc_max,EDA_tonic_mean,EDA_tonic_min,EDA_tonic_max,label
0,0.153556,1.678399,0.608263,-1.213173,2.55475,1
1,0.090108,1.4858,0.731985,-1.213173,2.477276,1
2,0.110987,1.4858,1.110242,-1.213173,2.037179,1
3,0.135308,1.239944,1.598995,0.959752,2.037179,1
4,0.116048,1.125306,1.342085,0.945946,2.037179,1


In [367]:
df_feat.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2091 entries, 0 to 2090
Data columns (total 6 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   net_acc_std     2091 non-null   float64
 1   net_acc_max     2091 non-null   float64
 2   EDA_tonic_mean  2091 non-null   float64
 3   EDA_tonic_min   2091 non-null   float64
 4   EDA_tonic_max   2091 non-null   float64
 5   label           2091 non-null   int64  
dtypes: float64(5), int64(1)
memory usage: 98.1 KB


### Merge amusement into baseline

In [368]:
df_feat_merged_amusement = df_feat.copy()
df_feat_merged_amusement["label"] = df_feat_merged_amusement["label"].replace([0], 1)

In [369]:
df_feat_merged_amusement["label"].unique()

array([1, 2])

In [370]:
y_merged_amusement = np.array(df_feat_merged_amusement.pop('label'))
X_merged_amusement = np.array(df_feat_merged_amusement)

In [371]:
X_train_merged_amusement, X_test_merged_amusement, y_train_merged_amusement, y_test_merged_amusement = train_test_split(X_merged_amusement, y_merged_amusement, test_size=0.25)

### Remove amusement

In [372]:
df_feat_no_amusement = df_feat[df_feat["label"] != 0]

In [373]:
df_feat_no_amusement["label"].unique()

array([1, 2])

In [374]:
y_no_amusement = np.array(df_feat_no_amusement.pop('label'))
X_no_amusement = np.array(df_feat_no_amusement)

In [375]:
X_train_no_amusement, X_test_no_amusement, y_train_no_amusement, y_test_no_amusement = train_test_split(X_no_amusement, y_no_amusement, test_size=0.25)

# Training

## Decision Tree

In [376]:
parameters = dict(
    criterion=("gini", "entropy", "log_loss"),
    splitter=("best", "random"),
    max_depth=(3, 5, 7, 9, 11),
)

In [377]:
tree = DecisionTreeClassifier()

### Merged amusement

In [378]:
clf_tree_merged_amusement = GridSearchCV(estimator=tree, param_grid=parameters)

In [379]:
clf_tree_merged_amusement.fit(X_train_merged_amusement, y_train_merged_amusement)

In [380]:
clf_tree_merged_amusement.best_estimator_

### No amusement

In [381]:
clf_tree_no_amusement = GridSearchCV(estimator=tree, param_grid=parameters)

In [382]:
clf_tree_no_amusement.fit(X_train_no_amusement, y_train_no_amusement)

In [383]:
clf_tree_no_amusement.best_estimator_

### Evaluation

### Merged amusement

In [384]:
y_pred_merged_amusement = clf_tree_merged_amusement.predict(X_test_merged_amusement)

In [385]:
accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9196940726577438

In [386]:
balanced_accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.898023176550784

In [387]:
r2_score(y_test_merged_amusement, y_pred_merged_amusement)

0.6256646216768916

In [388]:
f1_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9424657534246577

In [389]:
precision_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9297297297297298

In [390]:
recall_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9555555555555556

In [391]:
confusion_matrix(y_test_merged_amusement, y_pred_merged_amusement, labels=[1, 2])

array([[344,  16],
       [ 26, 137]])

### No amusement

In [392]:
y_pred_no_amusement = clf_tree_no_amusement.predict(X_test_no_amusement)

In [393]:
accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.9386363636363636

In [394]:
balanced_accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.9350485942381117

In [395]:
r2_score(y_test_no_amusement, y_pred_no_amusement)

0.742276987157237

In [396]:
f1_score(y_test_no_amusement, y_pred_no_amusement)

0.9497206703910615

In [397]:
precision_score(y_test_no_amusement, y_pred_no_amusement)

0.9479553903345725

In [398]:
recall_score(y_test_no_amusement, y_pred_no_amusement)

0.9514925373134329

In [399]:
confusion_matrix(y_test_no_amusement, y_pred_no_amusement, labels=[1, 2])

array([[255,  13],
       [ 14, 158]])

## SVM

In [400]:
parameters = dict(
    C=(.2, .5, 1, 2, 3, 4, 5),
    kernel=("linear", "poly", "rbf", "sigmoid"),
    gamma=("scale", "auto")
)

In [401]:
svc = SVC()

### Merged amusement

In [402]:
clf_svc_merged_amusement = GridSearchCV(estimator=svc, param_grid=parameters)

In [403]:
clf_svc_merged_amusement.fit(X_train_merged_amusement, y_train_merged_amusement)

In [404]:
clf_svc_merged_amusement.best_estimator_

#### No amusement

In [405]:
clf_svc_no_amusement = GridSearchCV(estimator=svc, param_grid=parameters)

In [406]:
clf_svc_no_amusement.fit(X_train_no_amusement, y_train_no_amusement)

In [407]:
clf_svc_no_amusement.best_estimator_

### Evaluation

### Merged amusement

In [408]:
y_pred_merged_amusement = clf_svc_merged_amusement.predict(X_test_merged_amusement)

In [409]:
accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.8948374760994264

In [410]:
balanced_accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.8749318336741649

In [411]:
r2_score(y_test_merged_amusement, y_pred_merged_amusement)

0.5097989093387867

In [412]:
f1_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9239280774550485

In [413]:
precision_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9201101928374655

In [414]:
recall_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9277777777777778

In [415]:
confusion_matrix(y_test_merged_amusement, y_pred_merged_amusement, labels=[1, 2])

array([[334,  26],
       [ 29, 134]])

### No amusement

In [416]:
y_pred_no_amusement = clf_svc_no_amusement.predict(X_test_no_amusement)

In [417]:
accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.9068181818181819

In [418]:
balanced_accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.8912270045123221

In [419]:
r2_score(y_test_no_amusement, y_pred_no_amusement)

0.6086428323498785

In [420]:
f1_score(y_test_no_amusement, y_pred_no_amusement)

0.926391382405745

In [421]:
precision_score(y_test_no_amusement, y_pred_no_amusement)

0.8927335640138409

In [422]:
recall_score(y_test_no_amusement, y_pred_no_amusement)

0.9626865671641791

In [423]:
confusion_matrix(y_test_no_amusement, y_pred_no_amusement, labels=[1, 2])

array([[258,  10],
       [ 31, 141]])

## Random Forest

In [424]:
parameters = dict(
    n_estimators=(25, 50, 75, 100, 125, 150),
    criterion=("gini", "entropy", "log_loss"),
    max_depth=(2, 3, 5, 7, 9, 11)
)

In [425]:
random = RandomForestClassifier()

### Merged amusement

In [426]:
clf_random_merged_amusement = GridSearchCV(estimator=random, param_grid=parameters)

In [427]:
clf_random_merged_amusement.fit(X_train_merged_amusement, y_train_merged_amusement)

In [428]:
clf_random_merged_amusement.best_estimator_

#### No amusement

In [429]:
clf_random_no_amusement = GridSearchCV(estimator=random, param_grid=parameters)

In [430]:
clf_random_no_amusement.fit(X_train_no_amusement, y_train_no_amusement)

In [431]:
clf_random_no_amusement.best_estimator_

### Evaluation

### Merged amusement

In [432]:
y_pred_merged_amusement = clf_random_merged_amusement.predict(X_test_merged_amusement)

In [433]:
accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9254302103250478

In [434]:
balanced_accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9089042263122018

In [435]:
r2_score(y_test_merged_amusement, y_pred_merged_amusement)

0.6524028629856851

In [436]:
f1_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9462068965517241

In [437]:
precision_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9397260273972603

In [438]:
recall_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9527777777777777

In [439]:
confusion_matrix(y_test_merged_amusement, y_pred_merged_amusement, labels=[1, 2])

array([[343,  17],
       [ 22, 141]])

### No amusement

In [440]:
y_pred_no_amusement = clf_random_no_amusement.predict(X_test_no_amusement)

In [441]:
accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.9454545454545454

In [442]:
balanced_accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.9396043040610899

In [443]:
r2_score(y_test_no_amusement, y_pred_no_amusement)

0.7709128774730996

In [444]:
f1_score(y_test_no_amusement, y_pred_no_amusement)

0.9557195571955719

In [445]:
precision_score(y_test_no_amusement, y_pred_no_amusement)

0.9452554744525548

In [446]:
recall_score(y_test_no_amusement, y_pred_no_amusement)

0.9664179104477612

In [447]:
confusion_matrix(y_test_no_amusement, y_pred_no_amusement, labels=[1, 2])

array([[259,   9],
       [ 15, 157]])

## K Neighbours

In [448]:
parameters = dict(
    n_neighbors=(2, 3, 5, 7, 9, 11),
    weights=("uniform", "distance"),
    algorithm=("ball_tree", "kd_tree", "brute")
)

In [449]:
neighbor = KNeighborsClassifier()

### Merged amusement

In [450]:
clf_neighbor_merged_amusement = GridSearchCV(estimator=neighbor, param_grid=parameters)

In [451]:
clf_neighbor_merged_amusement.fit(X_train_merged_amusement, y_train_merged_amusement)

In [452]:
clf_neighbor_merged_amusement.best_estimator_

#### No amusement

In [453]:
clf_neighbor_no_amusement = GridSearchCV(estimator=neighbor, param_grid=parameters)

In [454]:
clf_neighbor_no_amusement.fit(X_train_no_amusement, y_train_no_amusement)

In [455]:
clf_neighbor_no_amusement.best_estimator_

### Evaluation

### Merged amusement

In [456]:
y_pred_merged_amusement = clf_neighbor_merged_amusement.predict(X_test_merged_amusement)

In [457]:
accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9521988527724665

In [458]:
balanced_accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9451346284935243

In [459]:
r2_score(y_test_merged_amusement, y_pred_merged_amusement)

0.7771813224267212

In [460]:
f1_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9652294853963839

In [461]:
precision_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9665738161559888

In [462]:
recall_score(y_test_merged_amusement, y_pred_merged_amusement)

0.9638888888888889

In [463]:
confusion_matrix(y_test_merged_amusement, y_pred_merged_amusement, labels=[1, 2])

array([[347,  13],
       [ 12, 151]])

### No amusement

In [464]:
y_pred_no_amusement = clf_neighbor_no_amusement.predict(X_test_no_amusement)

In [465]:
accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.95

In [466]:
balanced_accuracy_score(y_test_no_amusement, y_pred_no_amusement)

0.9433356473446719

In [467]:
r2_score(y_test_no_amusement, y_pred_no_amusement)

0.790003471017008

In [468]:
f1_score(y_test_no_amusement, y_pred_no_amusement)

0.9595588235294118

In [469]:
precision_score(y_test_no_amusement, y_pred_no_amusement)

0.9456521739130435

In [470]:
recall_score(y_test_no_amusement, y_pred_no_amusement)

0.9738805970149254

In [473]:
confusion_matrix(y_test_no_amusement, y_pred_no_amusement, labels=[1, 2])

array([[261,   7],
       [ 15, 157]])