# **Import Libraries**

In [671]:
import mne

# Utility
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import seaborn as sns
import os

%matplotlib inline
import warnings
warnings.filterwarnings('ignore')
from tqdm.auto import tqdm

from utilities import read_xdf, epoching
from models import CNNModel

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

from sklearn.metrics import confusion_matrix, classification_report

np.random.seed(42)

# **Load Data**

In [672]:
annotations_des = {
    '1': 'Left cue start',
    '2': 'Left stimuli start',
    '3': 'Left blank start',
    '4': 'Right cue start',
    '5': 'Right stimuli start',
    '6': 'Right blank start',
}

In [673]:
def get_epoch(filenames):
    # Define temporal variables
    X, F, t, y = [],[],[],[]
    for filename in filenames:
        ##--------Get First 5 second--------##
        raw = read_xdf(filename, show_plot=False, show_psd=False, verbose=False)
        epochs = epoching(raw, show_psd=False,
            show_eeg=False,  # set True if wanna see preview of all epoch
            tmax=5)         # tmax=5 means set epoch duration 5 second
        # Pick only event 2: Left stimuli, 5: Right stimuli
        epochs = epochs['2','5']
        X.append((  epochs.get_data())[:,:,:1250]) # select only first 1250 frames(approx 5 second)
        F.append(   epochs.compute_psd(method='welch', fmax=30).get_data())
        t.append((  epochs.times)[:1250])                # select only first 1250 frames(approx 5 second)
        y.append(   epochs.events[:, -1])
        ##--------Get Last 5 second--------##
        epochs = epoching(raw, show_psd=False,
            show_eeg=False,  # set True if wanna see preview of all epoch
            baseline=(5,10),
            tmin=5,
            tmax=10)         # tmax=5 means set epoch duration 5 second
        # Pick only event 2: Left stimuli, 5: Right stimuli
        epochs = epochs['2','5']
        X.append((  epochs.get_data())[:,:,:1250]) # select only first 1250 frames(approx 5 second)
        F.append(   epochs.compute_psd(method='welch', fmax=30).get_data())
        t.append((  epochs.times)[:1250])                # select only first 1250 frames(approx 5 second)
        y.append(   epochs.events[:, -1])
    # Concatenate all data
    X = np.concatenate(X)
    F = np.concatenate(F)
    t = np.concatenate(t)
    y = np.concatenate(y)
    return X, F, t, y

## **Load Left Data**

In [674]:
filenames = {'data/Pipo_6Hz_18_05.xdf'}

Xl, Fl, tl, yl = get_epoch(filenames)

Creating RawArray with float64 data, n_channels=8, n_times=51160
    Range : 0 ... 51159 =      0.000 ...   204.541 secs
Ready.
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1252 original time points ...
0 bad epochs dropped
Effective window size : 5.006 (s)
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1251 original time points ...
1 bad epochs dropped
Effective window size : 5.002 (s)


In [675]:
# Map label to 0
yl = np.where(yl==2, 0, yl)
yl = np.where(yl==5, 0, yl)
yl

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

## **Load Right Data**

In [676]:
filenames = {'data/Pipo_10Hz_18_05.xdf'}

Xr, Fr, tr, yr = get_epoch(filenames)

Creating RawArray with float64 data, n_channels=8, n_times=52650
    Range : 0 ... 52649 =      0.000 ...   210.499 secs
Ready.
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1252 original time points ...
0 bad epochs dropped
Effective window size : 5.006 (s)
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1251 original time points ...
0 bad epochs dropped
Effective window size : 5.002 (s)


In [677]:
# Map label to 0
yr = np.where(yr==2, 1, yr)
yr = np.where(yr==5, 1, yr)
yr

array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

## **Load Free-State Data**

In [678]:
filenames = {'data/Pipo_unknowsignal_18_05.xdf'}

Xfs, Ffs, tfs, yfs = get_epoch(filenames)

Creating RawArray with float64 data, n_channels=8, n_times=49780
    Range : 0 ... 49779 =      0.000 ...   199.048 secs
Ready.
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1251 original time points ...
1 bad epochs dropped
Effective window size : 5.002 (s)
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1252 original time points ...
1 bad epochs dropped
Effective window size : 5.006 (s)


In [679]:
# Map label of y to 2
yfs[yfs==5] = 2
yfs

array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [680]:
Xfs.shape, Ffs.shape, tfs.shape, yfs.shape

((32, 5, 1250), (32, 5, 151), (2500,), (32,))

## **Load Mixed Data**

In [681]:
filenames = {'data/Pipo_mix_18_05_1.xdf','data/Pipo_mix_18_05_2.xdf'}

Xm, Fm, tm, ym = get_epoch(filenames)

Creating RawArray with float64 data, n_channels=8, n_times=50270
    Range : 0 ... 50269 =      0.000 ...   201.005 secs
Ready.
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1251 original time points ...
1 bad epochs dropped
Effective window size : 5.002 (s)
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from preloaded Raw for 48 events and 1252 original time points ...
1 bad epochs dropped
Effective window size : 5.006 (s)
Creating RawArray with float64 data, n_channels=8, n_times=50660
    Range : 0 ... 50659 =      0.000 ...   202.544 secs
Ready.
Used Annotations descriptions: ['1', '2', '3', '4', '5', '6']
Not setting metadata
48 matching events found
Applying bas

In [682]:
# Map label of 2 to 0 and 5 to 1
ym[ym==2] = 0
ym[ym==5] = 1
ym

array([0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1,
       1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1,
       1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1])

## **Concatenate Data**

In [683]:
X = np.concatenate([Xl, Xr, Xfs, Xm])
F = np.concatenate([Fl, Fr, Ffs, Fm])
t = np.concatenate([tl, tr, tfs, tm])
y = np.concatenate([yl, yr, yfs, ym])

X.shape, F.shape, t.shape, y.shape

((160, 5, 1250), (160, 5, 151), (12500,), (160,))

# **Preprocessing**

In [684]:
enc = OneHotEncoder()
Y = enc.fit_transform(y.reshape(-1, 1)).toarray()

In [685]:
X.shape, F.shape, t.shape, Y.shape

((160, 5, 1250), (160, 5, 151), (12500,), (160, 3))

## **Training pipeline**

### Train Test Split

In [686]:
X_train_CNN, X_test_CNN, y_train_CNN, y_test_CNN = train_test_split(X, Y, test_size=0.2, random_state=1)

X_train_CNN = tf.convert_to_tensor(X_train_CNN, dtype=tf.float32)
X_test_CNN = tf.convert_to_tensor(X_test_CNN, dtype=tf.float32)
y_train_CNN = tf.convert_to_tensor(y_train_CNN, dtype=tf.int32)
y_test_CNN = tf.convert_to_tensor(y_test_CNN, dtype=tf.int32)

# **Model Pipeline**

## **Load Model**

In [687]:
CNN = CNNModel()
CNN.model.summary()

Model: "CNNModel"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_layer (InputLayer)    [(None, 5, 1250)]         0         
                                                                 
 conv1d_30 (Conv1D)          (None, 3, 128)            480128    
                                                                 
 batch_norm_layer1 (BatchNor  (None, 3, 128)           512       
 malization)                                                     
                                                                 
 maxpool_layer1 (MaxPooling1  (None, 1, 128)           0         
 D)                                                              
                                                                 
 dropout_layer1 (Dropout)    (None, 1, 128)            0         
                                                                 
 dense_layer1 (Dense)        (None, 1, 64)             825

## **Train Model**

In [688]:
CNN.model_train(X_train_CNN, y_train_CNN)
# CNN.load_weights('CNNN_Mod_mix.h5')

## **Evaluate the model**

In [690]:
prop, predictions = CNN.model_predict_classes(X_test_CNN)
predictions



array([0, 1, 0, 2, 2, 0, 1, 0, 1, 1, 1, 2, 2, 0, 1, 2, 2, 1, 0, 0, 0, 1,
       1, 0, 1, 0, 0, 0, 1, 2, 0, 0], dtype=int64)

In [691]:
# # Classification report
print(classification_report(np.argmax(y_test_CNN, axis=1), predictions))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00        14
           1       0.82      1.00      0.90         9
           2       1.00      0.78      0.88         9

    accuracy                           0.94        32
   macro avg       0.94      0.93      0.92        32
weighted avg       0.95      0.94      0.94        32



## **Save Model**

In [692]:
# CNN.model.save('CNNN_Mod_mix.h5')