In [2]:
import mne
from mne.io import concatenate_raws

# 5,6,7,10,13,14 are codes for executed and imagined hands/feet
subject_id = 1
event_codes = [5, 6, 9, 10, 13, 14]

# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)

# Load each of the files
parts = [mne.io.read_raw_edf(path, preload=True, stim_channel='auto',
                             verbose='WARNING')
         for path in physionet_paths]

# Concatenate them
raw = concatenate_raws(parts)

# Find the events in this dataset
events, _ = mne.events_from_annotations(raw)

# Use only EEG channels
eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False,
                                  eog=False,
                                  exclude='bads')

# Extract trials, only using EEG channels
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1,
                     proj=False, picks=eeg_channel_inds,
                     baseline=None, preload=True)

Used Annotations descriptions: ['T0', 'T1', 'T2']
90 matching events found
No baseline correction applied
Not setting metadata
Loading data for 90 events and 497 original time points ...
0 bad epochs dropped


In [15]:
import numpy as np
from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.datautil.splitters import split_into_two_sets
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:, 2] - 2).astype(np.int64)  # 2,3 -> 0,1

train_set = SignalAndTarget(X[:60], y=y[:60])
test_set = SignalAndTarget(X[60:], y=y[60:])

# train_set, valid_set = split_into_two_sets(train_set,
#                                            first_set_fraction=0.8)
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 450
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
                        input_time_length=input_time_length,
                        final_conv_length=12).create_network()
to_dense_prediction_model(model)

if cuda:
    model.cuda()


In [6]:
from torch import optim

optimizer = optim.Adam(model.parameters())

from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(
    np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
if cuda:
    test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))


187 predictions per input/trial


In [7]:
from braindecode.experiments.experiment import Experiment
from braindecode.datautil.iterators import CropsFromTrialsIterator
from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, \
    CroppedTrialMisclassMonitor, MisclassMonitor
from braindecode.experiments.stopcriteria import MaxEpochs
import torch.nn.functional as F
import torch as th
from braindecode.torch_ext.modules import Expression

In [25]:
# Iterator is used to iterate over datasets both for training
# and evaluation
iterator = CropsFromTrialsIterator(batch_size=32,
                                   input_time_length=input_time_length,
                                   n_preds_per_input=n_preds_per_input)

# Loss function takes predictions as they come out of the network and the targets
# and returns a loss
class LossFunction:
    def __call__(self, preds, targets):
        return F.nll_loss(th.mean(preds, dim=2, keepdim=False), targets)
# loss_function = lambda preds, targets: F.nll_loss(
#     th.mean(preds, dim=2, keepdim=False), targets)

In [None]:
# Could be used to apply some constraint on the models, then should be object
# with apply method that accepts a module
model_constraint = None
# Monitors log the training progress
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
            CroppedTrialMisclassMonitor(input_time_length),
            RuntimeMonitor(), ]
# Stop criterion determines when the first stop happens
stop_criterion = MaxEpochs(4)
exp = Experiment(model, train_set, valid_set, test_set, iterator,
                 loss_function, optimizer, model_constraint,
                 monitors, stop_criterion,
                 remember_best_column='valid_misclass',
                 run_after_early_stop=True, batch_modifier=None, cuda=cuda)


In [16]:
from skorch.classifier import NeuralNet
from skorch.dataset import CVSplit

In [12]:
from torch.utils.data import Dataset

In [50]:
class EEGDataSet(Dataset):
    def __init__(self, X, y):
        self.X = X.reshape(-1, 64, 497, 1)
        self.y = y
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [54]:
from skorch.callbacks import EpochScoring

In [51]:
train_set = EEGDataSet(train_set.X, train_set.y)
valid_set = EEGDataSet(valid_set.X, valid_set.y)
test_set = EEGDataSet(test_set.X, test_set.y)

In [None]:
def my_score(net, X=None, y=None):
    losses = net.history[-1, 'batches', :, 'my_score']
    batch_sizes = net.history[-1, 'batches', :, 'valid_batch_size']
    return np.average(losses, weights=batch_sizes)

In [68]:
skorch_model = NeuralNet(model, LossFunction, optim.Adam, train_split=CVSplit(cv=0.5))
#, callbacks=[('accuracy', EpochScoring('accuracy', lower_is_better=False))])

In [69]:
skorch_model.fit(train_set)

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1        [36m1.6107[0m        [32m1.1428[0m  0.8415
      2        [36m0.7609[0m        1.9469  0.7551
      3        0.8753        2.3047  0.6911
      4        0.7719        2.5271  0.7515
      5        [36m0.5634[0m        2.3662  0.8819
      6        [36m0.4354[0m        2.0589  0.6537
      7        [36m0.3999[0m        1.7215  0.7479
      8        [36m0.3861[0m        1.4272  0.7418
      9        [36m0.3513[0m        1.2497  0.8589
     10        [36m0.2934[0m        1.1590  0.7061


<class 'skorch.net.NeuralNet'>[initialized](
  module_=Sequential(
    (dimshuffle): Expression(expression=_transpose_time_to_spat)
    (conv_time): Conv2d(1, 40, kernel_size=(25, 1), stride=(1, 1))
    (conv_spat): Conv2d(40, 40, kernel_size=(1, 64), stride=(1, 1), bias=False)
    (bnorm): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv_nonlin): Expression(expression=square)
    (pool): AvgPool2d(kernel_size=(75, 1), stride=(1, 1), padding=0)
    (pool_nonlin): Expression(expression=safe_log)
    (drop): Dropout(p=0.5, inplace=False)
    (conv_classifier): Conv2d(40, 2, kernel_size=(12, 1), stride=(1, 1), dilation=(15, 1))
    (softmax): LogSoftmax()
    (squeeze): Expression(expression=_squeeze_final_output)
  ),
)

In [103]:
F.softmax(torch.Tensor(np.mean(skorch_model.predict(test_set.X), axis=2, keepdims=False)), dim=1).numpy()

array([[9.63512123e-01, 3.64879258e-02],
       [5.23885548e-01, 4.76114452e-01],
       [9.90216255e-01, 9.78377461e-03],
       [7.01407790e-01, 2.98592210e-01],
       [9.64998662e-01, 3.50013748e-02],
       [9.72419262e-01, 2.75806896e-02],
       [9.99376595e-01, 6.23407948e-04],
       [9.57187772e-01, 4.28122766e-02],
       [9.91731703e-01, 8.26831348e-03],
       [8.05356026e-01, 1.94643915e-01],
       [9.55760002e-01, 4.42399830e-02],
       [9.43263650e-01, 5.67363240e-02],
       [9.64158475e-01, 3.58414575e-02],
       [9.88989055e-01, 1.10110017e-02],
       [3.28940988e-01, 6.71059012e-01],
       [1.61786646e-01, 8.38213384e-01],
       [6.85525835e-01, 3.14474195e-01],
       [3.44549865e-01, 6.55450106e-01],
       [9.95098531e-01, 4.90145013e-03],
       [9.56322134e-01, 4.36779149e-02],
       [8.98032606e-01, 1.01967342e-01],
       [9.87291157e-01, 1.27088986e-02],
       [9.64792788e-01, 3.52071822e-02],
       [8.82885337e-01, 1.17114596e-01],
       [9.487329

In [102]:
F.softmax(torch.Tensor(np.mean(skorch_model.predict(test_set.X), axis=2, keepdims=False)), dim=1).numpy()

array([[9.63512123e-01, 3.64879258e-02],
       [5.23885548e-01, 4.76114452e-01],
       [9.90216255e-01, 9.78377461e-03],
       [7.01407790e-01, 2.98592210e-01],
       [9.64998662e-01, 3.50013748e-02],
       [9.72419262e-01, 2.75806896e-02],
       [9.99376595e-01, 6.23407948e-04],
       [9.57187772e-01, 4.28122766e-02],
       [9.91731703e-01, 8.26831348e-03],
       [8.05356026e-01, 1.94643915e-01],
       [9.55760002e-01, 4.42399830e-02],
       [9.43263650e-01, 5.67363240e-02],
       [9.64158475e-01, 3.58414575e-02],
       [9.88989055e-01, 1.10110017e-02],
       [3.28940988e-01, 6.71059012e-01],
       [1.61786646e-01, 8.38213384e-01],
       [6.85525835e-01, 3.14474195e-01],
       [3.44549865e-01, 6.55450106e-01],
       [9.95098531e-01, 4.90145013e-03],
       [9.56322134e-01, 4.36779149e-02],
       [8.98032606e-01, 1.01967342e-01],
       [9.87291157e-01, 1.27088986e-02],
       [9.64792788e-01, 3.52071822e-02],
       [8.82885337e-01, 1.17114596e-01],
       [9.487329

In [96]:
F.softmax(torch.Tensor(np.mean(skorch_model.predict(test_set.X), axis=2, keepdims=False)[:, 0]))

  """Entry point for launching an IPython kernel.


tensor([0.0398, 0.0213, 0.0411, 0.0212, 0.0400, 0.0401, 0.0415, 0.0368, 0.0412,
        0.0290, 0.0375, 0.0377, 0.0400, 0.0410, 0.0106, 0.0065, 0.0239, 0.0131,
        0.0404, 0.0395, 0.0351, 0.0409, 0.0397, 0.0334, 0.0360, 0.0411, 0.0403,
        0.0263, 0.0404, 0.0247])

In [91]:
F.sigmoid(torch.Tensor([1]))



tensor([0.7311])

In [67]:
skorch_model.history[-1]

{'batches': [{'train_loss': 1.6362584829330444, 'train_batch_size': 30},
  {'valid_loss': 1.9583241939544678, 'valid_batch_size': 30}],
 'epoch': 1,
 'train_batch_count': 1,
 'valid_batch_count': 1,
 'dur': 0.7669932842254639,
 'train_loss': 1.6362584829330444,
 'train_loss_best': True,
 'valid_loss': 1.9583241939544678,
 'valid_loss_best': True}