In [1]:
from src.data_preparation.data_preparation import read_eeg_file
from scipy import signal
from scipy import linalg
from scipy.integrate import simps
from scipy import stats
from sklearn.model_selection import StratifiedKFold
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import accuracy_score
import pywt
import pyriemann.utils.mean as rie_mean
import numpy as np



Declaration of important values used during the pipeline 

In [2]:
TIME_LENGTH = 300
TIME_WINDOW = 300
EPOCH_SIZE = None
DATA_FOLDER = "data/bci-csp-based/bci-iii-dataset-iv-a"
CSP_COMPONENTS = 8
FS = 100
WAVELET = "coif1"
K_FOLD = 10

Create the subjects object and a dictionary to store the accuracies

In [3]:
subjects = range(1, 6)
subjects_set = set(subjects)
accuracies = {
    "GNB": np.zeros((len(subjects), K_FOLD)),
    "SVM": np.zeros((len(subjects), K_FOLD)),
    "LDA": np.zeros((len(subjects), K_FOLD))
}

Define the bandpass filter

In [4]:
sos = signal.cheby2(10, 50, [7, 30], analog=False, btype="band", output="sos", fs=FS)

Define the function to generate the common spatial filter's based on the test data 

In [5]:
def compute_spatial_filters(left_data, right_data):
    n_channels = left_data.shape[2]
    cov_shape = (n_channels, n_channels)
            
    # Estimate the covariance matrix of every trial
    n_left_trials = left_data.shape[0]
    cov = np.zeros((n_left_trials, *cov_shape))
    for n_trial in range(n_left_trials):
        trial = signal.sosfilt(sos, left_data[n_trial], axis=0)
        cov[n_trial] = np.cov(np.transpose(trial))

    # calculate average of covariance matrix
    cov_1 = rie_mean.mean_covariance(cov, metric="riemann")
    
    # Estimate the covariance matrix of every trial
    n_right_trials = right_data.shape[0]
    cov = np.zeros((n_right_trials, *cov_shape))
    for n_trial in range(n_right_trials):
        trial = signal.sosfilt(sos, right_data[n_trial], axis=0)
        cov[n_trial] = np.cov(np.transpose(trial))

    # calculate average of covariance matrix
    cov_2 = rie_mean.mean_covariance(cov, metric="riemann")
    
    # Solve the generalized eigenvalue problem
    n_pairs = CSP_COMPONENTS//2
    w, vr = linalg.eig(cov_1, cov_2, right=True)
    w = np.abs(w)
    sorted_indexes = np.argsort(w)
    chosen_indexes = np.zeros(2*n_pairs).astype(int)
    chosen_indexes[0:n_pairs] = sorted_indexes[0:n_pairs]
    chosen_indexes[n_pairs:2*n_pairs] = sorted_indexes[-n_pairs:]
    
    return vr[:, chosen_indexes]

In [6]:
def extract_features(X):
    trials = X.shape[0]
    F = np.zeros((trials, 2, CSP_COMPONENTS))
    for n_trial in range(trials):
        trial = X[n_trial]
        z = np.dot(np.transpose(W), np.transpose(trial))
        z = signal.sosfilt(sos, z, axis=1)
        
        # Calculate the wavelet features
        for n_feature in range(CSP_COMPONENTS):
            alpha_band, beta_band = pywt.dwt(z[n_feature], WAVELET)
            F[n_trial, 0, n_feature] = np.sum(np.abs(beta_band) ** 2)

        # Calculate the frequency-domain features
        psd_window_size = 100
        psd_window_overlap = psd_window_size//2
        low, high = 13, 30
        for n_feature in range(CSP_COMPONENTS):
            freqs, psd = signal.welch(z[n_feature], fs=FS, window="hann",
                                     nperseg=psd_window_size, noverlap=psd_window_overlap)
            beta_freqs = np.logical_and(freqs >= low, freqs <= high)
            freq_res = freqs[1] - freqs[0]
            F[n_trial, 1, n_feature] = simps(psd[beta_freqs], dx=freq_res)
        
    return F

Iterate on the subjects applying the algorithm, validating the results using the technique 10-fold cross-validation

In [7]:
for subject in subjects:
    print("========= Subject: ", subject)

    # Load data
    left_data_file = f"{DATA_FOLDER}/left-hand-subject-{subject}.csv"
    right_data_file = f"{DATA_FOLDER}/right-hand-subject-{subject}.csv"
    data = read_eeg_file(left_data_file, right_data_file, TIME_LENGTH, TIME_WINDOW, EPOCH_SIZE)
    
    W = compute_spatial_filters(data.X[data.labels == 0], data.X[data.labels == 1])
    
    subject_index = subject - 1
    cv = StratifiedKFold(n_splits=K_FOLD, shuffle=False)
    for (k, (train_index, test_index)) in enumerate(cv.split(data.X, data.labels)):
        X_train, X_test = data.X[train_index], data.X[test_index]
        y_train, y_test = data.labels[train_index], data.labels[test_index]
        
        # Feature extraction
        features_train = extract_features(X_train)
        features_test = extract_features(X_test)
    
        len_features = features_train.shape[1] * features_train.shape[2]
        features_train = np.reshape(features_train, newshape=(features_train.shape[0], len_features))
        features_test = np.reshape(features_test, newshape=(features_test.shape[0], len_features))

        # Feature normalization
        features_train = stats.zscore(features_train, axis=0)
        features_test = stats.zscore(features_test, axis=0)
        
        # GNB classifier
        gnb = GaussianNB(priors=[.5, .5], var_smoothing=1.0)
        gnb.fit(features_train, y_train)
        gnb_predictions = gnb.predict(features_test)
        gnb_accuracy = accuracy_score(y_test, gnb_predictions)
        print(f"GNB accuracy: {gnb_accuracy:.4f}")
        accuracies["GNB"][subject_index][k] = gnb_accuracy

        # SVM classifier
        svm = SVC(C=.8, kernel="rbf")
        svm.fit(features_train, y_train)
        svm_predictions = svm.predict(features_test)
        svm_accuracy = accuracy_score(y_test, svm_predictions)
        print(f"SVM accuracy: {svm_accuracy:.4f}")
        accuracies["SVM"][subject_index][k] = svm_accuracy

        # LDA classifier
        lda = LinearDiscriminantAnalysis()
        lda.fit(features_train, y_train)
        lda_predictions = lda.predict(features_test)
        lda_accuracy = accuracy_score(y_test, lda_predictions)
        print(f"LDA accuracy: {lda_accuracy:.4f}")
        accuracies["LDA"][subject_index][k] = lda_accuracy

        print()
    print()

for classifier in accuracies:
    print(classifier)
    for subject, cv_accuracies in enumerate(accuracies[classifier]):
        acc_mean = np.mean(cv_accuracies)*100
        acc_std = np.std(cv_accuracies)*100
        print(f"\tSubject {subject+1} average accuracy: {acc_mean:.4f} +/- {acc_std:.4f}")
    average_acc_mean = np.mean(accuracies[classifier])*100
    average_acc_std = np.std(accuracies[classifier])*100
    print(f"\tAverage accuracy: {average_acc_mean:.4f} +/- {average_acc_std:.4f}")

GNB accuracy: 0.8571
SVM accuracy: 0.8214
LDA accuracy: 0.7143

GNB accuracy: 0.8214
SVM accuracy: 0.7857
LDA accuracy: 0.8214

GNB accuracy: 0.8214
SVM accuracy: 0.8214
LDA accuracy: 0.7857

GNB accuracy: 0.8214
SVM accuracy: 0.9286
LDA accuracy: 0.8929

GNB accuracy: 0.8929
SVM accuracy: 0.8571
LDA accuracy: 0.7500

GNB accuracy: 1.0000
SVM accuracy: 1.0000
LDA accuracy: 0.9643

GNB accuracy: 0.8571
SVM accuracy: 0.8929
LDA accuracy: 0.9643

GNB accuracy: 0.7143
SVM accuracy: 0.7500
LDA accuracy: 0.7857

GNB accuracy: 0.8571
SVM accuracy: 0.8929
LDA accuracy: 0.8929

GNB accuracy: 0.6786
SVM accuracy: 0.7857
LDA accuracy: 0.8571


GNB accuracy: 0.8929
SVM accuracy: 0.9286
LDA accuracy: 0.9643

GNB accuracy: 0.9643
SVM accuracy: 1.0000
LDA accuracy: 1.0000

GNB accuracy: 0.9286
SVM accuracy: 0.9643
LDA accuracy: 0.9286

GNB accuracy: 0.9286
SVM accuracy: 0.8929
LDA accuracy: 0.9643

GNB accuracy: 0.9286
SVM accuracy: 1.0000
LDA accuracy: 1.0000

GNB accuracy: 0.7857
SVM accuracy: 0.85