In [None]:
#imports
import torch
import numpy as np
import os
import pprint
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Bidirectional, Flatten, GRU, Conv1D, MaxPooling1D
import matplotlib.pyplot as plt
from tensorflow.keras.utils import plot_model


In [None]:
#load dataset
#data already prepared before using in this notebook
dir = 'path to main directory'
data_train = torch.load(os.path.join(dir, 'eeg train dataset file'))
data_test = torch.load(os.path.join(dir, 'eeg test dataset file'))
ds_train = data_train['dataset']
ds_test = data_test['dataset']
pprint.pprint(ds_train[0])

In [None]:
#preprocessing steps
X_train = []
X_test = []

def preprocess_X(X, ds)->None:
    for i in ds:
        X.append(i['eeg'].numpy())

preprocess_X(X_train, ds_train)
preprocess_X(X_test, ds_test)

Y_train = np.array([i['label'] for i in ds_train])
Y_test = np.array([i['label'] for i in ds_test])

#function to trim eeg tensors to a common shape
def trim(X, max_cols)->None:
    X_trimmed = [arr[:, :max_cols] for arr in X if arr.shape[1] >= max_cols]
    return np.array(X_trimmed)

X_train = trim(X_train, 480)
X_test = trim(X_test, 480)

Next 5 cells define the encoder architectures to be tested:
LSTM
GRU
1D CNN
Bi-GRU
Bi-LSTM

In [None]:
#LSTM encoder

lstm = Sequential()
lstm.add(LSTM(128, return_sequences = True, input_shape=(128, 480)))
lstm.add(LSTM(128, return_sequences = True))
lstm.add(LSTM(64, return_sequences = True))
lstm.add(Flatten())
lstm.add(Dense(64, activation = 'relu'))          
lstm.add(Dense(40, activation='softmax'))

lstm.compile(loss='sparse_categorical_crossentropy',  
              optimizer='adam',
              metrics=['accuracy'])

lstm.summary()


In [None]:
#GRU encoder

gru = Sequential()

gru.add(GRU(128, return_sequences=True, input_shape=(128, 480)))
gru.add(GRU(128, return_sequences=True))
gru.add(GRU(64, return_sequences=True))
gru.add(Flatten())
gru.add(Dense(64, activation='relu'))
gru.add(Dense(40, activation='softmax'))

gru.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

gru.summary()

In [None]:
#1D CNN encoder

cnn = Sequential()

cnn.add(Conv1D(filters=128, kernel_size=3, activation='relu', input_shape=(128, 480)))
cnn.add(MaxPooling1D(pool_size=2))
cnn.add(Conv1D(filters=128, kernel_size=3, activation='relu'))
cnn.add(MaxPooling1D(pool_size=2))
cnn.add(Conv1D(filters=64, kernel_size=3, activation='relu'))
cnn.add(MaxPooling1D(pool_size=2))
cnn.add(Flatten())
cnn.add(Dense(64, activation='relu'))
cnn.add(Dense(40, activation='softmax'))

cnn.compile(loss='sparse_categorical_crossentropy',  
              optimizer='adam',
              metrics=['accuracy'])

cnn.summary()

In [None]:
#Bidirectional GRU encoder

bigru = Sequential()

bigru.add(Bidirectional(GRU(128, return_sequences=True), input_shape=(128, 480)))
bigru.add(Bidirectional(GRU(128, return_sequences=True)))
bigru.add(Bidirectional(GRU(64, return_sequences=True)))
bigru.add(Flatten())
bigru.add(Dense(64, activation='relu'))
bigru.add(Dense(40, activation='softmax'))

bigru.compile(loss='sparse_categorical_crossentropy',  
              optimizer='adam',
              metrics=['accuracy'])

bigru.summary()

In [None]:
#Bidirectional LSTM encoder

bilstm = Sequential()

bilstm.add(Bidirectional(LSTM(128, return_sequences=True), input_shape=(128, 480)))
bilstm.add(Bidirectional(LSTM(128, return_sequences=True)))
bilstm.add(Bidirectional(LSTM(64, return_sequences=True)))
bilstm.add(Flatten())
bilstm.add(Dense(64, activation='relu'))
bilstm.add(Dense(40, activation='softmax'))

bilstm.compile(loss='sparse_categorical_crossentropy',  
              optimizer='adam',
              metrics=['accuracy'])

bilstm.summary()

In [None]:
#training loop

history = bilstm.fit(
    X_train,
    Y_train,
    validation_split=0.1,
    batch_size=128,
    epochs=500,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=8,
            restore_best_weights=True
        )
    ]
)

In [None]:
#testing
model_acc = bilstm.evaluate(X_test, Y_test, verbose=0)[1]
print(model_acc)

In [None]:
#loss curve plots
fig = plt.figure()
plt.plot(history.history['loss'], color='teal', label='loss')
plt.plot(history.history['val_loss'], color='orange', label='val_loss')
fig.suptitle('Loss', fontsize=20)
plt.legend(loc="upper left")
plt.show()

In [None]:
#accuracy curve plots
fig = plt.figure()
plt.plot(history.history['accuracy'], color='teal', label='accuracy')
plt.plot(history.history['val_accuracy'], color='orange', label='val_accuracy')
fig.suptitle('Accuracy', fontsize=20)
plt.legend(loc="upper left")
plt.show()

In [None]:
plot_model(bilstm)

In [None]:
#save weights to disk
bilstm.save('model name.h5')