In [1]:
import warnings
warnings.filterwarnings('ignore')

import logging
import os.path
import time
from collections import OrderedDict
import sys
import pickle

import numpy as np
import mne

from braindecode.models.deep4 import Deep4Net
from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import (
    LossMonitor,
    MisclassMonitor,
    RuntimeMonitor,
)
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.datautil.iterators import BalancedBatchSizeIterator
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.datautil.signalproc import (
    bandpass_cnt,
    exponential_running_standardize,
)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne

from braindecode.mne_ext.signalproc import resample_cnt

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt

from TA_CSPNN import *

In [2]:
class model_config(object):
    def __init__(self):
        self.channels = 22
        self.timesamples = 250
        self.timeKernelLen = 64
        self.num_classes = 4
        self.Ft = 8 # individual model is 8
        self.Fs = 2 # individual model is 2

In [3]:
data_folder = './bci_2a'
subject_id_list = [9,3]

In [4]:
x_train_list = []
y_train_list = []

x_val_list = []
y_val_list = []

x_test_list = []
y_test_list = []

for subject_id in subject_id_list:
    low_cut_hz = 4
    ival = [500, 2500]
    max_epochs = 500
    max_increase_epochs = 160
    batch_size = 60
    high_cut_hz = 40
    factor_new = 1e-3
    init_block_size = 1000
    valid_set_fraction = 0.1
    sampling_rate = 124.5 # using 125 results in 251 samples

    train_filename = "A{:02d}T.gdf".format(subject_id)
    test_filename = "A{:02d}E.gdf".format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace(".gdf", ".mat")
    test_label_filepath = test_filepath.replace(".gdf", ".mat")

    train_loader = BCICompetition4Set2A(
        train_filepath, labels_filename=train_label_filepath
    )
    test_loader = BCICompetition4Set2A(
        test_filepath, labels_filename=test_label_filepath
    )
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    train_cnt = train_cnt.drop_channels(
            ["EOG-left", "EOG-central", "EOG-right"]
        )
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info["sfreq"],
                                    filt_order=3, axis=1,), train_cnt,)

    train_cnt = mne_apply(lambda a: exponential_running_standardize(a.T, factor_new=factor_new, init_block_size=init_block_size,
                          eps=1e-4,).T, train_cnt,)

    train_cnt = resample_cnt(train_cnt, sampling_rate)

    test_cnt = test_cnt.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, test_cnt.info["sfreq"], 
                                   filt_order=3, axis=1,), test_cnt,)
    test_cnt = mne_apply(lambda a: exponential_running_standardize( a.T, factor_new=factor_new, init_block_size=init_block_size,
                         eps=1e-4,).T, test_cnt,)

    test_cnt = resample_cnt(test_cnt, sampling_rate)

    marker_def = OrderedDict([
            ("Left Hand", [1]),
            ("Right Hand", [2]),
            ("Foot", [3]),
            ("Tongue", [4]),
        ])

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
    train_set, valid_set = split_into_two_sets(
        train_set, first_set_fraction=1 - valid_set_fraction
    )
    x_train = train_set.X[:,None,:,:]
    y_train = to_categorical(train_set.y)

    x_val = valid_set.X[:,None,:,:]
    y_val = to_categorical(valid_set.y)

    x_test = test_set.X[:,None,:,:]
    y_test = test_set.y # dont need y_test to be one hot, predictions are class numbers
    
    x_train_list.append(x_train)
    y_train_list.append(y_train)
    
    x_val_list.append(x_val)
    y_val_list.append(y_val)
    
    x_test_list.append(x_test)
    y_test_list.append(y_test)

Extracting EDF parameters from /home/ncarnigl/TA-CSPNN-Research/bci_2a/A09T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 673327  =      0.000 ...  2693.308 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Extracting EDF parameters from /home/ncarnigl/TA-CSPNN-Research/bci_2a/A09E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 675097  =      0.000 ...  2700.388 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']


W0314 14:10:28.908883 139960005736256 signalproc.py:55] This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=335317
    Range : 0 ... 335316 =      0.000 ...  2693.301 secs
Ready.


W0314 14:10:33.339093 139960005736256 signalproc.py:55] This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=336198
    Range : 0 ... 336197 =      0.000 ...  2700.378 secs
Ready.
Extracting EDF parameters from /home/ncarnigl/TA-CSPNN-Research/bci_2a/A03T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 660529  =      0.000 ...  2642.116 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Extracting EDF parameters from /home/ncarnigl/TA-CSPNN-Research/bci_2a/A03E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 648774  =      0.000 ...  2595.096 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']


W0314 14:10:41.570224 139960005736256 signalproc.py:55] This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=328943
    Range : 0 ... 328942 =      0.000 ...  2642.104 secs
Ready.


W0314 14:10:45.542816 139960005736256 signalproc.py:55] This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=323089
    Range : 0 ... 323088 =      0.000 ...  2595.084 secs
Ready.


In [5]:
x_train = np.concatenate(tuple(x_train_list))
y_train = np.concatenate(tuple(y_train_list))

x_val = np.concatenate(tuple(x_val_list))
y_val = np.concatenate(tuple(y_val_list))

x_test = np.concatenate(tuple(x_test_list))
y_test = np.concatenate(tuple(y_test_list))

In [6]:
print(x_train.shape, y_train.shape)
print(x_val.shape, y_val.shape)
print(x_test.shape, y_test.shape)

(518, 1, 22, 250) (518, 4)
(58, 1, 22, 250) (58, 4)
(576, 1, 22, 250) (576,)


In [7]:
num_models = 10
config = model_config()
test_acc = 0
subject_data = np.zeros((10,10))
for i in range(num_models):
    
    model = TA_CSPNN(config.num_classes, Channels=config.channels, Timesamples=config.timesamples,
                    timeKernelLen = config.timeKernelLen, Ft=config.Ft, Fs=config.Fs)
    
    model.compile(loss='categorical_crossentropy', optimizer='Adam', metrics = ['accuracy'])

    es = EarlyStopping(monitor='val_loss', mode='min', verbose=0, patience=50)

    history = model.fit(x_train, y_train, epochs=500, validation_data=((x_val, y_val)), callbacks=[es],
                        verbose=0,shuffle=True)

    y_pred = model.predict(x_test)

    # y_pred not one hot encoded, classes = 0,1,2,3
    y_pred = np.argmax(y_pred, axis=1)

    # accuracy for all classes
    acc = np.sum(y_pred == y_test) / len(y_test)
    test_acc += acc
    print(f"accuracy: {acc}")
       
    # accuracy for each subject
    for index, subject_num in enumerate(subject_id_list):
        y_pred = model.predict(x_test_list[index])
        y_pred = np.argmax(y_pred, axis=1)
        
        sub_acc = np.sum(y_pred == y_test_list[index]) / len(y_pred)
        subject_data[subject_num, i] = sub_acc        

test_acc /= num_models
print(test_acc)

W0314 14:10:46.563585 139960005736256 deprecation.py:506] From /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0314 14:10:51.448818 139960005736256 deprecation.py:323] From /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1205: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


accuracy: 0.8020833333333334
accuracy: 0.7690972222222222
accuracy: 0.8229166666666666
accuracy: 0.8003472222222222
accuracy: 0.8055555555555556
accuracy: 0.8020833333333334
accuracy: 0.7708333333333334
accuracy: 0.7795138888888888
accuracy: 0.7760416666666666
accuracy: 0.8003472222222222
0.7928819444444445
