In [1]:
import logging
import os.path
import time
from collections import OrderedDict
import sys

import numpy as np
import torch.nn.functional as F
from torch import optim

from braindecode.models.deep4 import Deep4Net
from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, RuntimeMonitor
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.datautil.iterators import BalancedBatchSizeIterator
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
from braindecode.torch_ext.util import set_random_seeds, np_to_var
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.datautil.signalproc import (bandpass_cnt,
                                             exponential_running_standardize)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne

import mne
from scipy.io import loadmat
log = logging.getLogger(__name__)

In [None]:
raw_edf = mne.io.read_raw_edf("C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A01T.gdf", stim_channel='auto')
raw_edf.load_data()
data = raw_edf.get_data()

In [None]:
raw_edf.ch_names

In [None]:
class BCICompetition4Set2A(object):
    def __init__(self, filename, load_sensor_names=None,
                 labels_filename=None):
        assert load_sensor_names is None
        self.__dict__.update(locals())
        del self.self

    def load(self):
        cnt = self.extract_data()
        events, artifact_trial_mask = self.extract_events(cnt)
        cnt.info['events'] = events
        cnt.info['artifact_trial_mask'] = artifact_trial_mask
        return cnt

    def extract_data(self):
        raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
        raw_edf.load_data()
        # correct nan values

        data = raw_edf.get_data()

        # do not correct stimulus channel
        assert raw_edf.ch_names[-1] == 'STI 014'
        for i_chan in range(data.shape[0] - 1):
            # first set to nan, than replace nans by nanmean.
            this_chan = data[i_chan]
            data[i_chan] = np.where(this_chan == np.min(this_chan),
                                    np.nan, this_chan)
            mask = np.isnan(data[i_chan])
            chan_mean = np.nanmean(data[i_chan])
            data[i_chan, mask] = chan_mean
        gdf_events = raw_edf.find_edf_events()
        raw_edf = mne.io.RawArray(data, raw_edf.info, verbose='WARNING')
        # remember gdf events
        raw_edf.info['gdf_events'] = gdf_events
        return raw_edf

    def extract_events(self, raw_edf):
        # all events
        events = np.array(list(zip(
            raw_edf.info['gdf_events'][1],
            raw_edf.info['gdf_events'][2])))

        # only trial onset events
        trial_mask = [ev_code in [769, 770, 771, 772, 783]
                      for ev_code in events[:,1]]
        trial_events = events[trial_mask]
        assert (len(trial_events) == 288), ("Got {:d} markers".format(len(trial_events)))
        # event markers 769,770 -> 1,2
        trial_events[:, 1] = trial_events[:, 1] - 768
        # possibly overwrite with markers from labels file
        if self.labels_filename is not None:
            classes = loadmat(self.labels_filename)['classlabel'].squeeze()
            trial_events[:, 1] = classes
        unique_classes = np.unique(trial_events[:, 1])
        assert np.array_equal([1, 2, 3 ,4], unique_classes), (
            "Expect 1,2,3,4 as class labels, got {:s}".format(
                str(unique_classes))
        )
        # now also create 0-1 vector for rejected trials
        trial_start_events = events[events[:, 1] == 768]
        assert len(trial_start_events) == len(trial_events)
        artifact_trial_mask = np.zeros(len(trial_events), dtype=np.uint8)
        artifact_events = events[events[:, 1] == 1023]
        for artifact_time in artifact_events[:, 0]:
            i_trial = trial_start_events[:, 0].tolist().index(artifact_time)
            artifact_trial_mask[i_trial] = 1

        # mne expects events with 3 ints each:
        events = np.zeros((len(trial_events), 3), dtype=np.int32)
        events[:,0] = trial_events[:,0]
        events[:,2] = trial_events[:,1]

        return events, artifact_trial_mask

In [2]:
def run_exp(data_folder, subject_id, low_cut_hz, model, cuda):
    ival = [-500, 4000]
    max_epochs = 1600
    max_increase_epochs = 160
    batch_size = 60
    high_cut_hz = 38
    factor_new = 1e-3
    init_block_size = 1000
    valid_set_fraction = 0.2

    number_of_patients = len(subject_id)
    
    train_filenames = ['A{:02d}T.gdf'.format(x) for x in subject_id]
    test_filenames = ['A{:02d}E.gdf'.format(x) for x in subject_id]
    train_filepaths = []
    test_filepaths = []
    train_label_filepaths=[]
    test_label_filepaths=[]
    train_cnts=[]
    test_cnts = []
    
    for i in range(len(train_filenames)):
        train_filepaths.append(os.path.join(data_folder, train_filenames[i]))
        test_filepaths.append(os.path.join(data_folder, test_filenames[i]))
        #test_filepath[i] = os.path.join(data_folder, test_filename[i])
        train_label_filepaths.append(train_filepaths[i].replace('.gdf', '.mat'))
        test_label_filepaths.append(test_filepaths[i].replace('.gdf', '.mat'))
        train_loader = BCICompetition4Set2A(train_filepaths[i], labels_filename=train_label_filepaths[i]) 
        test_loader = BCICompetition4Set2A(test_filepaths[i], labels_filename=test_label_filepaths[i])
        train_cnts.append(train_loader.load())
        test_cnts.append(test_loader.load())
        
    train_cnt = mne.io.concatenate_raws(train_cnts)
    test_cnt = mne.io.concatenate_raws(test_cnts)
    # Preprocessing

    train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left',
                                         'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    # lets convert to millvolt for numerical stability of next operations
    train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = mne_apply(lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = mne_apply(lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                                  init_block_size=init_block_size,
                                                  eps=1e-4).T, train_cnt)

    test_cnt = test_cnt.drop_channels(['STI 014', 'EOG-left',
                                       'EOG-central', 'EOG-right'])
    assert len(test_cnt.ch_names) == 22
    test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
    test_cnt = mne_apply(lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, test_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), test_cnt)
    test_cnt = mne_apply(lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                                  init_block_size=init_block_size,
                                                  eps=1e-4).T,test_cnt)
 

    marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                              ('Foot', [3]), ('Tongue', [4])])

    train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
    test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

    train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=1-valid_set_fraction)
         

    set_random_seeds(seed=20190706, cuda=cuda)

    n_classes = 4
    n_chans = int(train_set.X.shape[1])
    input_time_length = train_set.X.shape[2]
    if model == 'shallow':
        model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length='auto').create_network()
    elif model == 'deep':
        model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
                            final_conv_length='auto').create_network()
    if cuda:
        model.cuda()
    log.info("Model: \n{:s}".format(str(model)))

    optimizer = optim.Adam(model.parameters())

    iterator = BalancedBatchSizeIterator(batch_size=batch_size)

    stop_criterion = Or([MaxEpochs(max_epochs),
                         NoDecrease('valid_misclass', max_increase_epochs)])

    monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()]

    model_constraint = MaxNormDefaultConstraint()

    exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
                     loss_function=F.nll_loss, optimizer=optimizer,
                     model_constraint=model_constraint,
                     monitors=monitors,
                     stop_criterion=stop_criterion,
                     remember_best_column='valid_misclass',
                     run_after_early_stop=True, cuda=cuda)
    exp.run()
    return exp

In [None]:
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                    level=logging.DEBUG, stream=sys.stdout)
# Should contain both .gdf files and .mat-labelfiles from competition

data_folder = 'C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/'
subject_id = [1,2,3,4] # 1-9
low_cut_hz = 4 # 0 or 4
model = 'deep' #'shallow' or 'deep'
cuda = False
exp = run_exp(data_folder, subject_id, low_cut_hz, model, cuda)
log.info("Last 10 epochs")
log.info("\n" + str(exp.epochs_df.iloc[-10:]))

Extracting EDF parameters from C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A01T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Reading 0 ... 672527  =      0.000 ...  2690.108 secs...
Extracting EDF parameters from C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A01E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Extracting EDF parameters from C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A02T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Extracting EDF parameters from C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A02E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 662665  =      0.000 ...  2650.660 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Extracting EDF parameters from C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A03T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 660529  =      0.000 ...  2642.116 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Extracting EDF parameters from C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A03E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 648774  =      0.000 ...  2595.096 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Extracting EDF parameters from C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A04T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 600914  =      0.000 ...  2403.656 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Extracting EDF parameters from C:/Users/Fariah Hayee/EEG/data/BCICIV_2a_gdf/A04E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 660046  =      0.000 ...  2640.184 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


2019-02-14 00:34:03,360 INFO : Trial per class:
Counter({'Tongue': 72, 'Foot': 72, 'Right Hand': 72, 'Left Hand': 72})
2019-02-14 00:34:04,951 INFO : Trial per class:
Counter({'Left Hand': 72, 'Right Hand': 72, 'Foot': 72, 'Tongue': 72})
2019-02-14 00:34:06,068 INFO : Model: 
Sequential(
  (dimshuffle): Expression(expression=_transpose_time_to_spat)
  (conv_time): Conv2d(1, 25, kernel_size=(10, 1), stride=(1, 1))
  (conv_spat): Conv2d(25, 25, kernel_size=(1, 22), stride=(1, 1), bias=False)
  (bnorm): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin): Expression(expression=elu)
  (pool): MaxPool2d(kernel_size=(3, 1), stride=(3, 1), padding=0, dilation=1, ceil_mode=False)
  (pool_nonlin): Expression(expression=identity)
  (drop_2): Dropout(p=0.5)
  (conv_2): Conv2d(25, 50, kernel_size=(10, 1), stride=(1, 1), bias=False)
  (bnorm_2): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (nonlin_2): Expression(expres

2019-02-14 00:39:01,498 INFO : Time only for training updates: 17.45s
2019-02-14 00:39:16,958 INFO : Epoch 9
2019-02-14 00:39:16,960 INFO : train_loss                2.79279
2019-02-14 00:39:16,961 INFO : valid_loss                3.38916
2019-02-14 00:39:16,962 INFO : test_loss                 2.85612
2019-02-14 00:39:16,964 INFO : train_misclass            0.72174
2019-02-14 00:39:16,965 INFO : valid_misclass            0.77586
2019-02-14 00:39:16,967 INFO : test_misclass             0.73958
2019-02-14 00:39:16,968 INFO : runtime                   33.04409
2019-02-14 00:39:16,969 INFO : 
2019-02-14 00:39:16,987 INFO : New best valid_misclass: 0.775862
2019-02-14 00:39:16,990 INFO : 
2019-02-14 00:39:34,362 INFO : Time only for training updates: 17.37s
2019-02-14 00:39:49,314 INFO : Epoch 10
2019-02-14 00:39:49,316 INFO : train_loss                2.66008
2019-02-14 00:39:49,317 INFO : valid_loss                3.25812
2019-02-14 00:39:49,318 INFO : test_loss                 2.72474
2

2019-02-14 00:45:45,395 INFO : Epoch 21
2019-02-14 00:45:45,397 INFO : train_loss                1.94165
2019-02-14 00:45:45,398 INFO : valid_loss                2.65721
2019-02-14 00:45:45,399 INFO : test_loss                 2.19269
2019-02-14 00:45:45,400 INFO : train_misclass            0.62609
2019-02-14 00:45:45,402 INFO : valid_misclass            0.72414
2019-02-14 00:45:45,403 INFO : test_misclass             0.67361
2019-02-14 00:45:45,404 INFO : runtime                   32.32350
2019-02-14 00:45:45,405 INFO : 
2019-02-14 00:45:45,426 INFO : New best valid_misclass: 0.724138
2019-02-14 00:45:45,427 INFO : 
2019-02-14 00:46:02,722 INFO : Time only for training updates: 17.29s
2019-02-14 00:46:17,649 INFO : Epoch 22
2019-02-14 00:46:17,650 INFO : train_loss                1.60962
2019-02-14 00:46:17,652 INFO : valid_loss                2.27105
2019-02-14 00:46:17,653 INFO : test_loss                 1.89374
2019-02-14 00:46:17,654 INFO : train_misclass            0.59130
2019-

2019-02-14 00:52:10,630 INFO : test_misclass             0.60069
2019-02-14 00:52:10,631 INFO : runtime                   32.30451
2019-02-14 00:52:10,632 INFO : 
2019-02-14 00:52:27,678 INFO : Time only for training updates: 17.04s
2019-02-14 00:52:42,566 INFO : Epoch 34
2019-02-14 00:52:42,568 INFO : train_loss                1.49401
2019-02-14 00:52:42,569 INFO : valid_loss                2.41008
2019-02-14 00:52:42,570 INFO : test_loss                 1.86776
2019-02-14 00:52:42,571 INFO : train_misclass            0.53913
2019-02-14 00:52:42,573 INFO : valid_misclass            0.70690
2019-02-14 00:52:42,574 INFO : test_misclass             0.57292
2019-02-14 00:52:42,575 INFO : runtime                   32.11862
2019-02-14 00:52:42,576 INFO : 
2019-02-14 00:52:42,592 INFO : New best valid_misclass: 0.706897
2019-02-14 00:52:42,594 INFO : 
2019-02-14 00:52:59,562 INFO : Time only for training updates: 16.97s
2019-02-14 00:53:14,688 INFO : Epoch 35
2019-02-14 00:53:14,690 INFO : t

2019-02-14 00:59:07,454 INFO : runtime                   32.02067
2019-02-14 00:59:07,455 INFO : 
2019-02-14 00:59:24,506 INFO : Time only for training updates: 17.05s
2019-02-14 00:59:39,529 INFO : Epoch 47
2019-02-14 00:59:39,531 INFO : train_loss                1.16847
2019-02-14 00:59:39,532 INFO : valid_loss                2.33677
2019-02-14 00:59:39,533 INFO : test_loss                 1.70048
2019-02-14 00:59:39,535 INFO : train_misclass            0.42609
2019-02-14 00:59:39,536 INFO : valid_misclass            0.62069
2019-02-14 00:59:39,538 INFO : test_misclass             0.52778
2019-02-14 00:59:39,539 INFO : runtime                   31.82379
2019-02-14 00:59:39,540 INFO : 
2019-02-14 00:59:39,557 INFO : New best valid_misclass: 0.620690
2019-02-14 00:59:39,558 INFO : 
2019-02-14 00:59:56,595 INFO : Time only for training updates: 17.03s
2019-02-14 01:00:11,619 INFO : Epoch 48
2019-02-14 01:00:11,622 INFO : train_loss                0.94672
2019-02-14 01:00:11,623 INFO : v