In [4]:
#
# Gibran Fuentes-Pineda <gibranfp@unam.mx>
# IIMAS, UNAM
# 2019
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
#from tensorflow.keras.regularizers import l2
#from keras.utils import np_utils



def SepConv1D(Chans = 8, Samples = 257, Filters = 32):
    eeg_input    = Input(shape = (Samples, Chans))
    padded       = ZeroPadding1D(padding = 4)(eeg_input)
    block1       = SeparableConv1D(Filters, 16, strides = 8,
                                   padding = 'valid',
                                   data_format = 'channels_last',
                                   kernel_initializer = 'glorot_uniform',
                                   bias_initializer = 'zeros',
                                   use_bias = True)(padded)
    block1       = Activation('tanh')(block1)
    flatten      = Flatten(name = 'flatten')(block1)
    prediction   = Dense(1, activation = 'sigmoid')(flatten)
    return Model(inputs = eeg_input, outputs = prediction, name='SepConv1D')


In [5]:
import mne
filepath = "/workspace/data/EEG/data/"

ep1 = mne.read_epochs(filepath + "epochs/A01-epo.fif", preload=True)

Reading /workspace/data/EEG/data/epochs/A01-epo.fif ...
    Found the data of interest:
        t =    -199.22 ...     800.78 ms
        0 CTF compensation matrices available
Not setting metadata
Not setting metadata
4200 matching events found
No baseline correction applied
0 projection items activated


In [6]:
labels = ep1.events[:, -1]

In [7]:
X = ep1.get_data() * 1000
y = labels-1

In [8]:
# take 50/25/25 percent of the data to train/validate/test
X_train      = X[0: int(0.6*len(y)),]
Y_train      = y[0: int(0.6*len(y)),]
X_validate   = X[int(0.6*len(y)) : int(0.8*len(y)) ,]
Y_validate   = y[int(0.6*len(y)) : int(0.8*len(y)) ,]
X_test       = X[int(0.8*len(y)) : len(y) ,]
Y_test       = y[int(0.8*len(y)) : len(y) ,]

In [37]:
#Y_train      = np_utils.to_categorical(Y_train-1)
#Y_validate   = np_utils.to_categorical(Y_validate-1)
#Y_test       = np_utils.to_categorical(Y_test-1)

In [9]:
kernels, chans, samples = 1, 8, 257


In [10]:
X_train      = X_train.reshape(X_train.shape[0], samples, chans)
X_validate   = X_validate.reshape(X_validate.shape[0], samples, chans)
X_test       = X_test.reshape(X_test.shape[0], samples, chans)

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

X_train shape: (2520, 257, 8)
2520 train samples
840 test samples


In [13]:
model = SepConv1D()

2022-03-09 23:21:06.597366: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-03-09 23:21:07.366945: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 30988 MB memory:  -> device: 0, name: Tesla V100-SXM3-32GB, pci bus id: 0000:bc:00.0, compute capability: 7.0


In [41]:
model.compile(loss='categorical_crossentropy', optimizer='adam', 
              metrics = ['mse'])

# count number of parameters in the model
numParams    = model.count_params()

In [42]:
numParams

1441

In [43]:
from tensorflow.keras.callbacks import ModelCheckpoint

checkpointer = ModelCheckpoint(filepath='/workspace/data/EEG/models/checkpoint.h5', verbose=1,
                               save_best_only=True)

In [44]:
class_weights = {0:1, 1:5}

################################################################################
# fit the model. Due to very small sample sizes this can get
# pretty noisy run-to-run, but most runs should be comparable to xDAWN + 
# Riemannian geometry classification (below)
################################################################################
fittedModel = model.fit(X_train, Y_train, batch_size = 256, epochs = 1000,
                        verbose = 2, callbacks=[checkpointer], validation_data=(X_validate, Y_validate), class_weight = class_weights)

Epoch 1/1000

Epoch 1: val_loss improved from inf to 0.00000, saving model to /workspace/data/EEG/models/checkpoint.h5
10/10 - 0s - loss: 0.0000e+00 - mse: 0.3845 - val_loss: 0.0000e+00 - val_mse: 0.3821 - 497ms/epoch - 50ms/step
Epoch 2/1000

Epoch 2: val_loss did not improve from 0.00000
10/10 - 0s - loss: 0.0000e+00 - mse: 0.3861 - val_loss: 0.0000e+00 - val_mse: 0.3833 - 68ms/epoch - 7ms/step
Epoch 3/1000

Epoch 3: val_loss did not improve from 0.00000
10/10 - 0s - loss: 0.0000e+00 - mse: 0.3865 - val_loss: 0.0000e+00 - val_mse: 0.3821 - 72ms/epoch - 7ms/step
Epoch 4/1000

Epoch 4: val_loss did not improve from 0.00000
10/10 - 0s - loss: 0.0000e+00 - mse: 0.3893 - val_loss: 0.0000e+00 - val_mse: 0.3798 - 64ms/epoch - 6ms/step
Epoch 5/1000

Epoch 5: val_loss did not improve from 0.00000
10/10 - 0s - loss: 0.0000e+00 - mse: 0.3881 - val_loss: 0.0000e+00 - val_mse: 0.3798 - 77ms/epoch - 8ms/step
Epoch 6/1000

Epoch 6: val_loss did not improve from 0.00000
10/10 - 0s - loss: 0.0000e+00

In [45]:
model.load_weights('/workspace/data/EEG/models/checkpoint.h5')

In [46]:
import numpy as np
probs       = model.predict(X_test)
preds       = probs.argmax(axis = -1)
acc         = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))

Classification accuracy: 0.000000 


In [47]:
preds

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [48]:
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, classification_report
print("Test Accuracy: ", accuracy_score(Y_test, preds))
print("Test f1 Score: ", f1_score(Y_test, preds))
print("Test Confusion matrix: ", confusion_matrix(Y_test, preds))
print(classification_report(Y_test, preds, labels= [0, 1]))

Test Accuracy:  0.8333333333333334
Test f1 Score:  0.0
Test Confusion matrix:  [[700   0]
 [140   0]]
              precision    recall  f1-score   support

           0       0.83      1.00      0.91       700
           1       0.00      0.00      0.00       140

    accuracy                           0.83       840
   macro avg       0.42      0.50      0.45       840
weighted avg       0.69      0.83      0.76       840



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [49]:
Y_test


array([0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,
       0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
       1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,