# Load packages and prepare data

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sktime.transformations.panel.padder import PaddingTransformer
from sktime.transformations.panel.catch22 import Catch22

In [3]:
df_combined_subject = pd.read_pickle("..\\cleanedData\\df_combined_subject.pkl")

In [4]:
# perform train test split according by subject
# split into 5 different folds for CV
from sklearn.model_selection import GroupKFold
X_train, X_test, y_train, y_test = [], [], [], []
gss = GroupKFold(n_splits=5)
for train, test in gss.split(df_combined_subject["normalised_resp"], df_combined_subject["difficulty"], df_combined_subject["subject"]):
  X_train.append(df_combined_subject.loc[train,["normalised_resp"]])
  X_test.append(df_combined_subject.loc[test,["normalised_resp"]])
  y_train.append(df_combined_subject.loc[train,"difficulty"].astype("string"))
  y_test.append(df_combined_subject.loc[test,"difficulty"].astype("string"))

## Prepare Catch22 dataset

In [8]:
%%time
catch22 = PaddingTransformer() * Catch22()
X_train0_catch22 = catch22.fit_transform(X_train[0])
X_test0_catch22 = catch22.transform(X_test[0])

CPU times: total: 39min 3s
Wall time: 42min 18s


In [30]:
X_train0_catch22["difficulty"] = y_train[0].values

In [34]:
X_train0_catch22.to_csv("..\\cleanedData\\X_train0_catch22.csv")
X_test0_catch22.to_csv("..\\cleanedData\\X_test0_catch22.csv")

# Modeling

In [32]:
from pycaret.classification import *
pycaret_class = setup(data = X_train0_catch22, target = "difficulty", session_id=42) 

Unnamed: 0,Description,Value
0,Session id,42
1,Target,difficulty
2,Target type,Multiclass
3,Target mapping,"000: 0, 01B: 1, 02B: 2, 03B: 3, 04B: 4"
4,Original data shape,"(389, 23)"
5,Transformed data shape,"(389, 23)"
6,Transformed train set shape,"(272, 23)"
7,Transformed test set shape,"(117, 23)"
8,Numeric features,22
9,Rows with missing values,1.3%


In [33]:
best = compare_models()

Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC,TT (Sec)
rf,Random Forest Classifier,0.3832,0.6457,0.3832,0.3808,0.3667,0.2252,0.2299,0.065
lightgbm,Light Gradient Boosting Machine,0.3562,0.6373,0.3562,0.3563,0.344,0.1897,0.1932,0.193
dt,Decision Tree Classifier,0.3537,0.5915,0.3537,0.3434,0.3368,0.1876,0.192,0.012
et,Extra Trees Classifier,0.3496,0.6275,0.3496,0.346,0.3401,0.1832,0.1865,0.036
gbc,Gradient Boosting Classifier,0.3421,0.6436,0.3421,0.3469,0.328,0.1712,0.1758,0.103
knn,K Neighbors Classifier,0.2976,0.5918,0.2976,0.2943,0.2813,0.1174,0.1218,0.292
lda,Linear Discriminant Analysis,0.261,0.5656,0.261,0.2798,0.2528,0.0692,0.0713,0.012
ada,Ada Boost Classifier,0.2537,0.5863,0.2537,0.253,0.2412,0.0564,0.0591,0.021
qda,Quadratic Discriminant Analysis,0.2501,0.5316,0.2501,0.2797,0.2326,0.0502,0.0526,0.012
svm,SVM - Linear Kernel,0.228,0.0,0.228,0.163,0.1401,0.0197,0.0352,0.011


Processing:   0%|          | 0/61 [00:00<?, ?it/s]