In [1]:
import os
import numpy as np
import tensorflow as tf
import keras
import nengo_dl
import random
from tensorflow.python.keras import Input, Model, Sequential
import nengo
from tensorflow.python.keras.layers import Conv2D, Dropout, AveragePooling2D, Flatten, Dense, BatchNormalization, \
    Conv3D, LSTM
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split, KFold
from keras import backend as K

In [2]:
dataset_file = os.path.join('dataset_result', 'bci_dataset.npz')
dataset = np.load(dataset_file)

features, labels = dataset['features'], dataset['labels']
features = features.reshape((features.shape[0], 14, -1))

labels = labels.reshape((-1, 1))
labels = OneHotEncoder().fit_transform(labels).toarray()

# set seed to produce consistent result
seed = 2
np.random.seed(seed)
tf.random.set_seed(seed)

In [3]:
def lstm_model():
    model = Sequential([
        LSTM(360, input_shape=(features.shape[1:]), activation='relu', return_sequences=True),
        Dropout(0.2),
        BatchNormalization(),
        LSTM(360, activation='relu'),
        Dropout(0.2),
        BatchNormalization(),
        Dense(64, activation='relu'),
        Dropout(0.2),
        Dense(2, activation='softmax', name='output_layer')
    ])

    return model

In [4]:
def run_network(model, x_train, y_train, x_test, y_test, iteration, epochs=30):
    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=tf.losses.BinaryCrossentropy(),
        metrics=['accuracy']
    )

    model.fit(x_train, y_train, epochs=epochs)
    eval = model.evaluate(x_test, y_test)
    print('{}. Accuracy: '.format(iteration), eval[1] * 100)
    return eval[1]

In [5]:
results = []

num_splits = 10
iteration = 1

for train_idx, test_idx in KFold(n_splits=num_splits).split(features):
    x_train, y_train = features[train_idx], labels[train_idx]
    x_test, y_test = features[test_idx], labels[test_idx]

    model = lstm_model()
    results.append(run_network(model, x_train, y_train, x_test, y_test, iteration, 60))
    K.clear_session()

    iteration += 1

print('RESULTS:')
for result in results:
    print(result)


Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Epoch 23/60
Epoch 24/60
Epoch 25/60
Epoch 26/60
Epoch 27/60
Epoch 28/60
Epoch 29/60
Epoch 30/60
Epoch 31/60
Epoch 32/60
Epoch 33/60
Epoch 34/60
Epoch 35/60
Epoch 36/60
Epoch 37/60
Epoch 38/60
Epoch 39/60
Epoch 40/60
Epoch 41/60
Epoch 42/60
Epoch 43/60
Epoch 44/60
Epoch 45/60
Epoch 46/60
Epoch 47/60
Epoch 48/60
Epoch 49/60
Epoch 50/60
Epoch 51/60
Epoch 52/60
Epoch 53/60
Epoch 54/60
Epoch 55/60
Epoch 56/60
Epoch 57/60
Epoch 58/60
Epoch 59/60
Epoch 60/60
1. Accuracy:  48.99328947067261
Epoch 1/60
Epoch 2/60
Epoch 3/60
Epoch 4/60
Epoch 5/60
Epoch 6/60
Epoch 7/60
Epoch 8/60
Epoch 9/60
Epoch 10/60
Epoch 11/60
Epoch 12/60
Epoch 13/60
Epoch 14/60
Epoch 15/60
Epoch 16/60
Epoch 17/60
Epoch 18/60
Epoch 19/60
Epoch 20/60
Epoch 21/60
Epoch 22/60
Ep