In [13]:
import pandas as pd

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

from tsfresh import select_features
from tsfresh.utilities.dataframe_functions import impute

In [24]:
# load feature matrix
X_min = pd.read_hdf('0_0315_4_electrodes_min.h5')
X_eff = pd.read_hdf('0_0315_4_electrodes_eff.h5')

# load target dataframe
y = pd.read_hdf('0_0315_4_electrodes_y.h5')
y = y.drop_duplicates()
y = y.set_index('window_id')
y = y.T.squeeze()
y = y.sort_index(0)

In [25]:
impute(X_min)
impute(X_eff)
X_min_filt = select_features(X_min, y)
X_eff_filt = select_features(X_eff, y)

In [37]:
X_eff_train, X_eff_test, y_train, y_test = train_test_split(X_eff_filt, y, test_size=.4)
X_min_train, X_min_test, y_train, y_test = train_test_split(X_min_filt, y, test_size=.4)

In [38]:
tree_eff = DecisionTreeClassifier()
tree_eff.fit(X_eff_train, y_train)
print(classification_report(y_test, tree_eff.predict(X_eff_test)))

              precision    recall  f1-score   support

           0       0.25      0.15      0.19        13
           1       0.39      0.54      0.45        13
           2       0.36      0.36      0.36        11

    accuracy                           0.35        37
   macro avg       0.33      0.35      0.34        37
weighted avg       0.33      0.35      0.33        37



In [49]:
tree_min = DecisionTreeClassifier()
tree_min.fit(X_min_train, y_train)
print(classification_report(y_test, tree_min.predict(X_min_test)))

              precision    recall  f1-score   support

           0       1.00      0.92      0.96        13
           1       0.69      0.85      0.76        13
           2       0.78      0.64      0.70        11

    accuracy                           0.81        37
   macro avg       0.82      0.80      0.81        37
weighted avg       0.82      0.81      0.81        37



In [29]:
X_min_filt.head()

variable,34__maximum,4__variance,4__standard_deviation,34__variance,34__standard_deviation,57__standard_deviation,57__variance,4__maximum,4__minimum,34__minimum,57__maximum,27__variance,27__standard_deviation,57__minimum,27__minimum,27__maximum
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
0,53.979924,25313.868698,159.103327,1493.920412,38.651267,153.405547,23533.261792,295.423675,-205.294628,-90.785602,290.661661,1664.247023,40.795184,-204.070514,-98.425293,64.282079
1,82.043767,21397.711372,146.279566,2654.881296,51.52554,142.12607,20199.819801,277.318398,-197.619225,-111.661413,274.290223,3180.403191,56.395064,-181.483743,-108.570478,57.034841
2,84.238884,16018.828457,126.565511,3315.906394,57.583907,117.418201,13787.033912,224.252631,-167.800456,-105.058553,216.386658,3351.322706,57.89061,-150.945904,-116.684697,84.817419
3,75.38661,5668.518185,75.289562,1703.184739,41.269659,65.184365,4249.001457,142.916404,-112.513313,-67.2385,136.237402,1649.436948,40.613261,-93.334961,-62.821892,68.549734
4,66.437765,3514.979177,59.28726,1035.644492,32.181431,53.756255,2889.73497,104.061364,-94.316387,-51.435484,92.255836,2207.592015,46.985019,-98.530345,-62.825766,111.529645


In [55]:
df_min = pd.concat([X_min_filt, y], axis = 1)
# FEATURE IMPORTANCE
# PLOTS NEEDED
# distribution of each feature for different y


ModuleNotFoundError: No module named 'ggplot'