In [1]:
import logging
import os.path
import time
from collections import OrderedDict
import sys

import pandas as pd
import numpy as np
import torch.nn.functional as F
from torch import optim
import torch as th
from sklearn import metrics
import matplotlib.pyplot as plt
import seaborn as sn

from braindecode.models.deep4 import Deep4Net
from braindecode.models.util import to_dense_prediction_model
from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.experiments.experiment import Experiment
from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \
    RuntimeMonitor, CroppedTrialMisclassMonitor
from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.datautil.splitters import split_into_two_sets
from braindecode.torch_ext.constraints import MaxNormDefaultConstraint
from braindecode.torch_ext.util import set_random_seeds, np_to_var
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.datautil.signalproc import (bandpass_cnt,
                                             exponential_running_standardize)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne

log = logging.getLogger(__name__)

# Running Shallow Model

In [2]:
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                    level=logging.DEBUG, stream=sys.stdout)
# Should contain both .gdf files and .mat-labelfiles from competition
data_folder = '/home/david/data/BCICIV_2a_gdf/'
subject_id = 1 # 1-9
low_cut_hz = 4 # 0 or 4
model = 'shallow' #'shallow' or 'deep'
cuda = True

# Explain

In [3]:
ival = [-500, 4000]
input_time_length = 1125
max_epochs = 800
max_increase_epochs = 80
batch_size = 60
high_cut_hz = 38
factor_new = 1e-3
init_block_size = 1000
valid_set_fraction = 0.2

train_filename = 'A{:02d}T.gdf'.format(subject_id)
test_filename = 'A{:02d}E.gdf'.format(subject_id)
train_filepath = os.path.join(data_folder, train_filename)
test_filepath = os.path.join(data_folder, test_filename)
train_label_filepath = train_filepath.replace('.gdf', '.mat')
test_label_filepath = test_filepath.replace('.gdf', '.mat')

train_loader = BCICompetition4Set2A(
    train_filepath, labels_filename=train_label_filepath)
test_loader = BCICompetition4Set2A(
    test_filepath, labels_filename=test_label_filepath)
train_cnt = train_loader.load()
test_cnt = test_loader.load()

# Preprocessing

train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left',
                                     'EOG-central', 'EOG-right'])
assert len(train_cnt.ch_names) == 22
# lets convert to millvolt for numerical stability of next operations
train_cnt = mne_apply(lambda a: a * 1e6, train_cnt)
train_cnt = mne_apply(
    lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, train_cnt.info['sfreq'],
                           filt_order=3,
                           axis=1), train_cnt)
train_cnt = mne_apply(
    lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                              init_block_size=init_block_size,
                                              eps=1e-4).T,
    train_cnt)

test_cnt = test_cnt.drop_channels(['STI 014', 'EOG-left',
                                   'EOG-central', 'EOG-right'])
assert len(test_cnt.ch_names) == 22
test_cnt = mne_apply(lambda a: a * 1e6, test_cnt)
test_cnt = mne_apply(
    lambda a: bandpass_cnt(a, low_cut_hz, high_cut_hz, test_cnt.info['sfreq'],
                           filt_order=3,
                           axis=1), test_cnt)
test_cnt = mne_apply(
    lambda a: exponential_running_standardize(a.T, factor_new=factor_new,
                                              init_block_size=init_block_size,
                                              eps=1e-4).T,
    test_cnt)

marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                          ('Foot', [3]), ('Tongue', [4])])

train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)

train_set, valid_set = split_into_two_sets(
    train_set, first_set_fraction=1-valid_set_fraction)

set_random_seeds(seed=20190706, cuda=cuda)

n_classes = 4
n_chans = int(train_set.X.shape[1])
if model == 'shallow':
    model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
                        final_conv_length=30).create_network()
elif model == 'deep':
    model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
                        final_conv_length=2).create_network()


to_dense_prediction_model(model)
if cuda:
    model.cuda()

Extracting EDF parameters from /home/david/data/BCICIV_2a_gdf/A01T.gdf...
GDF file detected
Overlapping events detected. Use find_edf_events for the original events.
Setting channel info structure...
Interpolating stim channel. Events may jitter.
Creating raw.info structure...
Channel names are not unique, found duplicates for: {'EEG'}. Applying running numbers for duplicates.
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


Extracting EDF parameters from /home/david/data/BCICIV_2a_gdf/A01E.gdf...
GDF file detected
Overlapping events detected. Use find_edf_events for the original events.
Setting channel info structure...
Interpolating stim channel. Events may jitter.
Creating raw.info structure...
Channel names are not unique, found duplicates for: {'EEG'}. Applying running numbers for duplicates.
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')
  raw_edf = mne.io.read_raw_edf(self.filename, stim_channel='auto')


2018-10-01 19:52:06,144 INFO : Trial per class:
Counter({'Tongue': 72, 'Foot': 72, 'Right Hand': 72, 'Left Hand': 72})
2018-10-01 19:52:06,377 INFO : Trial per class:
Counter({'Left Hand': 72, 'Right Hand': 72, 'Foot': 72, 'Tongue': 72})


In [4]:
train_set.X.shape

(230, 22, 1125)

In [5]:
test_set.X.shape

(288, 22, 1125)

In [6]:
log.info("Model: \n{:s}".format(str(model)))
dummy_input = np_to_var(train_set.X[:1, :, :, None])
if cuda:
    dummy_input = dummy_input.cuda()
out = model(dummy_input)

n_preds_per_input = out.cpu().data.numpy().shape[2]

optimizer = optim.Adam(model.parameters())

iterator = CropsFromTrialsIterator(batch_size=batch_size,
                                   input_time_length=input_time_length,
                                   n_preds_per_input=n_preds_per_input)

stop_criterion = Or([MaxEpochs(max_epochs),
                     NoDecrease('valid_misclass', max_increase_epochs)])

monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
            CroppedTrialMisclassMonitor(
                input_time_length=input_time_length), RuntimeMonitor()]

model_constraint = MaxNormDefaultConstraint()

loss_function = lambda preds, targets: F.nll_loss(
    th.mean(preds, dim=2, keepdim=False), targets)

exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator,
                 loss_function=loss_function, optimizer=optimizer,
                 model_constraint=model_constraint,
                 monitors=monitors,
                 stop_criterion=stop_criterion,
                 remember_best_column='valid_misclass',
                 run_after_early_stop=True, cuda=cuda)

2018-10-01 19:52:08,153 INFO : Model: 
Sequential(
  (dimshuffle): Expression(expression=_transpose_time_to_spat)
  (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
  (conv_spat): Conv2d(40, 40, kernel_size=(1, 22), stride=(1, 1), bias=False)
  (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv_nonlin): Expression(expression=square)
  (pool): AvgPool2d(kernel_size=(75, 1), stride=(1, 1), padding=0)
  (pool_nonlin): Expression(expression=safe_log)
  (drop): Dropout(p=0.5)
  (conv_classifier): Conv2d(40, 4, kernel_size=(30, 1), stride=(1, 1), dilation=(15, 1))
  (softmax): LogSoftmax()
  (squeeze): Expression(expression=_squeeze_final_output)
)


In [7]:
print("test feature_vars shape")
print(exp.datasets['test'].X.shape)

test feature_vars shape
(288, 22, 1125)


In [8]:
print("test target_vars")
print(exp.datasets['test'].y)

test target_vars
[0 1 1 0 1 0 1 2 1 3 0 2 1 0 3 3 3 3 3 0 2 1 0 0 2 3 0 2 2 2 0 1 0 1 1 0 1
 2 1 2 2 3 2 2 3 3 3 3 3 2 1 0 0 1 2 3 1 2 0 0 0 3 1 1 0 0 2 0 1 3 3 2 0 3
 3 1 3 3 1 0 1 2 2 2 3 2 0 3 1 2 1 2 3 1 2 0 0 0 3 1 0 2 0 2 1 3 0 2 2 0 2
 1 3 3 3 2 0 3 1 3 1 0 2 1 0 2 2 0 2 3 3 1 0 1 3 1 3 2 1 1 1 2 3 0 1 3 0 2
 2 3 0 0 2 1 3 3 3 1 0 2 1 3 0 3 2 1 3 3 0 1 1 2 3 1 0 0 3 1 0 2 1 1 2 0 3
 2 2 2 2 0 1 0 1 0 0 2 2 1 2 3 0 3 0 0 1 3 2 1 3 2 3 2 3 1 1 3 0 1 1 1 2 3
 0 3 0 2 0 3 0 2 0 1 2 2 3 0 1 3 1 2 2 0 3 1 3 0 0 2 2 1 3 1 1 0 1 3 3 1 1
 1 1 3 3 2 3 0 1 2 1 0 3 0 3 0 0 0 0 2 2 3 1 2 2 2 3 2 0 2]


In [9]:
print("test target_vars shape:", exp.datasets['test'].y.shape)

test target_vars shape: (288,)


In [10]:
print("batch_size:", exp.iterator.batch_size)
print("input_time_length:", exp.iterator.input_time_length)
print("n_preds_per_input:", exp.iterator.n_preds_per_input)

batch_size: 60
input_time_length: 1125
n_preds_per_input: 592


In [11]:
n_receptive_field = exp.iterator.input_time_length - exp.iterator.n_preds_per_input + 1
i_trial_starts = [n_receptive_field - 1] * len(exp.datasets['test'].X)
i_trial_stops = [trial.shape[1] for trial in exp.datasets['test'].X]

print("n_receptive_field:", n_receptive_field)

n_receptive_field: 534


In [12]:
all_actual = []
dataset = 'test'
for batch in exp.iterator.get_batches(exp.datasets[dataset], shuffle=False):

    print("feature_vars shape:", np.array(batch[0]).shape)
    print("target_vars shape:", np.array(batch[1]).shape)
    
    feature_vars = np_to_var(batch[0], pin_memory = exp.pin_memory)
    target_vars = np_to_var(batch[1], pin_memory = exp.pin_memory)

    if exp.cuda:
        feature_vars = feature_vars.cuda()
        target_vars = target_vars.cuda()

    if hasattr(target_vars, 'cpu'):
        target_vars = target_vars.cpu().data.numpy()
    else:
        # assume it is iterable
        target_vars = [o.cpu().data.numpy() for o in target_vars]
    
    all_actual += target_vars.tolist()

feature_vars shape: (58, 22, 1125, 1)
target_vars shape: (58,)
feature_vars shape: (58, 22, 1125, 1)
target_vars shape: (58,)
feature_vars shape: (58, 22, 1125, 1)
target_vars shape: (58,)
feature_vars shape: (57, 22, 1125, 1)
target_vars shape: (57,)
feature_vars shape: (57, 22, 1125, 1)
target_vars shape: (57,)


In [13]:
np.array(all_actual)

array([0, 1, 1, 0, 1, 0, 1, 2, 1, 3, 0, 2, 1, 0, 3, 3, 3, 3, 3, 0, 2, 1,
       0, 0, 2, 3, 0, 2, 2, 2, 0, 1, 0, 1, 1, 0, 1, 2, 1, 2, 2, 3, 2, 2,
       3, 3, 3, 3, 3, 2, 1, 0, 0, 1, 2, 3, 1, 2, 0, 0, 0, 3, 1, 1, 0, 0,
       2, 0, 1, 3, 3, 2, 0, 3, 3, 1, 3, 3, 1, 0, 1, 2, 2, 2, 3, 2, 0, 3,
       1, 2, 1, 2, 3, 1, 2, 0, 0, 0, 3, 1, 0, 2, 0, 2, 1, 3, 0, 2, 2, 0,
       2, 1, 3, 3, 3, 2, 0, 3, 1, 3, 1, 0, 2, 1, 0, 2, 2, 0, 2, 3, 3, 1,
       0, 1, 3, 1, 3, 2, 1, 1, 1, 2, 3, 0, 1, 3, 0, 2, 2, 3, 0, 0, 2, 1,
       3, 3, 3, 1, 0, 2, 1, 3, 0, 3, 2, 1, 3, 3, 0, 1, 1, 2, 3, 1, 0, 0,
       3, 1, 0, 2, 1, 1, 2, 0, 3, 2, 2, 2, 2, 0, 1, 0, 1, 0, 0, 2, 2, 1,
       2, 3, 0, 3, 0, 0, 1, 3, 2, 1, 3, 2, 3, 2, 3, 1, 1, 3, 0, 1, 1, 1,
       2, 3, 0, 3, 0, 2, 0, 3, 0, 2, 0, 1, 2, 2, 3, 0, 1, 3, 1, 2, 2, 0,
       3, 1, 3, 0, 0, 2, 2, 1, 3, 1, 1, 0, 1, 3, 3, 1, 1, 1, 1, 3, 3, 2,
       3, 0, 1, 2, 1, 0, 3, 0, 3, 0, 0, 0, 0, 2, 2, 3, 1, 2, 2, 2, 3, 2,
       0, 2])

In [14]:
np.array(all_actual).shape

(288,)

In [15]:
exp.setup_training()
exp.run_one_epoch(exp.datasets, remember_best=False)

2018-10-01 19:52:09,110 INFO : Time only for training updates: 0.11s
2018-10-01 19:52:09,284 INFO : Epoch 0
2018-10-01 19:52:09,284 INFO : train_loss                6.16400
2018-10-01 19:52:09,284 INFO : valid_loss                6.76005
2018-10-01 19:52:09,285 INFO : test_loss                 6.17513
2018-10-01 19:52:09,285 INFO : train_sample_misclass     0.74288
2018-10-01 19:52:09,285 INFO : valid_sample_misclass     0.77633
2018-10-01 19:52:09,286 INFO : test_sample_misclass      0.74974
2018-10-01 19:52:09,286 INFO : train_misclass            0.74348
2018-10-01 19:52:09,286 INFO : valid_misclass            0.77586
2018-10-01 19:52:09,287 INFO : test_misclass             0.75000
2018-10-01 19:52:09,287 INFO : runtime                   0.00000
2018-10-01 19:52:09,287 INFO : 


In [16]:
batch_generator = exp.iterator.get_batches(exp.datasets['train'],
                                            shuffle=False)
start_train_epoch_time = time.time()

all_targets = []
for inputs, targets in batch_generator:
    if exp.batch_modifier is not None:
        inputs, targets = exp.batch_modifier.process(inputs,
                                                      targets)
    if len(inputs) > 0:
        all_targets += targets.tolist()
    print("inputs shape")
    print(inputs.shape)
    print("targets")
    print(targets)
    print("==========")

inputs shape
(58, 22, 1125, 1)
targets
[3 2 1 0 0 1 2 3 1 2 0 0 0 3 1 1 0 0 2 0 1 3 3 2 0 3 3 1 3 3 1 0 1 2 2 2 3
 2 0 3 1 2 1 2 3 1 2 0 0 0 3 1 0 2 0 2 1 3]
inputs shape
(58, 22, 1125, 1)
targets
[0 2 2 0 2 1 3 3 3 2 0 3 1 3 1 0 2 1 0 2 2 0 2 3 3 1 0 1 3 1 3 2 1 1 1 2 3
 0 1 3 0 2 2 3 0 0 2 1 3 3 3 1 0 2 1 3 0 3]
inputs shape
(57, 22, 1125, 1)
targets
[2 1 3 3 0 1 1 2 3 1 0 0 3 1 0 2 1 1 2 0 3 2 2 2 2 0 1 0 1 0 0 2 2 1 2 3 0
 3 0 0 1 3 2 1 3 2 3 2 3 1 1 3 0 1 1 1 2]
inputs shape
(57, 22, 1125, 1)
targets
[3 0 3 0 2 0 3 0 2 0 1 2 2 3 0 1 3 1 2 2 0 3 1 3 0 0 2 2 1 3 1 1 0 1 3 3 1
 1 1 1 3 3 2 3 0 1 2 1 0 3 0 3 0 0 0 0 2]


In [17]:
np.array(all_targets)

array([3, 2, 1, 0, 0, 1, 2, 3, 1, 2, 0, 0, 0, 3, 1, 1, 0, 0, 2, 0, 1, 3,
       3, 2, 0, 3, 3, 1, 3, 3, 1, 0, 1, 2, 2, 2, 3, 2, 0, 3, 1, 2, 1, 2,
       3, 1, 2, 0, 0, 0, 3, 1, 0, 2, 0, 2, 1, 3, 0, 2, 2, 0, 2, 1, 3, 3,
       3, 2, 0, 3, 1, 3, 1, 0, 2, 1, 0, 2, 2, 0, 2, 3, 3, 1, 0, 1, 3, 1,
       3, 2, 1, 1, 1, 2, 3, 0, 1, 3, 0, 2, 2, 3, 0, 0, 2, 1, 3, 3, 3, 1,
       0, 2, 1, 3, 0, 3, 2, 1, 3, 3, 0, 1, 1, 2, 3, 1, 0, 0, 3, 1, 0, 2,
       1, 1, 2, 0, 3, 2, 2, 2, 2, 0, 1, 0, 1, 0, 0, 2, 2, 1, 2, 3, 0, 3,
       0, 0, 1, 3, 2, 1, 3, 2, 3, 2, 3, 1, 1, 3, 0, 1, 1, 1, 2, 3, 0, 3,
       0, 2, 0, 3, 0, 2, 0, 1, 2, 2, 3, 0, 1, 3, 1, 2, 2, 0, 3, 1, 3, 0,
       0, 2, 2, 1, 3, 1, 1, 0, 1, 3, 3, 1, 1, 1, 1, 3, 3, 2, 3, 0, 1, 2,
       1, 0, 3, 0, 3, 0, 0, 0, 0, 2])

In [18]:
np.array(all_targets).shape

(230,)

In [19]:
exp.datasets['train'].y

array([3, 2, 1, 0, 0, 1, 2, 3, 1, 2, 0, 0, 0, 3, 1, 1, 0, 0, 2, 0, 1, 3,
       3, 2, 0, 3, 3, 1, 3, 3, 1, 0, 1, 2, 2, 2, 3, 2, 0, 3, 1, 2, 1, 2,
       3, 1, 2, 0, 0, 0, 3, 1, 0, 2, 0, 2, 1, 3, 0, 2, 2, 0, 2, 1, 3, 3,
       3, 2, 0, 3, 1, 3, 1, 0, 2, 1, 0, 2, 2, 0, 2, 3, 3, 1, 0, 1, 3, 1,
       3, 2, 1, 1, 1, 2, 3, 0, 1, 3, 0, 2, 2, 3, 0, 0, 2, 1, 3, 3, 3, 1,
       0, 2, 1, 3, 0, 3, 2, 1, 3, 3, 0, 1, 1, 2, 3, 1, 0, 0, 3, 1, 0, 2,
       1, 1, 2, 0, 3, 2, 2, 2, 2, 0, 1, 0, 1, 0, 0, 2, 2, 1, 2, 3, 0, 3,
       0, 0, 1, 3, 2, 1, 3, 2, 3, 2, 3, 1, 1, 3, 0, 1, 1, 1, 2, 3, 0, 3,
       0, 2, 0, 3, 0, 2, 0, 1, 2, 2, 3, 0, 1, 3, 1, 2, 2, 0, 3, 1, 3, 0,
       0, 2, 2, 1, 3, 1, 1, 0, 1, 3, 3, 1, 1, 1, 1, 3, 3, 2, 3, 0, 1, 2,
       1, 0, 3, 0, 3, 0, 0, 0, 0, 2])

In [20]:
exp.datasets['train'].y.shape

(230,)