# Decision Tree

0: Baseline, 1: Stressed

In [None]:
import gc
from joblib import dump

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

from interpret import set_visualize_provider
from interpret.provider import InlineProvider
from interpret import show
from interpret.blackbox import LimeTabular, ShapKernel, PartialDependence, MorrisSensitivity

from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split, GridSearchCV
from explainerdashboard import ClassifierExplainer, ExplainerDashboard

from sklearn.metrics import accuracy_score, balanced_accuracy_score,  f1_score, precision_score, recall_score, confusion_matrix, ConfusionMatrixDisplay

In [None]:
gc.collect()

In [None]:
set_visualize_provider(InlineProvider())

In [None]:
df = pd.read_csv("../../../data/combined_subjects.csv")

In [None]:
df.head()

In [None]:
df.info()

## Data Preparation

In [None]:
features = ["net_acc_std", "net_acc_max", "EDA_tonic_mean", "EDA_tonic_min", "EDA_tonic_max", "label"]

In [None]:
df_feat = df[features]
df_feat.head()

In [None]:
df_feat.info()

### Merged Amusement

In [None]:
df_feat_merged_amusement = df_feat.copy()

# baseline = 0
df_feat_merged_amusement["label"] = df_feat_merged_amusement["label"].replace([1], 0)

In [None]:
# stressed = 1
df_feat_merged_amusement["label"] = df_feat_merged_amusement["label"].replace([2], 1)

In [None]:
df_feat_merged_amusement["label"].unique()

In [None]:
y_merged_amusement = df_feat_merged_amusement.pop('label')
X_merged_amusement = df_feat_merged_amusement

In [None]:
X_train_merged_amusement, X_test_merged_amusement, y_train_merged_amusement, y_test_merged_amusement = train_test_split(X_merged_amusement, y_merged_amusement, test_size=0.25, random_state=42)

### Dropped Amusement

In [None]:
df_feat_no_amusement = df_feat[df_feat["label"] != 0]

In [None]:
# baseline = 0
df_feat_no_amusement["label"] = df_feat_no_amusement["label"].replace([1], 0)

In [None]:
# stressed = 1
df_feat_no_amusement["label"] = df_feat_no_amusement["label"].replace([2], 1)

In [None]:
df_feat_no_amusement["label"].unique()

In [None]:
y_no_amusement = df_feat_no_amusement.pop('label')
X_no_amusement = df_feat_no_amusement

In [None]:
X_train_no_amusement, X_test_no_amusement, y_train_no_amusement, y_test_no_amusement = train_test_split(X_no_amusement, y_no_amusement, test_size=0.25, random_state=42)

## Training

In [None]:
parameters = dict(
    criterion=("gini", "entropy", "log_loss"),
    splitter=("best", "random"),
    max_depth=(3, 5, 7, 9, 11),
    random_state=(42,)
)

In [None]:
tree = DecisionTreeClassifier()

### Merged Amusement

In [None]:
clf_tree_merged_amusement = GridSearchCV(estimator=tree, param_grid=parameters)

In [None]:
clf_tree_merged_amusement.fit(X_train_merged_amusement, y_train_merged_amusement)

In [None]:
clf_tree_merged_amusement.best_estimator_

In [None]:
dump(clf_tree_merged_amusement.best_estimator_, "clf_tree_merged_amusement_top_5_feat.joblib")

### Dropped Amusement

In [None]:
clf_tree_no_amusement = GridSearchCV(estimator=tree, param_grid=parameters)

In [None]:
clf_tree_no_amusement.fit(X_train_no_amusement, y_train_no_amusement)

In [None]:
clf_tree_no_amusement.best_estimator_

In [None]:
dump(clf_tree_no_amusement.best_estimator_, "clf_tree_no_amusement_top_5_feat.joblib")

## Evaluation

### Merged Amusement

In [None]:
y_pred_merged_amusement = clf_tree_merged_amusement.predict(X_test_merged_amusement)

In [None]:
accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

In [None]:
balanced_accuracy_score(y_test_merged_amusement, y_pred_merged_amusement)

In [None]:
f1_score(y_test_merged_amusement, y_pred_merged_amusement)

In [None]:
precision_score(y_test_merged_amusement, y_pred_merged_amusement)

In [None]:
recall_score(y_test_merged_amusement, y_pred_merged_amusement)

In [None]:
cm = confusion_matrix(y_test_merged_amusement, y_pred_merged_amusement, labels=clf_tree_merged_amusement.best_estimator_.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=clf_tree_merged_amusement.best_estimator_.classes_)
disp.plot()
plt.show()

### Dropped Amusement

In [None]:
y_pred_no_amusement = clf_tree_no_amusement.predict(X_test_no_amusement)

In [None]:
accuracy_score(y_test_no_amusement, y_pred_no_amusement)

In [None]:
balanced_accuracy_score(y_test_no_amusement, y_pred_no_amusement)

In [None]:
f1_score(y_test_no_amusement, y_pred_no_amusement)

In [None]:
precision_score(y_test_no_amusement, y_pred_no_amusement)

In [None]:
recall_score(y_test_no_amusement, y_pred_no_amusement)

In [None]:
cm = confusion_matrix(y_test_no_amusement, y_pred_no_amusement, labels=clf_tree_no_amusement.best_estimator_.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=clf_tree_no_amusement.best_estimator_.classes_)
disp.plot()
plt.show()

## XAI

In [None]:
def plot_decision_tree(clf, feature_names, class_names=["baseline", "stess"]):
    fig = plt.figure(figsize=(20, 8))
    vis = plot_tree(clf, feature_names=feature_names, class_names=class_names, max_depth=3, fontsize=7, proportion=True, filled=True, rounded=True)


### Merged Amusement

In [None]:
plot_decision_tree(
    clf=clf_tree_merged_amusement.best_estimator_,
    feature_names=features[:-1]
)

In [None]:
explainer_merged_amusement = ClassifierExplainer(clf_tree_merged_amusement, X_test_merged_amusement, y_test_merged_amusement)
ExplainerDashboard(explainer_merged_amusement, mode="inline").run(8765)

### Dropped Amusement

In [None]:
plot_decision_tree(
    clf=clf_tree_no_amusement.best_estimator_,
    feature_names=features[:-1]
)

In [None]:
explainer_no_amusement = ClassifierExplainer(clf_tree_no_amusement, X_test_no_amusement, y_test_no_amusement)
ExplainerDashboard(explainer_no_amusement, mode="inline").run(8766)