In [1]:
%load_ext autoreload
%autoreload 2
import os
os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')

# Using the Experiment Class

Braindecode provides a convenience `Experiment` class, which removes the necessity to write your own training loop. It expects a training, a validation and a test set and trains as follows:
1. Train on training set until a given stop criterion is fulfilled
2. Reset to the best epoch, i.e. reset parameters of the model and the optimizer to the state at the best epoch ("best" according to a given criterion) 
3. Continue training on the combined training + validation set until the loss on the validation set is as low as it was on the best epoch for the training set. (or until the ConvNet was trained twice as many epochs as the best epoch to prevent infinite training)

<div class='alert alert-warning'>

It is not necessary to use the Experiment class to use other functionality of Braindecode. Feel free to ignore it :)

</div>

## Load data

In [2]:
%%capture
import mne

physionet_paths = mne.datasets.eegbci.load_data(1, [5,6,9,10,13,14])

parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
         for path in physionet_paths]

from mne.io import concatenate_raws

raw = concatenate_raws(parts)
eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')
events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1,
                     proj=False, picks=eeg_channel_inds,
                     baseline=None, preload=True)

## Convert data to Braindecode Format

In [3]:
import numpy as np
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1

from braindecode.datautil.signal_target import SignalAndTarget
from braindecode.datautil.splitters import split_into_two_sets

train_set = SignalAndTarget(X[:60], y=y[:60])
test_set = SignalAndTarget(X[60:], y=y[60:])

train_set, valid_set = split_into_two_sets(train_set, first_set_fraction=0.8)


## Create the model

In [4]:
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds, to_dense_prediction_model

# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)

# This will determine how many crops are processed in parallel
input_time_length = 450
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=64, n_classes=2, input_time_length=input_time_length,
                        final_conv_length=12).create_network()
to_dense_prediction_model(model)

if cuda:
    model.cuda()

In [5]:
from torch import optim

optimizer = optim.Adam(model.parameters())

In [6]:
from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(np.ones((2, 64, input_time_length, 1), dtype=np.float32))
if cuda:
    test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))

187 predictions per input/trial


In [7]:
from braindecode.datautil.iterators import CropsFromTrialsIterator
iterator = CropsFromTrialsIterator(batch_size=32,input_time_length=input_time_length,
                                  n_preds_per_input=n_preds_per_input)

In [8]:
from braindecode.experiments.experiment import Experiment

In [9]:
from braindecode.experiments.monitors import RuntimeMonitor, LossMonitor, CroppedTrialMisclassMonitor, MisclassMonitor
from braindecode.experiments.stopcriteria import MaxEpochs
import torch.nn.functional as F
import torch as th
from braindecode.torch_ext.modules import Expression


loss_function = lambda preds, targets: F.nll_loss(th.mean(preds, dim=2)[:,:,0], targets)

model_constraint = None
monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'),
            CroppedTrialMisclassMonitor(input_time_length), RuntimeMonitor(),]
stop_criterion = MaxEpochs(20)
exp = Experiment(model, train_set, valid_set, test_set, iterator, loss_function, optimizer, model_constraint,
          monitors, stop_criterion, remember_best_column='valid_misclass',
          run_after_early_stop=True, batch_modifier=None, cuda=cuda)

In [10]:
import logging
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.DEBUG, stream=sys.stdout)
exp.run()

2017-07-03 13:49:54,684 INFO : Run until first stop...
2017-07-03 13:49:55,020 INFO : Epoch 0
2017-07-03 13:49:55,022 INFO : train_loss                0.98100
2017-07-03 13:49:55,024 INFO : valid_loss                0.75838
2017-07-03 13:49:55,025 INFO : test_loss                 0.76847
2017-07-03 13:49:55,027 INFO : train_sample_misclass     0.47839
2017-07-03 13:49:55,028 INFO : valid_sample_misclass     0.39416
2017-07-03 13:49:55,030 INFO : test_sample_misclass      0.39439
2017-07-03 13:49:55,032 INFO : train_misclass            0.50000
2017-07-03 13:49:55,033 INFO : valid_misclass            0.41667
2017-07-03 13:49:55,035 INFO : test_misclass             0.43333
2017-07-03 13:49:55,036 INFO : runtime                   0.00000
2017-07-03 13:49:55,038 INFO : 
2017-07-03 13:49:55,040 INFO : New best valid_misclass: 0.416667
2017-07-03 13:49:55,041 INFO : 
2017-07-03 13:49:56,538 INFO : Epoch 1
2017-07-03 13:49:56,544 INFO : train_loss                3.05195
2017-07-03 13:49:56,546

In this case, we again arrive at 76.6% misclass, the training stops after the validation loss decreases below the training loss at the best epoch of 0.03722.

## Dataset References


 This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:
 
     Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.

[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:

    Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220.