In [61]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.callbacks import EarlyStopping
from tensorflow.python.keras.layers import Dropout, Dense, LSTM
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split, KFold
from keras import backend as K
import pandas as pd
from tensorflow.python.keras.models import Sequential
import bnci_utils as utils

In [62]:
# All the datasets that can be run with this notebook
#   Entire dataset - all data
#   Female dataset - data from female subjects
#   Male dataset - data from male subjects
datasets = {
    'entire_dataset': 'entire_dataset.npz',
    'female_dataset': 'dataset_female_gender.npz',
    'male_dataset': 'dataset_male_gender.npz'
}

# Dataset path is by default saved in dataset_result/bci_dataset.npz
dataset_path = os.path.join('dataset_result', datasets['entire_dataset'])

data_output_folder = 'entire_dataset_output' # output path for statistics from the simulation
iteration_data_file_name = 'lstm_entire_dataset.xlsx' # file name of excel file with data from each iteration
iteration_stats_file_name = 'lstm_entire_dataset_stats.xlsx' # file name for statistics from the simulation (i.e
                                                                    # max and average accuracy, max and average recall...)

In [63]:
# Get features and labels
features, labels = utils.load_dataset(dataset_path)

f'Features shape: {features.shape}, labels shape: {labels.shape}'

'Features shape: (2976, 14, 36, 10), labels shape: (2976,)'

In [64]:
# Reshape the dataset for TensorFlow only
features = features.reshape((features.shape[0], 14, -1))

labels = labels.reshape((-1, 1))
labels = OneHotEncoder().fit_transform(labels).toarray()

f'features shape: {features.shape}, labels_shape: {labels.shape}'


'features shape: (2976, 14, 360), labels_shape: (2976, 2)'

In [65]:
# set seed to produce a consistent result
seed = 1
np.random.seed(seed)
tf.random.set_seed(seed)

In [66]:
# Function to create the LSTM model
def lstm_model():
    model = Sequential([
        LSTM(124, input_shape=(14, 360), activation=tf.nn.relu, return_sequences=True),
        Dropout(0.4),
        LSTM(124, activation=tf.nn.relu),
        Dropout(0.3),
        Dense(64, activation=tf.nn.relu),
        Dropout(0.2),
        Dense(2, activation=tf.nn.softmax, name='output_layer')
    ])

    return model

lstm_model().summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm (LSTM)                  (None, 14, 124)           240560    
_________________________________________________________________
dropout (Dropout)            (None, 14, 124)           0         
_________________________________________________________________
lstm_1 (LSTM)                (None, 124)               123504    
_________________________________________________________________
dropout_1 (Dropout)          (None, 124)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                8000      
_________________________________________________________________
dropout_2 (Dropout)          (None, 64)                0         
_________________________________________________________________
output_layer (Dense)         (None, 2)                 1

In [67]:
def run_network(model, train, valid, test, iteration, epochs=30):
    x_train, y_train = train[0], train[1]
    x_val, y_val = valid[0], valid[1]
    x_test, y_test = test[0], test[1]

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.losses.BinaryCrossentropy(),
        metrics=['accuracy']
    )

    # Train the model and validate on the validation data
    model.fit(x_train, y_train, epochs=epochs, validation_data=(x_val, y_val),
              callbacks=[EarlyStopping(patience=8, verbose=1, restore_best_weights=True)]
              )

    # Get the statistics
    accuracy, precision, recall, f1, confusion_matrix = utils.get_metrics_keras(model, x_test, y_test, f'{iteration}. LSTM ')

    return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'confusion_matrix': confusion_matrix
        }

In [68]:
results = []
num_splits = 10
iteration = 1

# Run 10-fold CV in the same manner as in the case of the CNN
x_train, x_test, y_train, y_test = train_test_split(features, labels, test_size=0.25, shuffle=True)
for train_idx, val_idx in KFold(n_splits=num_splits).split(x_train):
    x_train_curr, y_train_curr = x_train[train_idx], y_train[train_idx] # get the current training data
    x_val, y_val = x_train[val_idx], y_train[val_idx] # get the current validation data

    model = lstm_model() # create the LSTM model

    # Run the network and save the results
    result = run_network(model, (x_train_curr, y_train_curr), (x_val, y_val), (x_test, y_test), iteration)
    results.append(result)

    # Delete the model
    K.clear_session()
    del model

    iteration += 1

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
1. LSTM : accuracy = 50.134408602150536%, precision = 0.5125553914327917, recall = 0.8943298969072165, f1 = 0.6516431924882629
Confusion matrix:
[[ 26 330]
 [ 41 347]]
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30
2. LSTM : accuracy = 48.924731182795696%, precision = 0.525, recall = 0.21649484536082475, f1 = 0.30656934306569344
Confusion matr

In [69]:
# Create pandas dataframe with stats from each iterations
df = pd.DataFrame({
    'iterations': [x for x in range(1, num_splits + 1)],
    'accuracy': [x['accuracy'] for x in results],
    'precision': [x['precision'] for x in results],
    'recall': [x['recall'] for x in results],
    'f1': [x['f1'] for x in results],
})

df

Unnamed: 0,iterations,accuracy,precision,recall,f1
0,1,0.501344,0.512555,0.89433,0.651643
1,2,0.489247,0.525,0.216495,0.306569
2,3,0.52957,0.534672,0.755155,0.626068
3,4,0.525538,0.547945,0.515464,0.531208
4,5,0.514785,0.518519,0.974227,0.676813
5,6,0.502688,0.571429,0.185567,0.280156
6,7,0.471774,0.491803,0.386598,0.4329
7,8,0.525538,0.527646,0.860825,0.654261
8,9,0.524194,0.532819,0.71134,0.609272
9,10,0.482527,0.526316,0.07732,0.134831


In [70]:
os.makedirs(data_output_folder, exist_ok=True)

# Save the dataframe
df.to_excel(os.path.join(data_output_folder, iteration_data_file_name))

In [71]:
# Create a dataframe with statistics
df_stats = pd.DataFrame({
    'average_accuracy': [df['accuracy'].mean()],
    'max_accuracy': [df['accuracy'].max()],
    'accuracy_std': [df['accuracy'].std()],
    'average_precision': [df['precision'].mean()],
    'max_precision': [df['precision'].max()],
    'average_recall': [df['recall'].mean()],
    'max_recall': [df['recall'].max()],
    'average_f1': [df['f1'].mean()],
    'max_f1': [df['f1'].max()],
})

df_stats

Unnamed: 0,average_accuracy,max_accuracy,accuracy_std,average_precision,max_precision,average_recall,max_recall,average_f1,max_f1
0,0.50672,0.52957,0.020433,0.52887,0.571429,0.557732,0.974227,0.490372,0.676813


In [72]:
# Save the dataframe
df_stats.to_excel(os.path.join(data_output_folder, iteration_stats_file_name))

In [73]:
# Print confusion matrices
utils.print_confusion_matrices(ann=results)

Confusion matrices for the ANN:
[[ 26 330]
 [ 41 347]] 

[[280  76]
 [304  84]] 

[[101 255]
 [ 95 293]] 

[[191 165]
 [188 200]] 

[[  5 351]
 [ 10 378]] 

[[302  54]
 [316  72]] 

[[201 155]
 [238 150]] 

[[ 57 299]
 [ 54 334]] 

[[114 242]
 [112 276]] 

[[329  27]
 [358  30]] 

