In [None]:
import numpy as np
import pandas as pd

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import Lasso, LogisticRegression
from sklearn.feature_selection import SelectFromModel

from sklearn.metrics import accuracy_score, cohen_kappa_score
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.preprocessing import MinMaxScaler

import mne
from mne.io import read_raw_gdf
from mne.decoding import CSP, SPoC, TimeFrequency

from scipy.signal import welch, butter, lfilter, periodogram

import time

import pywt

from spectrum import pyule

In [None]:
def hjorth(a):
    first_deriv = np.diff(a)
    second_deriv = np.diff(a,2)

    var_zero = np.mean(a ** 2)
    var_d1 = np.mean(first_deriv ** 2)
    var_d2 = np.mean(second_deriv ** 2)

    activity = var_zero
    morbidity = np.sqrt(var_d1 / var_zero)
    complexity = np.sqrt(var_d2 / var_d1) / morbidity

    return activity, morbidity, complexity

In [None]:
def hurst(time_series, max_lag=20):
    lags = range(2, max_lag)

    tau = [np.std(np.subtract(time_series[lag:], time_series[:-lag])) for lag in lags]

    reg = np.polyfit(np.log(lags), np.log(tau), 1)

    return reg[0]

In [None]:
def tdp(x, i):
    p=[]
    p.append(np.log(np.mean(np.abs(x))))
    for dev in range(i):  
        x=np.diff(x)
        p.append(np.log(np.mean(np.abs(x))))
    return p

In [None]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

In [None]:
def yule_walker(data):
    p=pyule(data, 4, sampling=512, NFFT=5)
    return p.psd

In [None]:
def wavelet_features(epoch):
    cA_values = []
    cD_values = []
    cA_mean = []
    cA_std = []
    cA_Energy =[]
    cD_mean = []
    cD_std = []
    cD_Energy = []
    Entropy_D = []
    Entropy_A = []
    for i in range(22):
        cA,cD=pywt.dwt(epoch[i,:],'db4')
        cA_values.append(cA)
        cD_values.append(cD)		#calculating the coefficients of wavelet transform.
    for x in range(22):   
        cA_mean.append(np.mean(cA_values[x]))
        cA_std.append(np.std(cA_values[x]))
        cA_Energy.append(np.sum(np.square(cA_values[x])))
        cD_mean.append(np.mean(cD_values[x]))		# mean and standard deviation values of coefficents of each channel is stored .
        cD_std.append(np.std(cD_values[x]))
        cD_Energy.append(np.sum(np.square(cD_values[x])))
#         Entropy_D.append(np.sum(np.square(cD_values[x]) * np.log(np.square(cD_values[x]))))
#         Entropy_A.append(np.sum(np.square(cA_values[x]) * np.log(np.square(cA_values[x]))))
    return np.sum(cA_mean)/22,np.sum(cA_std)/14,np.sum(cD_mean)/14,np.sum(cD_std)/14,np.sum(cA_Energy)/14,np.sum(cD_Energy)/14

In [None]:
def preprocessing(sub, run):
    # load data
    raw=read_raw_gdf('data/complex_MI/S{}_MI/motorimagination_subject{}_run{}.gdf'.format(sub, sub, run), preload=True)

    # select eeg channels
    ex_ch=raw.info['ch_names'][61:]
    raw.drop_channels(ex_ch)
    
    # band-pass filter
    raw.filter(8., 30., fir_design='firwin')
    
    # get epoch
    events, _ = mne.events_from_annotations(raw, event_id={'1537':1, '1538':2, '1540':3, '1541':4})
    epochs = mne.Epochs(raw, events, event_id={'1537':1, '1538':2, '1540':3, '1541':4}, tmin=1, tmax=4, baseline=None, preload=True)
    label=events[:, 2]
    epochs=epochs.get_data()

    # remove nan
    if np.isnan(np.sum(epochs)):
        nan_idx=np.unique(np.where(np.isnan(epochs).any(axis=2))[0])[0]
        epochs=epochs[:nan_idx-1, :]
        label=events[:nan_idx-1, :][:, 2]
    
    return epochs, label

In [None]:
def feature_extraction(feature, X, y):
    # spatial domain
    csp=CSP(n_components=6)
    spoc=SPoC(n_components=6)
    if feature=="csp":
        X=csp.fit_transform(X, y)
    elif feature=="scsp":
        X=csp.fit_transform(X, y)
        sfm=SelectFromModel(Lasso(alpha=.01))
        sfm.fit(X, y)
        X=X[:, sfm.get_support()]
    elif feature=="spoc":
        X=spoc.fit_transform(X, y)

    # time domain
    elif feature=="tdp":
        X=np.apply_along_axis(tdp, 2, X, i=2).reshape(X.shape[0], -1)
    elif feature=="hjorth":
        X=np.apply_along_axis(hjorth, 2, X).reshape(X.shape[0], -1)
    elif feature=="hurst":
        X=np.apply_along_axis(hurst, 2, X).reshape(X.shape[0], -1)

    # spectral domain
    elif feature=="welch":
        f, Pxx_den = welch(X, 250, nperseg=6)
        X=Pxx_den.reshape(X.shape[0], -1)
    elif feature=="periodogram":
        f, Pxx_den=periodogram(X, 250, nfft=6)
        X=Pxx_den.reshape(X.shape[0], -1)
    elif feature=="yulewalker":
        X=np.apply_along_axis(yule_walker, 2, X).reshape(X.shape[0], -1)
    
    # spatial+frequency domain
    elif feature=="smfcsp":
        csp.fit(X, y)
        X=np.hstack([csp.transform(butter_bandpass_filter(X, 8., 13., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 8., 10., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 10., 13., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 13., 30., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 13., 18., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 18., 23., fs=250)),
                       csp.transform(butter_bandpass_filter(X, 23., 30., fs=250))])
        sfm=SelectFromModel(Lasso(alpha=.01))
        sfm.fit(X, y)
        X=X[:, sfm.get_support()]

    # time+frequency domain
    elif feature=="dwt":
        temp1=np.asarray(wavelet_features(X[0, :, :]))
        for i in range(1, X.shape[0]):
            temp2=np.asarray(wavelet_features(X[i, :, :]))
            temp1=np.concatenate((temp1, temp2), axis=0)
        X=temp1.reshape(X.shape[0], -1)    
    # spatial+time domain
    else:
        X=1
    
    return X

In [None]:
%%capture
d={}
for i in range(1, 16):
    
    # conat data
    epochs, label=preprocessing(i, 1)
    for run in range(2, 11):
        temp_e, temp_l=preprocessing(1, run)
        epochs=np.concatenate((epochs, temp_e))
        label=np.concatenate((label, temp_l))
    d['x{}'.format(i)]=epochs
    d['y{}'.format(i)]=label

In [None]:
df=pd.DataFrame()

In [None]:
%%capture

scaler = MinMaxScaler()

for fe in ["dwt"]:
    acc=[]
    tm=[]

    for sub in range(1, 16):
        epochs=d['x{}'.format(sub)]
        label=d['y{}'.format(sub)]

        # scale data
        epochs=scaler.fit_transform(epochs.reshape(-1, epochs.shape[-1])).reshape(epochs.shape)

        # define feature extraction
        start=time.time()
        X=feature_extraction(fe, epochs, label)
        end=time.time()

        # define grid
        grid={"C": [1e-2, 1e-1, 1, 1e+1, 1e+2], "kernel": ["rbf", "linear", "poly"]}

        clf=GridSearchCV(SVC(), param_grid=grid, cv=5, refit=True, n_jobs=-1)
        clf.fit(X, label)
        acc.append(clf.best_score_)
        tm.append(end-start)
        
    df['{}_acc'.format(fe)]=acc
    df['{}_time'.format(fe)]=tm    

In [None]:
df["dwt_acc"].values

In [None]:
df["dwt_time"].values

array([0.4442246 , 0.46575427, 0.46626019, 0.49849606, 0.46276236,
       0.46226859, 0.46475673, 0.46775913, 0.47025871, 0.46226883,
       0.46475887, 0.46375918, 0.46682501, 0.4757278 , 0.4843297 ])