In [1]:
from src.data.load_data import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from tsfresh import extract_features
from tsfresh.feature_extraction import MinimalFCParameters

In [4]:
ach_at_02 = load_MEA_data(folder = "data/raw/Ach-AT-subject-02")
# take middle horizontal strip of electrodes
label_MEA_data(ach_at_02, output='ach_at_subject_02')

In [2]:
df = pd.read_hdf("data/processed/ach_at_subject_02.h5")

In [3]:
y = (df[['window_id','y']]
     .drop_duplicates()
     .set_index('window_id')
     .T
     .squeeze()
     .sort_index(0))

cols = ['t', 'window_id', 15, 14, 13,  4, 57, 48, 47, 46]
df_middle = df[cols]
df = df.drop(columns = ['y'])

In [6]:
X = extract_features(df, column_id='window_id', column_sort='t', default_fc_parameters=MinimalFCParameters())
X.to_hdf('achat_02_min.h5', key = 'features', complevel = 9)

Feature Extraction: 100%|██████████| 10/10 [00:04<00:00,  2.01it/s]


In [7]:
X_middle = extract_features(df_middle, column_id='window_id', column_sort='t', default_fc_parameters=MinimalFCParameters())
X_middle.to_hdf('achat_02_min_middle.h5', key = 'features', complevel = 9)

Feature Extraction: 100%|██████████| 10/10 [00:00<00:00, 19.79it/s]


In [30]:
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.ensemble import RandomForestClassifier

from tsfresh import select_features
from tsfresh.utilities.dataframe_functions import impute

In [6]:
X_middle = pd.read_hdf('achat_02_min_middle.h5')

In [28]:
# first try with just the middle strip
# impute(X_middle)
# X_middle_filtered = select_features(X_middle, y)
X_middle_train, X_middle_test, y_train, y_test = train_test_split(X_middle_filtered, y, test_size=.4)
tree_middle = tree.DecisionTreeClassifier()
tree_middle.fit(X_middle_train, y_train)
print(classification_report(y_test, tree_middle.predict(X_middle_test)))
print(confusion_matrix(y_test, tree_middle.predict(X_middle_test)))


              precision    recall  f1-score   support

           0       1.00      0.82      0.90        11
           1       0.78      0.82      0.80        17
           2       0.73      0.80      0.76        10

    accuracy                           0.82        38
   macro avg       0.84      0.81      0.82        38
weighted avg       0.83      0.82      0.82        38

[[ 9  2  0]
 [ 0 14  3]
 [ 0  2  8]]


In [54]:
clf = RandomForestClassifier(n_estimators = 200, max_depth=10, random_state=0)
clf.fit(X_middle_train, y_train)
print(classification_report(y_test, clf.predict(X_middle_test)))
print(confusion_matrix(y_test, clf.predict(X_middle_test)))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        11
           1       1.00      0.82      0.90        17
           2       0.77      1.00      0.87        10

    accuracy                           0.92        38
   macro avg       0.92      0.94      0.92        38
weighted avg       0.94      0.92      0.92        38

[[11  0  0]
 [ 0 14  3]
 [ 0  0 10]]


In [56]:
X = pd.read_hdf('achat_02_min.h5')

In [80]:
# next try with the entire electrode array
# impute(X)
# X_filtered = select_features(X, y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.4)
tree_full = tree.DecisionTreeClassifier()
tree_full.fit(X_train, y_train)
print(classification_report(y_test, tree_full.predict(X_test)))

              precision    recall  f1-score   support

           0       1.00      0.91      0.95        11
           1       0.75      0.90      0.82        10
           2       0.94      0.88      0.91        17

    accuracy                           0.89        38
   macro avg       0.90      0.90      0.89        38
weighted avg       0.91      0.89      0.90        38



In [82]:
clf = RandomForestClassifier(n_estimators = 200, max_depth=10, random_state=0)
clf.fit(X_train, y_train)
print(classification_report(y_test, clf.predict(X_test)))
print(confusion_matrix(y_test, clf.predict(X_test)))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        11
           1       0.90      0.90      0.90        10
           2       0.94      0.94      0.94        17

    accuracy                           0.95        38
   macro avg       0.95      0.95      0.95        38
weighted avg       0.95      0.95      0.95        38

[[11  0  0]
 [ 0  9  1]
 [ 0  1 16]]
