In [4]:
# import all libraries
import pandas as pd
import numpy as np
import glob
import os
import mne

from sklearn.svm import SVC
from sklearn.metrics import mean_squared_error
from sklearn.svm import LinearSVC
from sklearn.feature_selection import SelectFromModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation
from tensorflow.keras.optimizers import SGD

In [5]:
# all functions
# example 1_al_ciplv_theta.npy
def parse_filename(filename):
    s = filename.split("_")
    pId = s[0]
    label = s[1]
    method = s[2]
    freq = s[-2]
    
    return {"pId": pId, "label": label, "method": method, "freq": freq}

def read_file(filename):
    return np.load(filename)

# get all the files in the folder output/
def get_files(folder, filter = None):
    '''
    filter is a dictionary, has 2 keys: method and freq to filter the files
    '''
    files = glob.glob(os.path.join(folder, "*.npy"))
    ret_files = []
    if filter:
        for f in files:
            # get the filename
            filename = os.path.basename(f)
            f_info = parse_filename(filename)
            if (f_info["method"] == filter["method"]) and (f_info["freq"] == filter["freq"]):
                ret_files.append(f)
    else:
        ret_files = files

    return ret_files

def feature_extraction(data):
    """
    get lower part of diagonal matrix
    """
    data_lower = data[np.tril_indices(data.shape[0], k=-1)]
    return data_lower

In [7]:
# logging to airtable
from airtable.airtable import AirTableClient
from airtable.config import config

atc = AirTableClient(**config["airtable"])

In [15]:
# constants
# get the labels for electrodes
# fs_dir = mne.datasets.fetch_fsaverage(verbose=True)
# subjects_dir = os.path.dirname(fs_dir)
# labels = mne.read_labels_from_annot('fsaverage', parc='aparc',
#                                     subjects_dir=subjects_dir)
# labels.pop(-1)
# label_colors = [label.color for label in labels]
data = "split_10s"
labels = {"al": 0, "fa":1}

methods = ['pli', 'wpli2', 'ciplv']
freqs = ['delta', 'theta', 'alpha', 'beta', 'gamma']


In [16]:
for method in methods:
    for freq in freqs:
        print("Processing: method: {}, freq: {}".format(method, freq))
        filter = {"method": method, "freq": freq}
        files = get_files("output/all/", filter)
        X = []
        Y = []
        for f in files:
            data = read_file(f)
            X.append(feature_extraction(data))
            Y.append(labels.get(parse_filename(os.path.basename(f))["label"]))

        X = np.array(X)
        Y = np.array(Y)
            
        is_feature_selection = True   
        for ifs in [False, True]:
            is_feature_selection = ifs
            if ifs:
                lsvc = LinearSVC(C=0.001, penalty="l2", dual=False).fit(X, Y)
                model = SelectFromModel(lsvc, prefit=True)
                X_new = model.transform(X)
            else:
                X_new = X

            X_new = np.array(X_new)
            Y = np.array(Y)

            X_train, X_test, Y_train, Y_test = train_test_split(X_new, Y, test_size=0.3, random_state=12)
            for m in ["SVC", "FCN"]:
                if m == "SVC":
                    model = SVC(kernel='linear')
                    model.fit(X_train, Y_train)

                    # evaluate model
                    Y_pred = model.predict(X_test)
                    mse = mean_squared_error(Y_test, Y_pred)

                    # y predicted to binary
                    Y_pred_binary = np.where(Y_pred > 0.5, 1, 0)

                    acc = accuracy_score(Y_test, Y_pred_binary).round(2)

                    # print classification report
                    full_classification_report = classification_report(Y_test, Y_pred_binary)

                elif m == "FCN":
                    # create model
                    model = Sequential()
                    model.add(Dense(units=256, input_dim=X_new.shape[1]))
                    model.add(Activation('relu'))
                    model.add(Dropout(0.3))
                    model.add(Dense(units=128))
                    model.add(Activation('relu'))
                    model.add(Dropout(0.3))
                    model.add(Dense(units=1, activation='sigmoid'))

                    model.compile(loss='binary_crossentropy', optimizer=SGD(learning_rate=0.01), metrics=['accuracy'])
                    # fit model with validation data
                    model.fit(X_train, Y_train, epochs=100, batch_size=8, verbose=0, validation_data=(X_test, Y_test))

                    # clasification report
                    Y_pred = model.predict(X_test)
                    Y_pred_binary = np.where(Y_pred > 0.5, 1, 0)

                    acc = accuracy_score(Y_test, Y_pred_binary).round(2)
                    full_classification_report = classification_report(Y_test, Y_pred_binary)

                
                res = {
                    "data": data,
                    "method": method,
                    "frequency": freq,
                    "model": m,
                    "feature selection": str(is_feature_selection),
                    "accuracy": str(acc),
                    "full accuracy report": full_classification_report
                }
                atc.add_row(res)


Processing: method: pli, freq: delta
MSE: 0.5862
MSE: 0.1724
Processing: method: pli, freq: theta
MSE: 0.4828
MSE: 0.1034
Processing: method: pli, freq: alpha
MSE: 0.4483
MSE: 0.2759
Processing: method: pli, freq: beta
MSE: 0.3448
MSE: 0.1034
Processing: method: pli, freq: gamma
MSE: 0.4483


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


MSE: 0.0690
Processing: method: wpli2, freq: delta
MSE: 0.5862
MSE: 0.2069
Processing: method: wpli2, freq: theta
MSE: 0.5172
MSE: 0.2414
Processing: method: wpli2, freq: alpha
MSE: 0.6207
MSE: 0.2759
Processing: method: wpli2, freq: beta
MSE: 0.4483
MSE: 0.1034
Processing: method: wpli2, freq: gamma
MSE: 0.4483
MSE: 0.1034
Processing: method: ciplv, freq: delta
MSE: 0.5172
MSE: 0.1724
Processing: method: ciplv, freq: theta
MSE: 0.5172
MSE: 0.1034
Processing: method: ciplv, freq: alpha
MSE: 0.4828
MSE: 0.1379
Processing: method: ciplv, freq: beta
MSE: 0.3103
MSE: 0.0690
Processing: method: ciplv, freq: gamma
MSE: 0.4138
MSE: 0.1034
