In [1]:
import numpy as np
import pandas as pd

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import Lasso, LogisticRegression
from sklearn.feature_selection import SelectFromModel, SelectKBest, chi2

from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import MinMaxScaler

import mne
from mne.io import read_raw_gdf
from mne.decoding import CSP, SPoC, TimeFrequency

from scipy.signal import welch, butter, lfilter, periodogram

import time
import pywt

from spectrum import pyule

In [2]:
def hjorth(a):
    first_deriv = np.diff(a)
    second_deriv = np.diff(a,2)

    var_zero = np.mean(a ** 2)
    var_d1 = np.mean(first_deriv ** 2)
    var_d2 = np.mean(second_deriv ** 2)

    activity = var_zero
    morbidity = np.sqrt(var_d1 / var_zero)
    complexity = np.sqrt(var_d2 / var_d1) / morbidity

    return activity, morbidity, complexity

In [3]:
def hurst(time_series, max_lag=20):
    lags = range(2, max_lag)

    tau = [np.std(np.subtract(time_series[lag:], time_series[:-lag])) for lag in lags]

    reg = np.polyfit(np.log(lags), np.log(tau), 1)

    return reg[0]

In [4]:
def tdp(x, i):
    p=[]
    p.append(np.log(np.mean(np.abs(x))))
    for dev in range(i):  
        x=np.diff(x)
        p.append(np.log(np.mean(np.abs(x))))
    return p

In [5]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

In [6]:
def yule_walker(data):
    p=pyule(data, 4, sampling=250, NFFT=5)
    return p.psd

In [10]:
def preprocessing(sub):

    # load data
    raw=read_raw_gdf('data/simple_MI/A0{}T.gdf'.format(sub), preload=True, verbose=False)

    # select eeg channels
    ex_ch=raw.info['ch_names'][22:]
    raw.drop_channels(ex_ch)

    # band-pass filter
    raw.filter(8., 30., fir_design='firwin')

    # get epoch
    events, _=mne.events_from_annotations(raw, event_id={'769':1, '770':2, '771':3, '772':4})
    epochs = mne.Epochs(raw, events, event_id={'769':1, '770':2, '771':3, '772':4}, tmin=1, tmax=4, baseline=None, preload=True, verbose=False)
    label=events[:, 2]
    epochs=epochs.get_data()
    
    return epochs, label

In [20]:
def preprocessing(sub):

    # load data
    raw=read_raw_gdf('data/simple_MI/B0{}T.gdf'.format(sub), preload=True, verbose=False)

    # band-pass filter
    raw.filter(8., 30., fir_design='firwin')

    # get epoch
    events, _=mne.events_from_annotations(raw, event_id={'769':1, '770':2, '771':3, '772':4})
    epochs = mne.Epochs(raw, events, event_id={'769':1, '770':2, '771':3, '772':4}, tmin=1, tmax=4, baseline=None, preload=True, verbose=False)
    label=events[:, 2]
    epochs=epochs.get_data()
    
    return epochs, label

In [8]:
def feature_extraction(feature, X, y):
    # spatial domain
    csp=CSP(n_components=6)
    spoc=SPoC(n_components=6)

    if feature=="csp":
        X=csp.fit_transform(X, y)
    elif feature=="scsp":
        X=csp.fit_transform(X, y)
        sfm=SelectFromModel(Lasso(alpha=.01))
        sfm.fit(X, y)
        X=X[:, sfm.get_support()]
    elif feature=="spoc":
        X=spoc.fit_transform(X, y)

    # time domain
    elif feature=="tdp":
        X=np.apply_along_axis(tdp, 2, X, i=2).reshape(X.shape[0], -1)
    elif feature=="hjorth":
        X=np.apply_along_axis(hjorth, 2, X).reshape(X.shape[0], -1)
    elif feature=="hurst":
        X=np.apply_along_axis(hurst, 2, X).reshape(X.shape[0], -1)

    # spectral domain
    elif feature=="welch":
        f, Pxx_den = welch(X, 250, nperseg=6)
        X=Pxx_den.reshape(X.shape[0], -1)
    elif feature=="periodogram":
        f, Pxx_den=periodogram(X, 250, nfft=6)
        X=Pxx_den.reshape(X.shape[0], -1)
    elif feature=="yulewalker":
        X=np.apply_along_axis(yule_walker, 2, X).reshape(X.shape[0], -1)

    # spatial+frequency domain
    elif feature=="smfcsp":
        csp.fit(X, y)
        X=np.hstack([csp.transform(butter_bandpass_filter(X, 8., 13., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 8., 10., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 10., 13., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 13., 30., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 13., 18., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 18., 23., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 23., 30., fs=250))])
        sfm=SelectFromModel(Lasso(alpha=.01))
        sfm.fit(X, y)
        X=X[:, sfm.get_support()]

    # time+frequency domain
    elif feature=="dwt":
        temp1=np.asarray(wavelet_features(X[0, :, :]))
        for i in range(1, X.shape[0]):
            temp2=np.asarray(wavelet_features(X[i, :, :]))
            temp1=np.concatenate((temp1, temp2), axis=0)
        X=temp1.reshape(X.shape[0], -1)
        
    # spatial+time domain
    else:
        X=1

    return X

In [9]:
df=pd.DataFrame()

In [31]:
%%capture

scaler = MinMaxScaler()

for fe in ["csp", "scsp", "spoc", "tdp", "hjorth", "hurst", "welch", "periodogram", "yulewalker"]:
    acc=[]
    tm=[]

    for sub in range(1, 4):
        epochs, label=preprocessing(sub)


        # scale data
        epochs=scaler.fit_transform(epochs.reshape(-1, epochs.shape[-1])).reshape(epochs.shape)

        # define feature extraction
        start=time.time()
        X=feature_extraction(fe, epochs, label)
        end=time.time()

        # define grid
        grid={"solver": ['lsqr'], 'shrinkage': ['auto']}

        clf=GridSearchCV(LinearDiscriminantAnalysis(), param_grid=grid, cv=5, refit=True, n_jobs=-1)
        clf.fit(X, label)
        acc.append(clf.best_score_)
        tm.append(end-start)

    df['{}_acc'.format(fe)]=acc
    df['{}_time'.format(fe)]=tm    

In [32]:
df.head()

Unnamed: 0,csp_acc,csp_time,scsp_acc,scsp_time,spoc_acc,spoc_time,tdp_acc,tdp_time,hjorth_acc,hjorth_time,hurst_acc,hurst_time,welch_acc,welch_time,periodogram_acc,periodogram_time,yulewalker_acc,yulewalker_time
0,0.955556,2.927267,0.955556,3.2906,0.855556,1.659683,0.755556,0.48124,0.816667,0.46376,0.622222,6.180391,0.772222,0.84522,0.35,0.00302,0.766667,11.918978
1,0.658333,3.425543,0.658333,3.763122,0.616667,1.312125,0.383333,0.339376,0.475,0.307777,0.408333,3.951676,0.441667,0.554712,0.283333,0.001995,0.475,7.914884
2,0.8,3.68374,0.783333,4.202655,0.791667,0.984423,0.416667,0.32559,0.458333,0.308996,0.291667,3.906852,0.466667,0.527607,0.291667,0.001995,0.408333,7.980083


In [33]:
df.to_csv('result/simple_lda(data2).csv')