In [3]:
import pandas as pd

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

from tsfresh import select_features
from tsfresh.utilities.dataframe_functions import impute

In [4]:
# load feature matrix
X_min = pd.read_hdf('0_0315_4_electrodes_min.h5')
X_eff = pd.read_hdf('0_0315_4_electrodes_eff.h5')

# load target dataframe
y = pd.read_hdf('0_0315_4_electrodes_y.h5')
y = y.drop_duplicates()
y = y.set_index('window_id')
y = y.T.squeeze()

In [5]:
impute(X_min)
impute(X_eff)
X_min_filt = select_features(X_min, y)
X_eff_filt = select_features(X_eff, y)

ValueError: Index of X and y must be identical if provided

In [None]:
X_eff_train, X_eff_test, y_train, y_test = train_test_split(X_eff_filt, y, test_size=.4)
X_min_train, X_min_test, y_train, y_test = train_test_split(X_min_filt, y, test_size=.4)

In [27]:
tree_eff = DecisionTreeClassifier()
tree_eff.fit(X_eff_train, y_train)
print(classification_report(y_test, tree_eff.predict(X_eff_test)))

              precision    recall  f1-score   support

           0       0.40      0.43      0.41        14
           1       0.18      0.22      0.20         9
           2       0.55      0.43      0.48        14

    accuracy                           0.38        37
   macro avg       0.38      0.36      0.36        37
weighted avg       0.40      0.38      0.39        37



In [29]:
tree_min = DecisionTreeClassifier()
tree_min.fit(X_min_train, y_train)
print(classification_report(y_test, tree_min.predict(X_min_test)))

              precision    recall  f1-score   support

           0       0.87      0.93      0.90        14
           1       0.57      0.89      0.70         9
           2       0.88      0.50      0.64        14

    accuracy                           0.76        37
   macro avg       0.77      0.77      0.74        37
weighted avg       0.80      0.76      0.75        37



In [6]:
X_eff.head()

variable,27__abs_energy,27__absolute_sum_of_changes,"27__agg_autocorrelation__f_agg_""mean""__maxlag_40","27__agg_autocorrelation__f_agg_""median""__maxlag_40","27__agg_autocorrelation__f_agg_""var""__maxlag_40","27__agg_linear_trend__f_agg_""max""__chunk_len_10__attr_""intercept""","27__agg_linear_trend__f_agg_""max""__chunk_len_10__attr_""rvalue""","27__agg_linear_trend__f_agg_""max""__chunk_len_10__attr_""slope""","27__agg_linear_trend__f_agg_""max""__chunk_len_10__attr_""stderr""","27__agg_linear_trend__f_agg_""max""__chunk_len_50__attr_""intercept""",...,57__symmetry_looking__r_0.9,57__symmetry_looking__r_0.9500000000000001,57__time_reversal_asymmetry_statistic__lag_1,57__time_reversal_asymmetry_statistic__lag_2,57__time_reversal_asymmetry_statistic__lag_3,57__value_count__value_-1,57__value_count__value_0,57__value_count__value_1,57__variance,57__variance_larger_than_standard_deviation
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,10287450.0,542.815232,0.998717,0.999183,1.833299e-06,7.51359,-0.201483,-0.047415,0.009426,9.796812,...,1.0,1.0,5183.474528,10392.257797,15626.176176,0.0,0.0,0.0,23533.261792,1.0
1,20545810.0,547.881431,0.999416,0.999689,4.959195e-07,-51.927007,0.377193,0.122607,0.01231,-49.58558,...,1.0,1.0,-4527.816554,-9039.693869,-13535.445137,0.0,0.0,0.0,20199.819801,1.0
2,20853340.0,678.509085,0.999419,0.999745,6.071601e-07,21.136099,-0.318151,-0.10616,0.012936,22.652361,...,1.0,1.0,-237.79933,-478.691372,-722.701491,0.0,0.0,0.0,13787.033912,1.0
3,11876370.0,490.528182,0.994833,0.995257,1.188132e-05,30.552197,-0.171437,-0.040132,0.009431,31.539385,...,1.0,1.0,182.838563,364.719348,545.64404,0.0,0.0,0.0,4249.001457,1.0
4,14783260.0,464.445814,0.997295,0.997514,3.244559e-06,-47.088775,0.780305,0.211839,0.006943,-45.465,...,1.0,1.0,11.874008,24.552356,38.052601,0.0,0.0,0.0,2889.73497,1.0


In [7]:
X_min.head()

variable,27__length,27__maximum,27__mean,27__median,27__minimum,27__standard_deviation,27__sum_values,27__variance,34__length,34__maximum,...,4__sum_values,4__variance,57__length,57__maximum,57__mean,57__median,57__minimum,57__standard_deviation,57__sum_values,57__variance
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,6000.0,64.282079,-7.094168,-16.079742,-98.425293,40.795184,-42565.009585,1664.247023,6000.0,53.979924,...,-80267.818929,25313.868698,6000.0,290.661661,-17.225708,-64.398988,-204.070514,153.405547,-103354.247551,23533.261792
1,6000.0,57.034841,-15.617231,-5.97545,-108.570478,56.395064,-93703.384438,3180.403191,6000.0,82.043767,...,-113697.182993,21397.711372,6000.0,274.290223,-17.849974,-57.85444,-181.483743,142.12607,-107099.845392,20199.819801
2,5999.0,84.817419,-11.171965,-0.181217,-116.684697,57.89061,-67020.617478,3351.322706,5999.0,84.238884,...,117103.672237,16018.828457,5999.0,216.386658,23.001735,20.278273,-150.945904,117.418201,137987.408109,13787.033912
3,6000.0,68.549734,18.164751,31.838341,-62.821892,40.613261,108988.504633,1649.436948,6000.0,75.38661,...,78736.316409,5668.518185,6000.0,136.237402,16.418633,25.737367,-93.334961,65.184365,98511.797509,4249.001457
4,6000.0,111.529645,16.008868,-0.210763,-62.825766,46.985019,96053.206487,2207.592015,6000.0,66.437765,...,-9176.557947,3514.979177,6000.0,92.255836,7.437976,13.568825,-98.530345,53.756255,44627.855964,2889.73497
