In [1]:
import warnings
warnings.filterwarnings('ignore')

import logging
import os.path
import time
from collections import OrderedDict
import sys
import pickle

import numpy as np
import mne

from braindecode.models.deep4 import Deep4Net
from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import (
    LossMonitor,
    MisclassMonitor,
    RuntimeMonitor,
)
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.datautil.iterators import BalancedBatchSizeIterator
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.datautil.signalproc import (
    bandpass_cnt,
    exponential_running_standardize,
)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne

from braindecode.mne_ext.signalproc import resample_cnt

from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt

from TA_CSPNN import *

In [2]:
class model_config(object):
    def __init__(self):
        self.channels = 22
        self.timesamples = 250
        self.timeKernelLen = 64
        self.num_classes = 4
        self.Ft = 8
        self.Fs = 2
        self.dropout = 0.25

In [None]:
data_folder = './bci_2a'
subject_acc = {}
models_per_subject = 10
subject_data = np.zeros((10,10))
for subject_id in range(1,10):
    if subject_id == 4:
        continue
    low_cut_hz = 4
    ival = [500, 2500]
    max_epochs = 500
    max_increase_epochs = 160
    batch_size = 60
    high_cut_hz = 40
    factor_new = 1e-3
    init_block_size = 1000
    valid_set_fraction = 0.1
    sampling_rate = 124.5 # using 125 results in 251 samples

    train_filename = "A{:02d}T.gdf".format(subject_id)
    test_filename = "A{:02d}E.gdf".format(subject_id)
    train_filepath = os.path.join(data_folder, train_filename)
    test_filepath = os.path.join(data_folder, test_filename)
    train_label_filepath = train_filepath.replace(".gdf", ".mat")
    test_label_filepath = test_filepath.replace(".gdf", ".mat")

    train_loader = BCICompetition4Set2A(train_filepath, labels_filename=train_label_filepath)
    test_loader = BCICompetition4Set2A(test_filepath, labels_filename=test_label_filepath)
    
    train_cnt = train_loader.load()
    test_cnt = test_loader.load()

    train_cnt = train_cnt.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(lambda a: bandpass_cnt(a,low_cut_hz,high_cut_hz,
                                                 train_cnt.info["sfreq"],filt_order=3,axis=1,),train_cnt,)

    train_cnt = mne_apply(lambda a: exponential_running_standardize(a.T,factor_new=factor_new,
                          init_block_size=init_block_size,eps=1e-4,).T,train_cnt,)

    train_cnt = resample_cnt(train_cnt, sampling_rate)

    test_cnt = test_cnt.drop_channels(["EOG-left", "EOG-central", "EOG-right"])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(lambda a: bandpass_cnt(a,low_cut_hz,high_cut_hz,
                                                test_cnt.info["sfreq"],filt_order=3,axis=1,),test_cnt,)
    test_cnt = mne_apply(lambda a: exponential_running_standardize(a.T,factor_new=factor_new,
                         init_block_size=init_block_size,eps=1e-4,).T,test_cnt,)

    test_cnt = resample_cnt(test_cnt, sampling_rate)

    marker_def = OrderedDict([
            ("Left Hand", [1]),
            ("Right Hand", [2]),
            ("Foot", [3]),
            ("Tongue", [4]),
        ])

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
    train_set, valid_set = split_into_two_sets(
        train_set, first_set_fraction=1 - valid_set_fraction
    )
    x_train = train_set.X[:,None,:,:]
    y_train = to_categorical(train_set.y)

    x_val = valid_set.X[:,None,:,:]
    y_val = to_categorical(valid_set.y)

    x_test = test_set.X[:,None,:,:]
    y_test = test_set.y

    test_acc = 0
    best_acc = 0
    for i in range(models_per_subject):
        config = model_config()
        model = TA_CSPNN(config.num_classes, Channels=config.channels, Timesamples=config.timesamples,
                        timeKernelLen = config.timeKernelLen, Ft=config.Ft, Fs=config.Fs, dropOut=config.dropout)
        model.compile(loss='categorical_crossentropy', optimizer='Adam', metrics = ['accuracy'])

        es = EarlyStopping(monitor='val_loss', mode='min', verbose=0, patience=50)

        history = model.fit(x_train, y_train, epochs=500, validation_data=((x_val, y_val)), 
                            callbacks=[es],verbose=0)

        y_pred = model.predict(x_test)

        y_pred = np.argmax(y_pred, axis=1)
        
        acc = np.sum(y_pred == y_test) / len(y_test)
        test_acc += acc
        
        print(f"accuracy: {acc}")
        subject_data[subject_id,i] = acc

#         if acc > best_acc:
#             print(f"saving model with acc: {acc}")
#             model.save(f"models/subject_{subject_id}")
#             best_acc = acc
        
    test_acc /= models_per_subject
    print(f"subject_id {subject_id} : accuracy {test_acc}")
    subject_acc[subject_id] = test_acc

Extracting EDF parameters from /home/ncarnigl/TA-CSPNN-Research/bci_2a/A01T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Extracting EDF parameters from /home/ncarnigl/TA-CSPNN-Research/bci_2a/A01E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']


W0313 21:38:27.776702 140531481552704 signalproc.py:55] This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=334918
    Range : 0 ... 334917 =      0.000 ...  2690.096 secs
Ready.


W0313 21:38:31.671991 140531481552704 signalproc.py:55] This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=342126
    Range : 0 ... 342125 =      0.000 ...  2747.992 secs
Ready.


W0313 21:38:32.427098 140531481552704 deprecation.py:506] From /opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


accuracy: 0.7048611111111112
accuracy: 0.7256944444444444
accuracy: 0.6666666666666666
accuracy: 0.7222222222222222
accuracy: 0.7256944444444444
accuracy: 0.71875
accuracy: 0.7534722222222222
accuracy: 0.7291666666666666
accuracy: 0.7569444444444444
accuracy: 0.7048611111111112
subject_id 1 : accuracy 0.7208333333333334
Extracting EDF parameters from /home/ncarnigl/TA-CSPNN-Research/bci_2a/A02T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '769', '770', '771', '772']
Extracting EDF parameters from /home/ncarnigl/TA-CSPNN-Research/bci_2a/A02E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 662665  =      0.000 ...  2650.660 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


Used Annotations descriptions: ['1023', '1072', '276', '277', '32766', '768', '783']


W0313 21:45:20.457428 140531481552704 signalproc.py:55] This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=337230
    Range : 0 ... 337229 =      0.000 ...  2708.667 secs
Ready.


W0313 21:45:23.958452 140531481552704 signalproc.py:55] This is not causal, uses future data....


Creating RawArray with float64 data, n_channels=22, n_times=330007
    Range : 0 ... 330006 =      0.000 ...  2650.651 secs
Ready.
accuracy: 0.3229166666666667
accuracy: 0.34375
accuracy: 0.3472222222222222
accuracy: 0.2777777777777778
accuracy: 0.3611111111111111


In [None]:
for k, v in subject_acc.items():
    print(f"{k} : {v}")

In [None]:
place = "functions/elu"
for i in range(1,10):
    if i == 4:
        continue
    a = subject_data[i]
    print(np.mean(a))
    np.save(f"data/{place}/subject_{i}", a)
    