In [2]:
# import all libraries
import pandas as pd
import numpy as np
import glob
import os
import mne

from sklearn.svm import SVC
from sklearn.metrics import mean_squared_error
from sklearn.svm import LinearSVC
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
from xgboost import XGBClassifier

from sklearn.feature_selection import SelectFromModel
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, classification_report

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation
from tensorflow.keras.optimizers import SGD

In [3]:
# all functions
# example 1_al_ciplv_theta_1.npy
def parse_filename(filename):
    s = filename.split("_")
    pId = s[0]
    label = s[1]
    method = s[2]
    freq = s[-2]
    epoch = s[-1].split(".")[0]
    
    return {"pId": pId, "label": label, "method": method, "freq": freq, "epoch": epoch}
    # return {"pId": pId, "label": label, "method": method, "freq": freq}

def read_file(filename):
    return np.load(filename)

# get all the files in the folder output/
def get_files(folder, filter = None):
    '''
    filter is a dictionary, has 2 keys: method and freq to filter the files
    '''
    files = glob.glob(os.path.join(folder, "*.npy"))
    ret_files = []
    if filter:
        for f in files:
            # get the filename
            filename = os.path.basename(f)
            f_info = parse_filename(filename)
            # remove epoch 2 if no overlap
            if f_info["epoch"] == "2":
                continue
            if (f_info["method"] == filter["method"]) and (f_info["freq"] == filter["freq"]):
                ret_files.append(f)
    else:
        ret_files = files

    return ret_files

def feature_extraction(data):
    """
    get lower part of diagonal matrix
    """
    data_lower = data[np.tril_indices(data.shape[0], k=-1)]
    return data_lower

In [4]:
# logging to airtable
from airtable.airtable import AirTableClient
from airtable.config import config

atc = AirTableClient(**config["airtable_sensor"])

In [5]:
# constants
# get the labels for electrodes
# fs_dir = mne.datasets.fetch_fsaverage(verbose=True)
# subjects_dir = os.path.dirname(fs_dir)
# labels = mne.read_labels_from_annot('fsaverage', parc='aparc',
#                                     subjects_dir=subjects_dir)
# labels.pop(-1)
# label_colors = [label.color for label in labels]
data_source = "no_split"
labels = {"al": 0, "fa":1}

methods = ['wpli2']
freqs = ['delta', 'theta', 'alpha', 'beta', 'gamma']

# grid search params for different models
models = {
    "svm": {
        "model": SVC(kernel='linear', C=1),
        "params": {
            "C": [0.1, 1, 10, 100, 1000],
            "kernel": ['linear', 'rbf']
        }
    }
}

In [6]:
def run_grid_search(model_name, model, params, X, y):
    """
    run grid search for a model
    """
    grid_search = GridSearchCV(model, params, cv=5, return_train_score=False)
    grid_search.fit(X, y)
    # print("Best params: ", grid_search.best_params_)
    # print("Best score: ", grid_search.best_score_)
    return grid_search.best_params_, grid_search.best_score_

def get_metrics(y_true, y_pred):
    """
    get the metrics for the model
    """
    full_classification_report = classification_report(y_true, y_pred, output_dict=True)
    specificity = round(full_classification_report["0"]["recall"], 3)
    sensitivity = round(full_classification_report["1"]["recall"], 3)
    accuracy = round(full_classification_report["accuracy"], 3)
    
    return specificity, sensitivity, accuracy

In [10]:
for method in methods:
    for freq in freqs:
        print("Processing: method: {}, freq: {}".format(method, freq))
        filter = {"method": method, "freq": freq}
        files = get_files("output_sensor/no_split/", filter)
        X = []
        Y = []
        for f in files:
            data = read_file(f)
            X.append(feature_extraction(data))
            Y.append(labels.get(parse_filename(os.path.basename(f))["label"]))

        X = np.array(X)
        Y = np.array(Y)
            
        # for ifs in [False, True]:
        is_feature_selection = True
        if is_feature_selection:
            lsvc = LinearRegression()
            model = SelectFromModel(lsvc)
            X_new = model.fit_transform(X, Y)
        else:
            X_new = X


        X_new = np.array(X_new)
        Y = np.array(Y)
        print(X_new.shape)

        X_train, X_test, Y_train, Y_test = train_test_split(X_new, Y, test_size=0.5, random_state=12)
        for m in models.keys():
            print("Processing: method: {}, freq: {}, model: {}, feature selection: {}".format(method, freq, m, is_feature_selection))
            best_params, best_score = run_grid_search(m, models[m]["model"], models[m]["params"], X_new, Y)
            # run best model
            model = models[m]["model"].set_params(**best_params)
            model.fit(X_train, Y_train)
            Y_pred = model.predict(X_test)

            specificity, sensitivity, accuracy = get_metrics(Y_test, Y_pred)

            # print classification report
            print(classification_report(Y_test, Y_pred))
            
            full_classification_report = classification_report(Y_test, Y_pred)

            res = {
                "data": data_source,
                "method": method,
                "frequency": freq,
                "model": m,
                "feature selection": str(is_feature_selection),
                "best_params": str(best_params),
                "accuracy": str(accuracy),
                "specificity": str(specificity),
                "sensitivity": str(sensitivity),
                "full accuracy report": full_classification_report
            }
            atc.add_row(res)


Processing: method: wpli2, freq: delta
(96, 211)
Processing: method: wpli2, freq: delta, model: svm, feature selection: True
              precision    recall  f1-score   support

           0       0.78      0.67      0.72        27
           1       0.64      0.76      0.70        21

    accuracy                           0.71        48
   macro avg       0.71      0.71      0.71        48
weighted avg       0.72      0.71      0.71        48

Processing: method: wpli2, freq: theta
(96, 208)
Processing: method: wpli2, freq: theta, model: svm, feature selection: True
              precision    recall  f1-score   support

           0       0.95      0.70      0.81        27
           1       0.71      0.95      0.82        21

    accuracy                           0.81        48
   macro avg       0.83      0.83      0.81        48
weighted avg       0.85      0.81      0.81        48

Processing: method: wpli2, freq: alpha
(96, 204)
Processing: method: wpli2, freq: alpha, model: 