In [39]:
import os
import numpy as np
import tensorflow as tf
import keras
import nengo_dl
from tensorflow.python.keras import Input, Model
import nengo
from tensorflow.python.keras.callbacks import EarlyStopping
from tensorflow.python.keras.layers import Conv2D, Dropout, AveragePooling2D, Flatten, Dense, BatchNormalization, \
    Conv3D, MaxPooling2D, LSTM
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split, KFold
from keras import backend as K
import pandas as pd
from sklearn import metrics
from tensorflow.python.keras.models import Sequential

import bnci_utils as utils

In [40]:
# All the datasets that can be run with this notebook
#   Entire dataset - all data
#   Female dataset - data from female subjects
#   Male dataset - data from male subjects
datasets = {
    'entire_dataset': 'entire_dataset.npz',
    'female_dataset': 'dataset_female_gender.npz',
    'male_dataset': 'dataset_male_gender.npz'
}

# Dataset path is by default saved in dataset_result/bci_dataset.npz
dataset_path = os.path.join('dataset_result', datasets['entire_dataset'])

data_output_folder = 'entire_dataset_output' # output path for statistics from the simulation
iteration_data_file_name = 'cnn_10_fold_entire_dataset.xlsx' # file name of excel file with data from each iteration
iteration_stats_file_name = 'cnn_10_fold_entire_dataset_stats.xlsx' # file name for statistics from the simulation (i.e
                                                                    # max and average accuracy, max and average recall...)

'Features shape: (2976, 14, 36, 10), labels shape: (2976,)'

In [None]:
# Get features and labels
features, labels = utils.load_dataset(dataset_path)

f'Features shape: {features.shape}, labels shape: {labels.shape}'

In [41]:
# Reshape the dataset for TensorFlow only
features = features.reshape((features.shape[0], 14, -1))

labels = labels.reshape((-1, 1))
labels = OneHotEncoder().fit_transform(labels).toarray()

f'features shape: {features.shape}, labels_shape: {labels.shape}'

'features shape: (2976, 14, 360), labels_shape: (2976, 2)'

In [42]:
# set seed to produce a consistent result
seed = 0
np.random.seed(seed)
tf.random.set_seed(seed)

In [43]:
# Function to create the LSTM model
def lstm_model():
    model = Sequential([
        LSTM(124, input_shape=(14,360), activation=tf.nn.relu, return_sequences=True),
        Dropout(0.4),
        BatchNormalization(),
        LSTM(124, activation=tf.nn.relu),
        Dropout(0.3),
        BatchNormalization(),
        Dense(64, activation=tf.nn.relu),
        Dropout(0.2),
        Dense(2, activation=tf.nn.softmax, name='output_layer')
    ])

    return model

In [44]:
def run_network(model, train, valid, test, iteration, epochs=30):
    x_train, y_train = train[0], train[1]
    x_val, y_val = valid[0], valid[1]
    x_test, y_test = test[0], test[1]

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.losses.BinaryCrossentropy(),
        metrics=['accuracy']
    )

    # Train the model and validate on the validation data
    model.fit(x_train, y_train, epochs=epochs, validation_data=(x_val, y_val),
              callbacks=[EarlyStopping(patience=8, verbose=1, restore_best_weights=True)]
              )

    # Get the statistics
    accuracy, precision, recall, f1, confusion_matrix = utils.get_metrics_keras(model, x_test, y_test, f'{iteration}. LSTM ')

    return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'confusion_matrix': confusion_matrix
        }


In [45]:
results = []
num_splits = 10
iteration = 1

# Run 10-fold CV in the same manner as in the case of the CNN
x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.25, shuffle=True)
for train_idx, val_idx in KFold(n_splits=num_splits).split(x_train):
    x_train_curr, y_train_curr = x_train[train_idx], y_train[train_idx] # get the current training data
    x_val, y_val = x_train[val_idx], y_train[val_idx] # get the current validation data

    model = lstm_model() # create the LSTM model

    # Run the network and save the results
    result = run_network(model, (x_train_curr, y_train_curr), (x_val, y_val), (x_test, y_test), iteration)
    results.append(result)

    # Delete the model
    K.clear_session()
    del model

    iteration += 1

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Restoring model weights from the end of the best epoch.
Epoch 00017: early stopping
1. LSTM : accuracy = 49.73118279569893%, precision = 0.4875, recall = 0.6464088397790055, f1 = 0.5558194774346793
Confusion matrix:
[[136 246]
 [128 234]]
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Restoring model weights from the end of the best epoch.
Epoch 00013: early stopping
2. LSTM : accuracy = 48.79032258064516%, precision = 0.47845804988662133, recall = 0.5828729281767956, f1 = 0.5255292652552928
Confusion matrix:
[[152 230]
 [151 211]]
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Restoring model weights from the end of the 

In [46]:
# Create pandas dataframe with stats from each iterations
df = pd.DataFrame({
    'iterations': [x for x in range(1, num_splits + 1)],
    'accuracy': [x['accuracy'] for x in results],
    'precision': [x['precision'] for x in results],
    'recall': [x['recall'] for x in results],
    'f1': [x['f1'] for x in results],
})

df

Unnamed: 0,iterations,accuracy,precision,recall,f1
0,1,0.497312,0.4875,0.646409,0.555819
1,2,0.487903,0.478458,0.582873,0.525529
2,3,0.521505,0.511905,0.356354,0.420195
3,4,0.517473,0.50655,0.320442,0.392555
4,5,0.482527,0.48,0.762431,0.589114
5,6,0.467742,0.433594,0.30663,0.359223
6,7,0.46371,0.419214,0.265193,0.324873
7,8,0.471774,0.476692,0.875691,0.617332
8,9,0.477151,0.474576,0.696133,0.56439
9,10,0.485215,0.469027,0.439227,0.453638


In [47]:
df_stats = {
    ''
}