In [None]:
import logging
import os.path
import time
from collections import OrderedDict
import sys

import numpy as np


from braindecode.models.deep4 import Deep4Net
from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import (
    LossMonitor,
    MisclassMonitor,
    RuntimeMonitor,
)
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.datautil.iterators import BalancedBatchSizeIterator
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.datautil.signalproc import (
    bandpass_cnt,
    exponential_running_standardize,
)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne

from braindecode.mne_ext.signalproc import resample_cnt

In [11]:
data_folder = './bci_2a'
subject_id = 1
low_cut_hz = 4

ival = [500, 2500]
max_epochs = 500
max_increase_epochs = 160
batch_size = 60
high_cut_hz = 40
factor_new = 1e-3
init_block_size = 1000
valid_set_fraction = 0.1
sampling_rate = 124.5 # using 125 results in 251 samples

train_filename = "A{:02d}T.gdf".format(subject_id)
test_filename = "A{:02d}E.gdf".format(subject_id)
train_filepath = os.path.join(data_folder, train_filename)
test_filepath = os.path.join(data_folder, test_filename)
train_label_filepath = train_filepath.replace(".gdf", ".mat")
test_label_filepath = test_filepath.replace(".gdf", ".mat")

train_loader = BCICompetition4Set2A(
    train_filepath, labels_filename=train_label_filepath
)
test_loader = BCICompetition4Set2A(
    test_filepath, labels_filename=test_label_filepath
)
train_cnt = train_loader.load()
test_cnt = test_loader.load()

Extracting EDF parameters from /Users/noahcarniglia/189project/bci_2a/A01T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_gdf = mne.io.read_raw_gdf(self.filename, stim_channel="auto")


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Extracting EDF parameters from /Users/noahcarniglia/189project/bci_2a/A01E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_gdf = mne.io.read_raw_gdf(self.filename, stim_channel="auto")


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']


In [12]:
train_cnt = train_cnt.drop_channels(
        ["EOG-left", "EOG-central", "EOG-right"]
    )
assert len(train_cnt.ch_names) == 22
# lets convert to millvolt for numerical stability of next operations
train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
train_cnt = mne_apply(
    lambda a: bandpass_cnt(
        a,
        low_cut_hz,
        high_cut_hz,
        train_cnt.info["sfreq"],
        filt_order=3,
        axis=1,
    ),
    train_cnt,
)

train_cnt = mne_apply(
    lambda a: exponential_running_standardize(
        a.T,
        factor_new=factor_new,
        init_block_size=init_block_size,
        eps=1e-4,
    ).T,
    train_cnt,
)

train_cnt = resample_cnt(train_cnt, sampling_rate)

test_cnt = test_cnt.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
assert len(test_cnt.ch_names) == 22
test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
test_cnt = mne_apply(
    lambda a: bandpass_cnt(
        a,
        low_cut_hz,
        high_cut_hz,
        test_cnt.info["sfreq"],
        filt_order=3,
        axis=1,
    ),
    test_cnt,
)
test_cnt = mne_apply(
    lambda a: exponential_running_standardize(
        a.T,
        factor_new=factor_new,
        init_block_size=init_block_size,
        eps=1e-4,
    ).T,
    test_cnt,
)

test_cnt = resample_cnt(test_cnt, sampling_rate)

marker_def = OrderedDict([
        ("Left Hand", [1]),
        ("Right Hand", [2]),
        ("Foot", [3]),
        ("Tongue", [4]),
    ])

train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

train_set, valid_set = split_into_two_sets(
    train_set, first_set_fraction=1 - valid_set_fraction
)

print(type(train_set.X))
print(type(train_set.y))
print(train_set.X.shape)

This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=334918
    Range : 0 ... 334917 =      0.000 ...  2690.096 secs
Ready.


This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=342126
    Range : 0 ... 342125 =      0.000 ...  2747.992 secs
Ready.
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
(259, 22, 250)


In [23]:
x_train = train_set.X[:,None,:,:]
y_train = train_set.y

x_val = valid_set.X[:,None,:,:]
y_val = valid_set.y

x_test = test_set.X[:,None,:,:]
y_test = test_set.y

In [28]:
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt

from TA_CSPNN import *

In [34]:
y_train = to_categorical(y_train)
y_val = to_categorical(y_val)
y_test = to_categorical(y_test)

In [35]:
class model_config(object):
    def __init__(self):
        self.channels = 22
        self.timesamples = 250
        self.timeKernelLen = 64
        self.num_classes = 4
        self.Ft = 8
        self.Fs = 2

In [36]:
config = model_config()
model = TA_CSPNN(config.num_classes, Channels=config.channels, Timesamples=config.timesamples,
                timeKernelLen = config.timeKernelLen, Ft=config.Ft, Fs=config.Fs)
opt = Adam(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer='Adam', metrics = ['accuracy'])

es = EarlyStopping(monitor='val_loss', mode='min', verbose=2, patience=50)

In [37]:
history = model.fit(x_train, y_train, 
                    epochs=500, 
                    validation_data=((x_val, y_val)), 
                    callbacks=[es],
                    verbose=2)

Train on 259 samples, validate on 29 samples
Epoch 1/500


InvalidArgumentError:  Default AvgPoolingOp only supports NHWC on device type CPU
	 [[node model_1/average_pooling2d_1/AvgPool (defined at <ipython-input-37-d2c1e67395f1>:5) ]] [Op:__inference_distributed_function_1666]

Errors may have originated from an input operation.
Input Source operations connected to node model_1/average_pooling2d_1/AvgPool:
 model_1/lambda_1/pow (defined at /Users/noahcarniglia/189project/TA_CSPNN.py:55)

Function call stack:
distributed_function


In [32]:
y_train

array([3, 2, 1, 0, 0, 1, 2, 3, 1, 2, 0, 0, 0, 3, 1, 1, 0, 0, 2, 0, 1, 3,
       3, 2, 0, 3, 3, 1, 3, 3, 1, 0, 1, 2, 2, 2, 3, 2, 0, 3, 1, 2, 1, 2,
       3, 1, 2, 0, 0, 0, 3, 1, 0, 2, 0, 2, 1, 3, 0, 2, 2, 0, 2, 1, 3, 3,
       3, 2, 0, 3, 1, 3, 1, 0, 2, 1, 0, 2, 2, 0, 2, 3, 3, 1, 0, 1, 3, 1,
       3, 2, 1, 1, 1, 2, 3, 0, 1, 3, 0, 2, 2, 3, 0, 0, 2, 1, 3, 3, 3, 1,
       0, 2, 1, 3, 0, 3, 2, 1, 3, 3, 0, 1, 1, 2, 3, 1, 0, 0, 3, 1, 0, 2,
       1, 1, 2, 0, 3, 2, 2, 2, 2, 0, 1, 0, 1, 0, 0, 2, 2, 1, 2, 3, 0, 3,
       0, 0, 1, 3, 2, 1, 3, 2, 3, 2, 3, 1, 1, 3, 0, 1, 1, 1, 2, 3, 0, 3,
       0, 2, 0, 3, 0, 2, 0, 1, 2, 2, 3, 0, 1, 3, 1, 2, 2, 0, 3, 1, 3, 0,
       0, 2, 2, 1, 3, 1, 1, 0, 1, 3, 3, 1, 1, 1, 1, 3, 3, 2, 3, 0, 1, 2,
       1, 0, 3, 0, 3, 0, 0, 0, 0, 2, 2, 3, 1, 2, 2, 2, 3, 2, 0, 2, 0, 3,
       1, 3, 3, 2, 3, 3, 2, 1, 3, 2, 0, 1, 1, 1, 2, 1, 3])