In [7]:
import pywt
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
from random import randint
from sklearn import svm
from mrmr import mrmr_classif
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.pipeline import make_pipeline

from mne import read_epochs, set_log_level, compute_rank, concatenate_epochs
from mne.decoding import Scaler 

from pyriemann.estimation import Covariances, Kernels
from pyriemann.utils.distance import distance
from pyriemann.classification import MDM, FgMDM, KNearestNeighbor
from pyriemann.tangentspace import TangentSpace
from sklearn_rvm import EMRVC

from jupyterthemes.stylefx import set_nb_theme
set_nb_theme('gruvboxd')

In [8]:
set_log_level('warning')
epochs = read_epochs('ica_epo.fif').pick('eeg').filter(0,240)
epochs.drop_channels(epochs.info['bads'])
epochs.apply_baseline((-1.4,-0.4))

0,1
Number of events,120
Events,left: 24 r_pinch: 25 r_stop: 25 rest: 21 right: 25
Time range,-2.000 – 7.999 sec
Baseline,-1.400 – -0.400 sec


In [9]:
le = LabelEncoder()
scaler = Scaler(info=epochs.info)
tangent_space = TangentSpace()
mdm = MDM(n_jobs=1)
fmdm = FgMDM(n_jobs=2)
knn = KNearestNeighbor(n_neighbors=4, n_jobs=1)
rvm = EMRVC(kernel="rbf", gamma="auto")
svm_rbf = svm.SVC(kernel="rbf")
lda = LinearDiscriminantAnalysis(solver='lsqr',shrinkage='auto')
pca = PCA(n_components=0.95)

In [10]:
estimators = [ 
    "cov-sch"
]

left vs rest

In [11]:
fmax = 35
fmin = 18
conditions = ['left','rest']
subset = epochs[conditions].copy()
subset.drop_channels(subset.info['bads'])
subset = subset.pick(['eeg'])
subset = subset.apply_baseline((-1.4,-0.4))
y = le.fit_transform(subset.events[:,2])
train_data = subset.copy().crop(0.4,1.2)
dwt_data = get_dwt_coeff(train_data.get_data(),3,5, 'db26')
train_data = train_data.filter(fmin,fmax).get_data()
time_config = (3,0.4,300,100)

(45, 150, 99)


In [12]:
chance = np.mean(y == y[0])
chance = max(chance, 1. - chance)
chance

0.5333333333333333

In [13]:
find_accuracy_psd_estimators('riemannian_left_rest.npy')

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.78it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [00:59<00:00,  3.94s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.19it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [01:02<00:00,  4.16s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:07<00:00,  1.33it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [00:47<00:00,  3.15s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.26it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [01:12<00:00,  4.81s/it]
100%|███████████████████████████████████

cov-sch  : 
lda  0.8  ; mdm 0.49777777777777776  ; knn  0.49111111111111116  ; dwt+lda  0.8488888888888889 ; dwt+svm 0.8266666666666668


In [9]:
find_accuracy_psd_estimators()

cov-lwf  : 
lda  0.5111111111111112  ; mdm 0.47777777777777786  ; cff+lda  0.5377777777777778
cov-sch  : 
lda  0.5088888888888888  ; mdm 0.48888888888888893  ; cff+lda  0.6177777777777779
ker-rbf  : 
lda  0.6733333333333335  ; mdm 0.6844444444444445  ; cff+lda  0.5311111111111111
ker-polynomial  : 
lda  0.6977777777777778  ; mdm 0.6400000000000001  ; cff+lda  0.6955555555555558


left vs right

In [14]:
fmax = 60
fmin = 35
conditions = ['left','right']
subset = epochs[conditions].copy()
subset.drop_channels(subset.info['bads'])
subset = subset.pick(['eeg'])
subset = subset.apply_baseline((-1.4,-0.4))
y = le.fit_transform(subset.events[:,2])
train_data = subset.copy().crop(0.4,1.2)
dwt_data = get_dwt_coeff(train_data.get_data(),5,6, 'db30')
train_data = train_data.filter(fmin,fmax).get_data()
time_config = (3,0.4,300,100)

(49, 150, 70)


In [15]:
chance = np.mean(y == y[0])
chance = max(chance, 1. - chance)
chance

0.5102040816326531

In [16]:
find_accuracy_psd_estimators('riemannian_left_right.npy')

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.94it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [01:01<00:00,  4.13s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:04<00:00,  2.21it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [01:02<00:00,  4.14s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:05<00:00,  1.79it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [00:56<00:00,  3.80s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:06<00:00,  1.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [01:03<00:00,  4.22s/it]
100%|███████████████████████████████████

cov-sch  : 
lda  0.8110294117647059  ; mdm 0.586642156862745  ; knn  0.609436274509804  ; dwt+lda  0.8325980392156863 ; dwt+svm 0.7700980392156863


In [11]:
find_accuracy_psd_estimators()

cov-lwf  : 
lda  0.5268382352941177  ; mdm 0.5080882352941176  ; cff+lda  0.5223039215686275
cov-sch  : 
lda  0.511764705882353  ; mdm 0.4997549019607843  ; cff+lda  0.5166666666666667
ker-rbf  : 
lda  0.5404411764705883  ; mdm 0.5099264705882354  ; cff+lda  0.5384803921568628
ker-polynomial  : 
lda  0.5262254901960784  ; mdm 0.5197303921568628  ; cff+lda  0.5333333333333333


In [6]:
def find_accuracy_psd_estimators(file_name):
    data = scaler.fit_transform(train_data, y)
    
    for est in estimators:
        mdm_score = []
        lda_score = []
        dwt_score = []
        svm_dwt_score = []
        knn_score = []
        for train_rep in range(10):
            cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=randint(15,25) + train_rep)
            
            mrmr_features = None
            mrmr_coeff_features = None
            cv_split = cv.split(data, y)
            for train_idx, test_idx in cv_split:
                y_train, y_test = y[train_idx], y[test_idx]
                
                est_class, est_param = est.split('-')
                x_train = []
                x_test = []
                x_coeff_train = []
                x_coeff_test = []
                
                if est_class == "ker":
                    krn = Kernels(metric=est_param)
                    x_train = krn.fit_transform(data[train_idx], y_train)
                    x_test = krn.transform(data[test_idx])
                    x_coeff_train = krn.fit_transform(dwt_data[train_idx], y_test)
                    x_coeff_test = krn.transform(dwt_data[test_idx])
                else:
                    psd = Covariances(estimator=est_param)
                    x_train = psd.fit_transform(data[train_idx], y_train)
                    x_test = psd.transform(data[test_idx])
                    x_coeff_train = psd.fit_transform(dwt_data[train_idx], y_train)
                    x_coeff_test = psd.transform(dwt_data[test_idx])
                    
                mdm.fit(x_train, y_train)
                mdm_score.append(np.median(mdm.score(x_test, y_test)))
                knn.fit(x_train, y_train)
                knn_score.append(np.median(knn.score(x_test, y_test)))

                x_train = tangent_space.fit_transform(x_train)
                x_test = tangent_space.transform(x_test)
                if mrmr_features is None:
                    x_pd = pd.DataFrame(x_train)
                    mrmr_features = mrmr_classif(X=x_pd, y=y_train, K=10)
   
                lda.fit(x_train[:, mrmr_features], y_train)
                lda_score.append(np.median(lda.score(x_test[:, mrmr_features], y_test)))
        
                x_coeff_train = tangent_space.fit_transform(x_coeff_train)
                x_coeff_test = tangent_space.transform(x_coeff_test)
                 
                if mrmr_coeff_features is None:
                    x_pd = pd.DataFrame(x_coeff_train)
                    mrmr_coeff_features = mrmr_classif(X=x_pd, y=y_train, K=15)
  
                lda.fit(x_coeff_train[:, mrmr_coeff_features], y_train)
                dwt_score.append(np.median(lda.score(x_coeff_test[:, mrmr_coeff_features], y_test)))
                svm_rbf.fit(x_coeff_train[:, mrmr_coeff_features], y_train)
                svm_dwt_score.append(np.median(svm_rbf.score(x_coeff_test[:, mrmr_coeff_features], y_test)))
  
        print(est, ' : ')
        print('lda ', np.mean(lda_score) ,
              " ; mdm", np.mean(mdm_score), ' ; knn ', np.mean(knn_score),
              ' ; dwt+lda ', np.mean(dwt_score),
             '; dwt+svm', np.mean(svm_dwt_score))
        np.save(file_name,np.array([np.mean(lda_score),np.mean(mdm_score),
               np.mean(knn_score), np.mean(dwt_score), np.mean(svm_dwt_score)]))

def dwt_det_coeff(x, db='db2'):
    aprx, det = pywt.dwt(x,db)
    return det

def dwt_aprox_coeff(x, db='db2'):
    aprx, det = pywt.dwt(x,db)
    return aprx

def get_dwt_coeff(train_data, lvl, lvl1, db='db4'):
    x_aprox_coeff = train_data
          
    detail_coeffs = []
    aprox_coeffs = []
    for dwt_lvl in range(lvl1+1):
        x_det_coeff = np.apply_along_axis(dwt_det_coeff, 2, x_aprox_coeff, db=db)
        x_aprox_coeff = np.apply_along_axis(dwt_aprox_coeff, 2, x_aprox_coeff, db=db)

        detail_coeffs.append(scaler.fit_transform(x_det_coeff.copy(),y))
        aprox_coeffs.append(scaler.fit_transform(x_aprox_coeff.copy(),y))
        
    coeffs_shape = detail_coeffs[lvl].shape
    
    lvl1_coeffs = np.zeros(coeffs_shape)
    lvl1_coeffs[:,:,:detail_coeffs[lvl1].shape[2]] = detail_coeffs[lvl1]
   
    aprox0_coeffs = np.zeros(coeffs_shape)
    aprox0_coeffs[:,:,:aprox_coeffs[-1].shape[2]] = aprox_coeffs[-1]

    dwt_data = np.concatenate([detail_coeffs[lvl], lvl1_coeffs, aprox0_coeffs], axis=1)
    print(dwt_data.shape)
    return dwt_data     
    

In [17]:
def dwt_det_coeff(x, db='db2'):
    aprx, det = pywt.dwt(x,db)
    return det

def dwt_aprx_coeff(x, db='db2'):
    aprx, det = pywt.dwt(x,db)
    return aprx

def get_dwt_coeff(x, db, lvl):
    aprox_coeff = x
    for lvl in range(lvl):
        det_coeff = np.apply_along_axis(dwt_det_coeff, 2, aprox_coeff, db=db)
        aprox_coeff = np.apply_along_axis(dwt_aprx_coeff, 2, aprox_coeff, db=db)
    
    return det_coeff

In [18]:
def concat_index(train_data):
    epochs, channels, points = train_data.shape
    indexed = np.zeros([epochs,channels,points+1])
    for e in range(epochs):
        for ch in range(channels):
            indexed[e,ch,0] = ch 
            indexed[e,ch,1:] = train_data[e,ch]

    return indexed