In [1]:
"""
 Sample script using EEGNet to classify Event-Related Potential (ERP) EEG data
 from a four-class classification task, using the sample dataset provided in
 the MNE [1, 2] package:
     https://martinos.org/mne/stable/manual/sample_dataset.html#ch-sample-data
   
 The four classes used from this dataset are:
     LA: Left-ear auditory stimulation
     RA: Right-ear auditory stimulation
     LV: Left visual field stimulation
     RV: Right visual field stimulation

 The code to process, filter and epoch the data are originally from Alexandre
 Barachant's PyRiemann [3] package, released under the BSD 3-clause. A copy of 
 the BSD 3-clause license has been provided together with this software to 
 comply with software licensing requirements. 
 
 When you first run this script, MNE will download the dataset and prompt you
 to confirm the download location (defaults to ~/mne_data). Follow the prompts
 to continue. The dataset size is approx. 1.5GB download. 
 
 For comparative purposes you can also compare EEGNet performance to using 
 Riemannian geometric approaches with xDAWN spatial filtering [4-8] using 
 PyRiemann (code provided below).

 [1] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck,
     L. Parkkonen, M. Hämäläinen, MNE software for processing MEG and EEG data, 
     NeuroImage, Volume 86, 1 February 2014, Pages 446-460, ISSN 1053-8119.

 [2] A. Gramfort, M. Luessi, E. Larson, D. Engemann, D. Strohmeier, C. Brodbeck, 
     R. Goj, M. Jas, T. Brooks, L. Parkkonen, M. Hämäläinen, MEG and EEG data 
     analysis with MNE-Python, Frontiers in Neuroscience, Volume 7, 2013.

 [3] https://github.com/alexandrebarachant/pyRiemann. 

 [4] A. Barachant, M. Congedo ,"A Plug&Play P300 BCI Using Information Geometry"
     arXiv:1409.0107. link

 [5] M. Congedo, A. Barachant, A. Andreev ,"A New generation of Brain-Computer 
     Interface Based on Riemannian Geometry", arXiv: 1310.8115.

 [6] A. Barachant and S. Bonnet, "Channel selection procedure using riemannian 
     distance for BCI applications," in 2011 5th International IEEE/EMBS 
     Conference on Neural Engineering (NER), 2011, 348-351.

 [7] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Multiclass 
     Brain-Computer Interface Classification by Riemannian Geometry,” in IEEE 
     Transactions on Biomedical Engineering, vol. 59, no. 4, p. 920-928, 2012.

 [8] A. Barachant, S. Bonnet, M. Congedo and C. Jutten, “Classification of 
     covariance matrices using a Riemannian-based kernel for BCI applications“, 
     in NeuroComputing, vol. 112, p. 172-178, 2013.


 Portions of this project are works of the United States Government and are not
 subject to domestic copyright protection under 17 USC Sec. 105.  Those 
 portions are released world-wide under the terms of the Creative Commons Zero 
 1.0 (CC0) license.  
 
 Other portions of this project are subject to domestic copyright protection 
 under 17 USC Sec. 105.  Those portions are licensed under the Apache 2.0 
 license.  The complete text of the license governing this material is in 
 the file labeled LICENSE.TXT that is a part of this project's official 
 distribution. 
"""

In [2]:
import PyQt5
%config InlineBackend.figure_format = 'retina'
%matplotlib qt5

import numpy as np

# mne imports
# import mne
# from mne import io
# from mne.datasets import sample

# EEGNet-specific imports
from EEGModels import EEGNet
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K

# PyRiemann imports
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from pyriemann.utils.viz import plot_confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression

# tools for plotting confusion matrices
from matplotlib import pyplot as plt

# while the default tensorflow ordering is 'channels_last' we set it here
# to be explicit in case if the user has changed the default ordering
K.set_image_data_format('channels_last')

In [3]:
def loadmeta():
    """Load meta data for analysis of PD data.

    Returns
    -------
    Fs : float
        Sampling rate (Hz)
    t : numpy array
        time array corresponding to the eeg signals
    S : int
        Number of PD patients
    Sc : int
        Number of control subjects
    flo : 2-element tuple
        frequency limits of the beta range (Hz)
    fhi : 2-element tuple
        frequency limits for the high gamma range (Hz)
    """

    Fs = 500  # Sampling rate (Hz)
    t = np.arange(0, 30, 1 / Fs)  # Time series (seconds)
    S = 270 #540 #324 #1526
    Sc = 270 #540 #324 #1509
    Smed = 270 #540 #324 #1516
    flo = (13,30)
    fhi = (50, 150)
    return Fs, t, S, Sc, Smed, flo, fhi

In [4]:
def _blankeeg(dtype=object):
    Fs, t, S, Sc, Smed, flo, fhi = loadmeta()
    eeg = {}
    eeg['off'] = np.zeros(S, dtype=dtype)
    eeg['on'] = np.zeros(Smed, dtype=dtype)
    eeg['C'] = np.zeros(Sc, dtype=dtype)
    return eeg

In [5]:
def loadPD():
    '''
    Load the data after following preprocessing:
    1. Average referenced
    
    Load rejection indices:
    1. Each index in an array

    Parameters
    ----------
    filepath : string
        path to averaged referenced data
    filepathrej : string
        path to rejection indices
        
    Returns
    -------
    eeg : dict
        Pre-processed voltage traces
        'off' : subject-by-time array for PD patients OFF medication
        'on' : subject-by-time array for PD patients ON medication        
        'C' : subject-by-time array for control subjects
        
    rejects : dict
        rejection indices including muscle artifacts
        'off' : rejection indices for PD patients OFF medication
        'on' : rejection indices for ON medication        
        'C' : rejection indices for control subjects

    '''
  
    dirB='../../UNM_Dataset/EEGNet_data/intra-patient/off_med/'
    tempB = []
    for file in os.listdir(dirB):
        fl = dirB + file
        tempB.append(fl)
    tempB = sorted(tempB)    # class: 1     val:  -1

    dirC='../../UNM_Dataset/EEGNet_data/intra-patient/on_med/'
    tempC = []
    for file in os.listdir(dirC):
        fl = dirC + file
        tempC.append(fl)
    tempC = sorted(tempC)   # class:2       val: 0

    dirE='../../UNM_Dataset/EEGNet_data/intra-patient/control/'
    tempE = []
    for file in os.listdir(dirE):
        fl = dirE + file
        tempE.append(fl)
    tempE = sorted(tempE)   # class: 3        val: 1

    eeg = _blankeeg()
#     print(len(tempB))
    for i in range(len(tempB)):
        data=io.loadmat(tempB[i], struct_as_record=False, squeeze_me=True)
        eeg['off'][i] = data['newdataMat']
        
    for i in range(len(tempC)):
        data=io.loadmat(tempC[i], struct_as_record=False, squeeze_me=True)
        eeg['on'][i] = data['newdataMat']
        
    for i in range(len(tempE)):
        data=io.loadmat(tempE[i], struct_as_record=False, squeeze_me=True)
        eeg['C'][i] = data['newdataMat']

    Fs, t, S, Sc, Smed, flo, fhi = loadmeta()
    rejects = _blankeeg()    
    rejects['off'] = np.zeros(S, dtype=int)
    rejects['on'] = np.zeros(Smed, dtype=int)
    rejects['C'] = np.zeros(Sc, dtype=int)
    
    return eeg,rejects

In [6]:
from __future__ import division
import numpy as np
import scipy as sp
from scipy import io
from scipy import signal
import os

In [7]:
Fs, t, S, Sc, Smed, flo, fhi = loadmeta() 
eeg,rejects = loadPD() # EO means Eyes Opened

In [8]:
cl_B    = np.ones((1,S))
cl_C    = np.full((1,Smed),3)
cl_E    = np.full((1,Sc),2)

In [9]:
def make_dataset(signals, subjects, channels, samples):
    dataset = np.zeros((subjects,channels,samples))
    for i in range(len(signals)):
        dataset[i,:,:] = signals[i]
    return dataset

In [10]:
kernels, chans, samples = 1, 60, 2900

X_off = make_dataset(eeg['off'],S,chans,samples)
X_ctl = make_dataset(eeg['on'],Smed,chans,samples)

X = np.concatenate([X_off, X_ctl], axis=0)
y = np.ravel(np.concatenate([cl_B, cl_E], axis = 0))

In [11]:
##################### Process, filter and epoch the data ######################

# take 50/25/25 percent of the data to train/validate/test
from sklearn.model_selection import train_test_split
X_train, X_test, Y_train, Y_test = train_test_split(X, y, test_size=0.5, random_state=156156)
X_test, X_validate, Y_test, Y_validate = train_test_split(X_test, Y_test, test_size=0.5, random_state=42)

In [12]:
############################# EEGNet portion ##################################

# convert labels to one-hot encodings.
Y_train      = np_utils.to_categorical(Y_train-1)
Y_validate   = np_utils.to_categorical(Y_validate-1)
Y_test       = np_utils.to_categorical(Y_test-1)

# convert data to NHWC (trials, channels, samples, kernels) format. Data 
# contains 60 channels and 151 time-points. Set the number of kernels to 1.
X_train      = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate   = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test       = X_test.reshape(X_test.shape[0], chans, samples, kernels)
   
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

# configure the EEGNet-8,2,16 model with kernel length of 32 samples (other 
# model configurations may do better, but this is a good starting point)
model = EEGNet(nb_classes = 2, Chans = chans, Samples = samples, 
               dropoutRate = 0.5, kernLength = 16, F1 = 8, D = 2, F2 = 16, 
               dropoutType = 'Dropout')

# compile the model and set the optimizers
model.compile(loss='categorical_crossentropy', optimizer='adam', 
              metrics = ['accuracy'])

# count number of parameters in the model
numParams    = model.count_params()    

# set a valid path for your system to record model checkpoints
checkpointer = ModelCheckpoint(filepath='checkpoint.h5', verbose=1,
                               save_best_only=True)

X_train shape: (270, 60, 2900, 1)
270 train samples
135 test samples


In [13]:
model.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 60, 2900, 1)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 60, 2900, 8)       128       
_________________________________________________________________
batch_normalization (BatchNo (None, 60, 2900, 8)       32        
_________________________________________________________________
depthwise_conv2d (DepthwiseC (None, 1, 2900, 16)       960       
_________________________________________________________________
batch_normalization_1 (Batch (None, 1, 2900, 16)       64        
_________________________________________________________________
activation (Activation)      (None, 1, 2900, 16)       0         
_________________________________________________________________
average_pooling2d (AveragePo (None, 1, 725, 16)       

In [14]:
###############################################################################
# if the classification task was imbalanced (significantly more trials in one
# class versus the others) you can assign a weight to each class during 
# optimization to balance it out. This data is approximately balanced so we 
# don't need to do this, but is shown here for illustration/completeness. 
###############################################################################

# the syntax is {class_1:weight_1, class_2:weight_2,...}. Here just setting
# the weights all to be 1
class_weights = {0:1, 1:1, 2:1, 3:1}

In [15]:

################################################################################
# fit the model. Due to very small sample sizes this can get
# pretty noisy run-to-run, but most runs should be comparable to xDAWN + 
# Riemannian geometry classification (below)
################################################################################
fittedModel = model.fit(X_train, Y_train, batch_size = 32, epochs = 100, 
                        verbose = 2, validation_data=(X_validate, Y_validate),
                        callbacks=[checkpointer])

# load optimal weights
model.load_weights('checkpoint.h5')

Epoch 1/100

Epoch 00001: val_loss improved from inf to 0.69274, saving model to checkpoint.h5
9/9 - 1s - loss: 0.7077 - accuracy: 0.5111 - val_loss: 0.6927 - val_accuracy: 0.4667
Epoch 2/100

Epoch 00002: val_loss improved from 0.69274 to 0.68396, saving model to checkpoint.h5
9/9 - 1s - loss: 0.6534 - accuracy: 0.6519 - val_loss: 0.6840 - val_accuracy: 0.4741
Epoch 3/100

Epoch 00003: val_loss did not improve from 0.68396
9/9 - 1s - loss: 0.6435 - accuracy: 0.6333 - val_loss: 0.6866 - val_accuracy: 0.5111
Epoch 4/100

Epoch 00004: val_loss did not improve from 0.68396
9/9 - 1s - loss: 0.6031 - accuracy: 0.7630 - val_loss: 0.6897 - val_accuracy: 0.5185
Epoch 5/100

Epoch 00005: val_loss improved from 0.68396 to 0.68008, saving model to checkpoint.h5
9/9 - 1s - loss: 0.5847 - accuracy: 0.7704 - val_loss: 0.6801 - val_accuracy: 0.5556
Epoch 6/100

Epoch 00006: val_loss improved from 0.68008 to 0.67682, saving model to checkpoint.h5
9/9 - 1s - loss: 0.5646 - accuracy: 0.7444 - val_loss: 

Epoch 49/100

Epoch 00049: val_loss improved from 0.33407 to 0.31532, saving model to checkpoint.h5
9/9 - 1s - loss: 0.0803 - accuracy: 0.9963 - val_loss: 0.3153 - val_accuracy: 0.8889
Epoch 50/100

Epoch 00050: val_loss did not improve from 0.31532
9/9 - 1s - loss: 0.0955 - accuracy: 0.9741 - val_loss: 0.3166 - val_accuracy: 0.8815
Epoch 51/100

Epoch 00051: val_loss improved from 0.31532 to 0.29850, saving model to checkpoint.h5
9/9 - 1s - loss: 0.0742 - accuracy: 0.9926 - val_loss: 0.2985 - val_accuracy: 0.8963
Epoch 52/100

Epoch 00052: val_loss improved from 0.29850 to 0.29698, saving model to checkpoint.h5
9/9 - 1s - loss: 0.0952 - accuracy: 0.9815 - val_loss: 0.2970 - val_accuracy: 0.9037
Epoch 53/100

Epoch 00053: val_loss improved from 0.29698 to 0.28474, saving model to checkpoint.h5
9/9 - 1s - loss: 0.0693 - accuracy: 0.9889 - val_loss: 0.2847 - val_accuracy: 0.9037
Epoch 54/100

Epoch 00054: val_loss did not improve from 0.28474
9/9 - 1s - loss: 0.0658 - accuracy: 0.9889 - 

In [16]:
###############################################################################
# can alternatively used the weights provided in the repo. If so it should get
# you 93% accuracy. Change the WEIGHTS_PATH variable to wherever it is on your
# system.
###############################################################################

# WEIGHTS_PATH = "EEGNet-8-2-weights.h5"
# model.load_weights(WEIGHTS_PATH)


In [17]:
###############################################################################
# make prediction on test set.
###############################################################################

probs       = model.predict(X_test)
preds       = probs.argmax(axis = -1)  
acc         = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))

Classification accuracy: 0.940741 


In [18]:
############################# PyRiemann Portion ##############################

# code is taken from PyRiemann's ERP sample script, which is decoding in 
# the tangent space with a logistic regression

n_components = 2  # pick some components

# set up sklearn pipeline
clf = make_pipeline(XdawnCovariances(n_components),
                    TangentSpace(metric='riemann'),
                    LogisticRegression())

preds_rg     = np.zeros(len(Y_test))

# reshape back to (trials, channels, samples)
X_train      = X_train.reshape(X_train.shape[0], chans, samples)
X_test       = X_test.reshape(X_test.shape[0], chans, samples)

# train a classifier with xDAWN spatial filtering + Riemannian Geometry (RG)
# labels need to be back in single-column format
clf.fit(X_train, Y_train.argmax(axis = -1))
preds_rg     = clf.predict(X_test)

# Printing the results
acc2         = np.mean(preds_rg == Y_test.argmax(axis = -1))
print("Classification accuracy: %f " % (acc2))

# plot the confusion matrices for both classifiers
# names        = ['audio left', 'audio right', 'vis left', 'vis right']
names        = ['off', 'on']
plt.figure(0)
plot_confusion_matrix(preds, Y_test.argmax(axis = -1), names, title = 'EEGNet-8,2')

plt.figure(1)
plot_confusion_matrix(preds_rg, Y_test.argmax(axis = -1), names, title = 'xDAWN + RG')

Classification accuracy: 0.607407 


<AxesSubplot:title={'center':'xDAWN + RG'}, xlabel='Predicted label', ylabel='True label'>