## **Import Libraries**

In [173]:
import mne

# Utility
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import seaborn as sns
import os
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
from tqdm.auto import tqdm

from utilities import read_xdf, epoching

np.random.seed(42)

In [174]:
# Train model
from collections import OrderedDict
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from time import time
from datetime import timedelta
import keras
from keras.models import load_model
import os
from keras.utils import plot_model

# Build model
from keras.models import Sequential, Model
from keras.layers import Embedding, Reshape, Activation, Input, Dense,GRU,Reshape,TimeDistributed,Bidirectional,Dropout,Masking,LSTM, GlobalAveragePooling1D, Conv1D, MaxPooling1D, Flatten,GlobalMaxPooling1D
from keras.layers import Concatenate, Lambda, Reshape, RepeatVector,Multiply, BatchNormalization
from keras.optimizers import Adam
from keras import backend as K                                                          
from keras.callbacks import ModelCheckpoint, TensorBoard
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from pyriemann.estimation import Covariances, ERPCovariances, XdawnCovariances
from pyriemann.spatialfilters import CSP
from pyriemann.tangentspace import TangentSpace
from pyriemann.classification import MDM

# Classifications report
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay ,multilabel_confusion_matrix, accuracy_score
from sklearn.model_selection import cross_validate, cross_val_score, StratifiedShuffleSplit

## **Load Data**

### From XDF file

In [175]:
annotations_des = {
    '1': 'Left cue start',
    '2': 'Left stimuli start',
    '3': 'Left blank start',
    '4': 'Right cue start',
    '5': 'Right stimuli start',
    '6': 'Right blank start',
}

In [176]:
raw = read_xdf('Pipo_1_5_test1.xdf', show_plot=False, show_psd=False, bandpass=(4.0, 12.0))

Creating RawArray with float64 data, n_channels=8, n_times=49650
    Range : 0 ... 49649 =      0.000 ...   198.521 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 12 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 12.00 Hz
- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)
- Filter length: 413 samples (1.651 s)



[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished


In [177]:
from utilities import read_xdf, epoching

filenames = {'Pipo_1_5_test1.xdf', 'Pipo_1_5_test2.xdf', 'Pipo_1_5_test3.xdf'}

X, F, t, y = [],[],[],[]
for filename in filenames:

    raw = read_xdf(filename, show_plot=False, show_psd=False, bandpass=(4.0, 12.0))

    epochs = epoching(raw, show_psd=False,
        show_eeg=False,  # set True if wanna see preview of all epoch
        tmax=5)         # tmax=5 means set epoch duration 5 second

    # Pick only event 2: Left stimuli, 5: Right stimuli
    epochs = epochs['2','5']

    X.append((  epochs.get_data() * 1e6)[:,:,:1250]) # select only first 1250 frames(approx 5 second)
    F.append(   epochs.compute_psd(method='welch', fmax=30).get_data())
    t.append((  epochs.times)[:1250])                # select only first 1250 frames(approx 5 second)
    y.append(   epochs.events[:, -1])

    epochs = epoching(raw, show_psd=False,
        show_eeg=False,  # set True if wanna see preview of all epoch
        baseline=(5,10),
        tmin=5,
        tmax=10)         # tmax=5 means set epoch duration 5 second

    # Pick only event 2: Left stimuli, 5: Right stimuli
    epochs = epochs['2','5']

    X.append((  epochs.get_data() * 1e6)[:,:,:1250]) # select only first 1250 frames(approx 5 second)
    F.append(   epochs.compute_psd(method='welch', fmax=30).get_data())
    t.append((  epochs.times)[:1250])                # select only first 1250 frames(approx 5 second)
    y.append(   epochs.events[:, -1])

# Concatenate all data
X = np.concatenate(X)
F = np.concatenate(F)
t = np.concatenate(t)
y = np.concatenate(y)

Creating RawArray with float64 data, n_channels=8, n_times=52020
    Range : 0 ... 52019 =      0.000 ...   208.006 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 12 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 12.00 Hz
- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)
- Filter length: 413 samples (1.651 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1251 

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished


Creating RawArray with float64 data, n_channels=8, n_times=49650
    Range : 0 ... 49649 =      0.000 ...   198.521 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 12 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 12.00 Hz
- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)
- Filter length: 413 samples (1.651 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1251 

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished


Creating RawArray with float64 data, n_channels=8, n_times=50520
    Range : 0 ... 50519 =      0.000 ...   201.994 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 4 - 12 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 4.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 3.00 Hz)
- Upper passband edge: 12.00 Hz
- Upper transition bandwidth: 3.00 Hz (-6 dB cutoff frequency: 13.50 Hz)
- Filter length: 413 samples (1.651 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1252 

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   4 out of   4 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done   8 out of   8 | elapsed:    0.0s finished


In [178]:
X.shape, F.shape, t.shape, y.shape

((96, 5, 1250), (96, 5, 151), (7500,), (96,))

In [197]:
X

array([[[        0.        ,    124091.2310227 ,    323308.34791588,
         ...,  -4975084.78529112,  -3059847.69116265,
          -1047095.02941039],
        [        0.        ,    418609.94358085,    940088.79752828,
         ..., -10072812.18093242,  -8047836.52447013,
          -5561649.86459834],
        [        0.        ,    426263.43128802,    866387.8060909 ,
         ..., -10919335.50152572,  -8753718.34647588,
          -6169341.03457228],
        [        0.        ,    497176.15056236,   1008507.15964988,
         ...,  -1775864.65169888,   -310554.63368498,
           1083292.14041824],
        [        0.        ,    302119.78783984,    744102.79333782,
         ...,   8616879.07350387,   8855043.33130038,
           8761922.52372198]],

       [[        0.        ,   1292517.61559799,   2582534.37223891,
         ...,  -1652824.92502735,   -914425.33198587,
             98160.95319505],
        [        0.        ,   1390323.51531964,   2905248.66812871,
         ..

In [179]:
from mne.decoding import Scaler
scaler = Scaler(scalings='mean')
X_norm = scaler.fit_transform(X)

In [195]:
X_norm.shape

(96, 5, 1250)

In [196]:
X_norm

array([[[ 0.05485012,  0.06789522,  0.08883792, ..., -0.46815572,
         -0.2668164 , -0.05522575],
        [ 0.01377982,  0.0571767 ,  0.1112379 , ..., -1.03045872,
         -0.82053148, -0.56279095],
        [-0.0023031 ,  0.04502083,  0.09388361, ..., -1.21457198,
         -0.97414432, -0.68722575],
        [-0.03532738,  0.033687  ,  0.10466624, ..., -0.28183998,
         -0.07843631,  0.11504735],
        [-0.10100238, -0.05823762,  0.00432465, ...,  1.11870855,
          1.15242047,  1.13923931]],

       [[ 0.05485012,  0.19072605,  0.32633908, ..., -0.11890312,
         -0.04127885,  0.06516929],
        [ 0.01377982,  0.15791329,  0.3149641 , ...,  0.03522708,
          0.19376038,  0.38262349],
        [-0.0023031 ,  0.11901569,  0.25701423, ..., -0.0052507 ,
          0.06838231,  0.18770481],
        [-0.03532738,  0.09971957,  0.23093417, ..., -0.22590581,
         -0.26836877, -0.26896736],
        [-0.10100238, -0.05557173, -0.03462309, ..., -0.992695  ,
         -1.11

In [180]:
# One hot encoding
from sklearn.preprocessing import OneHotEncoder

# enc = OneHotEncoder()
# Y = enc.fit_transform(y.reshape(-1, 1)).toarray()
Y = y.copy()
mp = {2:0, 5:1}
for i in range(len(Y)):
    Y[i] = mp[Y[i]]

In [181]:
X_norm.shape, F.shape, t.shape, Y.shape

((96, 5, 1250), (96, 5, 151), (7500,), (96,))

## **Training pipeline**

### Train Test Split

In [182]:
from sklearn.model_selection import train_test_split

In [183]:
X_train_CNN, X_test_CNN, y_train_CNN, y_test_CNN = train_test_split(X_norm, Y, test_size=0.2, random_state=42)

X_train_CNN = tf.convert_to_tensor(X_train_CNN, dtype=tf.float32)
X_test_CNN = tf.convert_to_tensor(X_test_CNN, dtype=tf.float32)
y_train_CNN = tf.convert_to_tensor(y_train_CNN, dtype=tf.int32)
y_test_CNN = tf.convert_to_tensor(y_test_CNN, dtype=tf.int32)

### **Create a model**

In [184]:
from models import CNNModel

In [185]:
CNN = CNNModel()
CNN.model.summary()

Model: "CNNModel"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_layer (InputLayer)    [(None, 5, 1250)]         0         
                                                                 
 conv1d_6 (Conv1D)           (None, 3, 128)            480128    
                                                                 
 batch_norm_layer1 (BatchNor  (None, 3, 128)           512       
 malization)                                                     
                                                                 
 maxpool_layer1 (MaxPooling1  (None, 1, 128)           0         
 D)                                                              
                                                                 
 dropout_layer1 (Dropout)    (None, 1, 128)            0         
                                                                 
 dense_layer1 (Dense)        (None, 1, 64)             825

In [186]:
CNN.model_train(X_train_CNN, y_train_CNN)
# CNN.load_weights('CNN.h5')

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [187]:
out = CNN.model_predict_classes(X_test_CNN)



In [188]:
out

[1, 0, 1, 0, -1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, -1, 0, 1, 1]

### **Evaluate the model**

In [189]:
predictions = CNN.model_predict(X_test_CNN)
predictions



array([[0.99522346],
       [0.00081641],
       [0.994997  ],
       [0.00021465],
       [0.6304759 ],
       [0.993386  ],
       [0.9630043 ],
       [0.98127407],
       [0.06330821],
       [0.05179942],
       [0.98874986],
       [0.9987779 ],
       [0.00293645],
       [0.00234299],
       [0.04159449],
       [0.17102401],
       [0.30824372],
       [0.01666183],
       [0.9965185 ],
       [0.9837855 ]], dtype=float32)

In [190]:
X_test_CNN

<tf.Tensor: shape=(20, 5, 1250), dtype=float32, numpy=
array([[[-1.3714234 , -1.4649614 , -1.4427805 , ...,  0.05267079,
         -0.08979908, -0.22919995],
        [-1.516764  , -1.5958303 , -1.5548593 , ..., -0.39493155,
         -0.5264928 , -0.6288733 ],
        [-1.0158865 , -1.0497293 , -1.0036805 , ..., -0.44601402,
         -0.54614794, -0.6139918 ],
        [-0.70208275, -0.7482932 , -0.73115104, ...,  0.1774256 ,
          0.16472016,  0.14220922],
        [-0.00002552, -0.08569576, -0.16352619, ...,  0.47738764,
          0.5504948 ,  0.601348  ]],

       [[ 0.05485012, -0.23348598, -0.5368935 , ...,  0.8320281 ,
          0.93650776,  0.9261473 ],
        [ 0.01377982, -0.3095312 , -0.63174075, ...,  0.18180186,
          0.31255105,  0.36465117],
        [-0.0023031 , -0.32386366, -0.6467315 , ..., -0.18981056,
         -0.07879283, -0.03477502],
        [-0.03532738, -0.14315726, -0.23711865, ...,  0.47978175,
          0.5569663 ,  0.56853944],
        [-0.10100238,  0.

In [191]:
np.set_printoptions(suppress=True)
print(predictions*100)

[[99.52235   ]
 [ 0.08164063]
 [99.4997    ]
 [ 0.02146477]
 [63.04759   ]
 [99.3386    ]
 [96.30043   ]
 [98.1274    ]
 [ 6.330821  ]
 [ 5.1799417 ]
 [98.874985  ]
 [99.87779   ]
 [ 0.29364452]
 [ 0.23429908]
 [ 4.1594496 ]
 [17.102402  ]
 [30.824371  ]
 [ 1.666183  ]
 [99.65185   ]
 [98.37855   ]]


In [192]:
# threshold = 0.5
predictions = np.where(predictions > 0.5, 1, 0)

In [193]:
# # Classification report
print(classification_report(y_test_CNN,predictions))

              precision    recall  f1-score   support

           0       1.00      0.91      0.95        11
           1       0.90      1.00      0.95         9

    accuracy                           0.95        20
   macro avg       0.95      0.95      0.95        20
weighted avg       0.96      0.95      0.95        20



In [194]:
CNN.model.save('CNNNorm.h5')