# Classifying Drug Administration on Subject 2 (Ach-AT)

## Loading and Processing Data
Using the functions written in `src/data/load_data.py`, we convert the raw signals into a format containing discrete time windows.  
This format is suitable for tsfresh to then be used to extract relevant features.

In [1]:
from src.data.load_data import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


from tsfresh import extract_features
from tsfresh.feature_extraction import MinimalFCParameters

In [4]:
ach_at_02 = load_MEA_data(folder = "data/raw/Ach-AT-subject-02")
# take middle horizontal strip of electrodes
label_MEA_data(ach_at_02, output='ach_at_subject_02')

In [2]:
df = pd.read_hdf("data/processed/ach_at_subject_02.h5")

In [3]:
y = (df[['window_id','y']]
     .drop_duplicates()
     .set_index('window_id')
     .T
     .squeeze()
     .sort_index(0))

cols = ['t', 'window_id', 15, 14, 13,  4, 57, 48, 47, 46]
df_middle = df[cols]
df = df.drop(columns = ['y'])

In [6]:
X = extract_features(df, column_id='window_id', column_sort='t', default_fc_parameters=MinimalFCParameters())
X.to_hdf('achat_02_min.h5', key = 'features', complevel = 9)

Feature Extraction: 100%|██████████| 10/10 [00:04<00:00,  2.01it/s]


In [7]:
X_middle = extract_features(df_middle, column_id='window_id', column_sort='t', default_fc_parameters=MinimalFCParameters())
X_middle.to_hdf('achat_02_min_middle.h5', key = 'features', complevel = 9)

Feature Extraction: 100%|██████████| 10/10 [00:00<00:00, 19.79it/s]


## Model Fitting

Compare the performance of decision trees (individual, interpretable) and random forests (ensemble).
Fit both types of model on each set of features and identify which combination results in the highest classification accuracy.

In [14]:
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score

from tsfresh import select_features
from tsfresh.utilities.dataframe_functions import impute

### Feature set 1: Middle strip of electrodes from MEA

In [56]:
X_middle = pd.read_hdf('trial_data/achat_02_min_middle.h5')
# first try with just the middle strip
impute(X_middle)
X_middle_filtered = select_features(X_middle, y)

In [62]:
## Evaluation Method 1: split training and test data
X_middle_train, X_middle_test, y_train, y_test = train_test_split(X_middle_filtered, y, test_size=.4)

# TREE
tree_middle = tree.DecisionTreeClassifier()
tree_middle.fit(X_middle_train, y_train)
print(classification_report(y_test, tree_middle.predict(X_middle_test)))
print(confusion_matrix(y_test, tree_middle.predict(X_middle_test)))

# RANDOM FOREST
rf = RandomForestClassifier(n_estimators = 200, max_depth=3, random_state=0)
rf.fit(X_middle_train, y_train)
print(classification_report(y_test, rf.predict(X_middle_test)))
print(confusion_matrix(y_test, rf.predict(X_middle_test)))

## Evaluation Method 2: k-fold cross validation
# TREE
print('\n TREE \n Mean 3-fold cross-validation score = '+str(np.mean(cross_val_score(tree_middle, X_middle, y, cv=3))))

# RANDOM FOREST
print('\n RANDOM FOREST \n Mean 3-fold cross-validation score = '+str(np.mean(cross_val_score(rf, X_middle, y, cv=3))))

precision    recall  f1-score   support

           0       0.91      1.00      0.95        10
           1       0.93      0.87      0.90        15
           2       0.85      0.85      0.85        13

    accuracy                           0.89        38
   macro avg       0.89      0.90      0.90        38
weighted avg       0.90      0.89      0.89        38

[[10  0  0]
 [ 0 13  2]
 [ 1  1 11]]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        10
           1       0.92      0.80      0.86        15
           2       0.80      0.92      0.86        13

    accuracy                           0.89        38
   macro avg       0.91      0.91      0.90        38
weighted avg       0.90      0.89      0.89        38

[[10  0  0]
 [ 0 12  3]
 [ 0  1 12]]

 TREE 
 Mean 3-fold cross-validation score = 0.8172043010752689

 RANDOM FOREST 
 Mean 3-fold cross-validation score = 0.9032258064516129


### Feature set 2: All electrodes from MEA

In [64]:
X = pd.read_hdf('trial_data/achat_02_min.h5')
# next try with the entire electrode array
impute(X)
X_filtered = select_features(X, y)

In [68]:
## Evaluation Method 1: split training and test data
X_train, X_test, y_train, y_test = train_test_split(X_filtered, y, test_size=.4)

# TREE
tree_full = tree.DecisionTreeClassifier()
tree_full.fit(X_train, y_train)
print('\n TREE \n')
print(classification_report(y_test, tree_full.predict(X_test)))
print(confusion_matrix(y_test, tree_full.predict(X_test)))

# RANDOM FOREST
print('\n RANDOM FOREST \n')
rf = RandomForestClassifier(n_estimators = 200, max_depth=3, random_state=0)
rf.fit(X_train, y_train)
print(classification_report(y_test, rf.predict(X_test)))
print(confusion_matrix(y_test, rf.predict(X_test)))

## Evaluation Method 2: k-fold cross validation
# TREE
print('\n TREE \n Mean 3-fold cross-validation score = '+str(np.mean(cross_val_score(tree_full, X_filtered, y, cv=3))))

# RANDOM FOREST
print('\n RANDOM FOREST \n Mean 3-fold cross-validation score = '+str(np.mean(cross_val_score(rf, X_filtered, y, cv=3))))


 TREE 

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        12
           1       0.75      0.86      0.80        14
           2       0.80      0.67      0.73        12

    accuracy                           0.84        38
   macro avg       0.85      0.84      0.84        38
weighted avg       0.84      0.84      0.84        38

[[12  0  0]
 [ 0 12  2]
 [ 0  4  8]]

 RANDOM FOREST 

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        12
           1       0.93      0.93      0.93        14
           2       0.92      0.92      0.92        12

    accuracy                           0.95        38
   macro avg       0.95      0.95      0.95        38
weighted avg       0.95      0.95      0.95        38

[[12  0  0]
 [ 0 13  1]
 [ 0  1 11]]

 TREE 
 Mean 3-fold cross-validation score = 0.8494623655913979

 RANDOM FOREST 
 Mean 3-fold cross-validation score = 0.935483870967742

We observe that using the set of features generated from all of the electrodes provides slightly higher classification accuracy than only using the middle strip of electrodes. This confirms to me that we should be using all of the data when fitting a model to the whole dataset.  
Unsurprisingly, random forests perform better than individual trees. 3-fold cross-validation provides a less biased estimate of out-of-sample model accuracy. Additionally, we observe from the confusion matrices that each of the target variables share similar classification accuracies hence overall accuracy is not heavily biased. 