In [1]:
import numpy as np
import mne
from EEGModels import EEGNet
from tensorflow.keras import utils as np_utils
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras import backend as K
from pyriemann.estimation import XdawnCovariances
from pyriemann.tangentspace import TangentSpace
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import os
import pickle

In [2]:
K.set_image_data_format('channels_last')

In [7]:
# Load the preprocessed data
with open('preprocessed_data.pkl', 'rb') as f:
    X, y = pickle.load(f)

In [8]:
# Check the shapes of the datasets
print("X shape:", X.shape)
print("y shape:", y.shape)

X shape: (39569, 64, 3)
y shape: (19894,)


In [None]:
# Ensure that the number of samples is correct for y
if X.shape[0] != y.shape[0]:
    raise ValueError("Mismatch between number of samples in X and y.")

ValueError: Mismatch between number of samples in X and y.

In [None]:
# Define parameters
kernels, chans, samples = 1, X.shape[1], X.shape[2]

In [None]:
# Split data into train, validate, and test sets (adjust as needed)
train_size = int(0.5 * len(X))
validate_size = int(0.25 * len(X))

In [None]:
print("train_size:", train_size)
print("validate_size:", validate_size)
print("test_size:", len(y) - train_size - validate_size)

In [None]:
X_train = X[:train_size]
Y_train = y[:train_size]
X_validate = X[train_size:train_size + validate_size]
Y_validate = y[train_size:train_size + validate_size]
X_test = X[train_size + validate_size:]
Y_test = y[train_size + validate_size:]

In [None]:
# Check the shapes of the splits before one-hot encoding
print("X_train shape:", X_train.shape)
print("Y_train shape:", Y_train.shape)
print("X_validate shape:", X_validate.shape)
print("Y_validate shape:", Y_validate.shape)
print("X_test shape:", X_test.shape)
print("Y_test shape:", Y_test.shape)

In [None]:
print(y.shape)
print(X.shape)
print(train_size)
print(validate_size)


In [None]:
# One-hot encode the labels
Y_train = np_utils.to_categorical(Y_train)
Y_validate = np_utils.to_categorical(Y_validate)
Y_test = np_utils.to_categorical(Y_test)

In [None]:
# Reshape data for EEGNet
X_train = X_train.reshape(X_train.shape[0], chans, samples, kernels)
X_validate = X_validate.reshape(X_validate.shape[0], chans, samples, kernels)
X_test = X_test.reshape(X_test.shape[0], chans, samples, kernels)

In [None]:
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

In [None]:
# Initialize EEGNet model
model = EEGNet(nb_classes=Y_train.shape[1], Chans=chans, Samples=samples,
               dropoutRate=0.5, kernLength=32, F1=8, D=2, F2=16,
               dropoutType='Dropout')

In [None]:
# Compile model
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Set up model checkpointing
checkpointer = ModelCheckpoint(filepath='/tmp/checkpoint.h5', verbose=1, save_best_only=True)

# Train the model
fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=300, verbose=2,
                        validation_data=(X_validate, Y_validate),
                        callbacks=[checkpointer])


In [None]:
# Load best weights
model.load_weights('/tmp/checkpoint.h5')

In [None]:
# Evaluate model
probs = model.predict(X_test)
preds = probs.argmax(axis=-1)
acc = np.mean(preds == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc))

In [None]:
# PyRiemann part (for comparison)
n_components = 2
clf = make_pipeline(XdawnCovariances(n_components), TangentSpace(metric='riemann'), LogisticRegression())
X_train_reshaped = X_train.reshape(X_train.shape[0], chans, samples)
X_test_reshaped = X_test.reshape(X_test.shape[0], chans, samples)
clf.fit(X_train_reshaped, Y_train.argmax(axis=-1))
preds_rg = clf.predict(X_test_reshaped)
acc2 = np.mean(preds_rg == Y_test.argmax(axis=-1))
print("Classification accuracy: %f " % (acc2))

In [None]:
# Plot confusion matrices
names = ['rest', 'left fist', 'right fist', 'both fists', 'both feet']

In [None]:
cm_eegnet = confusion_matrix(Y_test.argmax(axis=-1), preds)
disp_eegnet = ConfusionMatrixDisplay(confusion_matrix=cm_eegnet, display_labels=names)
disp_eegnet.plot(cmap=plt.cm.Blues)
plt.title('EEGNet-8,2')
plt.show()

In [None]:
cm_rg = confusion_matrix(Y_test.argmax(axis=-1), preds_rg)
disp_rg = ConfusionMatrixDisplay(confusion_matrix=cm_rg, display_labels=names)
disp_rg.plot(cmap=plt.cm.Blues)
plt.title('xDAWN + RG')
plt.show()