In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""
Cropped Decoding on BCIC IV 2a Competition Set
==============================================

"""

# Authors: Maciej Sliwowski
#          Robin Tibor Schirrmeister
#
# License: BSD-3

import os.path
from collections import OrderedDict

import torch
from torch import optim
from torch.utils.data import Dataset
import numpy as np

from braindecode.datasets.bcic_iv_2a import BCICompetition4Set2A
from braindecode.datautil import CropsDataLoader
from braindecode.datautil.signalproc import (
    bandpass_cnt,
    exponential_running_standardize,
)
from braindecode.datautil.trial_segment import create_signal_target_from_raw_mne
from braindecode.classifier import EEGClassifier
from braindecode.scoring import CroppedTrialEpochScoring
from braindecode.mne_ext.signalproc import mne_apply
from braindecode.models.deep4 import Deep4Net
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model
from braindecode.util import set_random_seeds
from braindecode.losses import CroppedNLLLoss

In [None]:
from braindecode.datasets import MOABBDataset
from braindecode.datautil.windowers import EventWindower

In [None]:
subject_id = 1  # 1-9
low_cut_hz = 4  # 0 or 4
model_name = "shallow"  # 'shallow' or 'deep'
cuda = torch.cuda.is_available()
#cuda = False
if cuda:
    device = 'cuda'
else:
    device = 'cpu'
    

ival = [-500, 4000]
input_time_length = 1000
max_epochs = 5
max_increase_epochs = 80
batch_size = 60
high_cut_hz = 38
factor_new = 1e-3
init_block_size = 1000
valid_set_fraction = 0.2

In [None]:
set_random_seeds(seed=20190706, cuda=cuda)

n_classes = 4
n_chans = 22
if model_name == "shallow":
    model = ShallowFBCSPNet(
        n_chans,
        n_classes,
        input_time_length=input_time_length,
        final_conv_length=30,
    )
elif model_name == "deep":
    model = Deep4Net(
        n_chans,
        n_classes,
        input_time_length=input_time_length,
        final_conv_length=2,
    )

to_dense_prediction_model(model)

if cuda:
    model.cuda()

with torch.no_grad():
    dummy_input = torch.tensor(
        np.ones((1, n_chans, input_time_length, 1), dtype=np.float32),
        device=device
    )
    n_preds_per_input = model(dummy_input).shape[2]


In [None]:
dataset = MOABBDataset('BNCI2014001',subject=subject_id,
            raw_transformer=None,
            windower=EventWindower(window_size_samples=1000,
                                   stride_samples=n_preds_per_input,
                                   drop_last_samples=False,
                                   tmin=0),
            transform_online=True)

In [None]:
class TrainTestRunSplit(object):
    def __init__(self, train_runs):
        assert isinstance(train_runs, (int, float))
        self.train_runs = train_runs

    def __call__(self, dataset, y, **kwargs):
        # can we directly use this https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
        # or stick to same API
        if isinstance(self.train_runs, int):
            n_train_runs = self.train_runs
        else:
            n_train_runs = int(self.train_runs * len(dataset))

        n_train_windows = dataset.cumulative_sizes[n_train_runs]
        n_total_windows = len(dataset)
        return (
            torch.utils.data.Subset(dataset,np.arange(n_train_windows)),
            torch.utils.data.Subset(dataset,np.arange(n_train_windows, n_total_windows)),
        )

In [None]:
train_set, valid_set = TrainTestRunSplit(5)(dataset, y=None)

In [None]:
len(train_set)

In [None]:
from skorch.callbacks import ProgressBar

In [None]:
cropped_cb_train = CroppedTrialEpochScoring(
    "accuracy",
    name="train_trial_accuracy",
    lower_is_better=False,
    on_train=True,
    input_time_length=input_time_length,
)

cropped_cb_valid = CroppedTrialEpochScoring(
    "accuracy",
    on_train=False,
    name="valid_trial_accuracy",
    lower_is_better=False,
    input_time_length=input_time_length,
)
# MaxNormDefaultConstraint and early stopping should be added to repeat previous braindecode

clf = EEGClassifier(
    model,
    criterion=CroppedNLLLoss,
    optimizer=optim.AdamW,
    train_split=TrainTestRunSplit(train_runs=5),
    optimizer__lr=0.0625 * 0.01,
    optimizer__weight_decay=0,
    batch_size=32,
    callbacks=[
        ("train_trial_accuracy", cropped_cb_train),
        ("valid_trial_accuracy", cropped_cb_valid),
        ("progress_bar", ProgressBar()),
    ],
    device=device,
    
)

clf.fit(dataset, y=None, epochs=20)

In [None]:
%debug

In [None]:
n_preds_per_input

In [None]:
1125 / n_preds_per_input

In [None]:
%debug

In [None]:

expected = [{'batches': [{'train_loss': 2.0750885009765625, 'train_batch_size': 32},
   {'train_loss': 3.09424090385437, 'train_batch_size': 32},
   {'train_loss': 1.079931616783142, 'train_batch_size': 32},
   {'valid_loss': 2.320780038833618, 'valid_batch_size': 24}],
  'epoch': 1,
  'train_batch_count': 3,
  'valid_batch_count': 1,
  'train_loss': 2.0830870072046914,
  'train_loss_best': True,
  'valid_loss': 2.320780038833618,
  'valid_loss_best': True,
  'train_trial_accuracy': 0.5,
  'train_trial_accuracy_best': True,
  'valid_trial_accuracy': 0.5,
  'valid_trial_accuracy_best': True},
 {'batches': [{'train_loss': 1.7862337827682495, 'train_batch_size': 32},
   {'train_loss': 1.410051941871643, 'train_batch_size': 32},
   {'train_loss': 1.1569499969482422, 'train_batch_size': 32},
   {'valid_loss': 1.4905306100845337, 'valid_batch_size': 24}],
  'epoch': 2,
  'train_batch_count': 3,
  'valid_batch_count': 1,
  'train_loss': 1.4510785738627117,
  'train_loss_best': True,
  'valid_loss': 1.4905306100845337,
  'valid_loss_best': True,
  'train_trial_accuracy': 0.5,
  'train_trial_accuracy_best': False,
  'valid_trial_accuracy': 0.5,
  'valid_trial_accuracy_best': False},
 {'batches': [{'train_loss': 1.1232541799545288, 'train_batch_size': 32},
   {'train_loss': 2.304981231689453, 'train_batch_size': 32},
   {'train_loss': 0.9293400049209595, 'train_batch_size': 32},
   {'valid_loss': 2.455669641494751, 'valid_batch_size': 24}],
  'epoch': 3,
  'train_batch_count': 3,
  'valid_batch_count': 1,
  'train_loss': 1.4525251388549805,
  'train_loss_best': False,
  'valid_loss': 2.455669641494751,
  'valid_loss_best': False,
  'train_trial_accuracy': 0.5,
  'train_trial_accuracy_best': False,
  'valid_trial_accuracy': 0.5,
  'valid_trial_accuracy_best': False},
 {'batches': [{'train_loss': 1.241913080215454, 'train_batch_size': 32},
   {'train_loss': 1.1696765422821045, 'train_batch_size': 32},
   {'train_loss': 0.9132626056671143, 'train_batch_size': 32},
   {'valid_loss': 0.9064457416534424, 'valid_batch_size': 24}],
  'epoch': 4,
  'train_batch_count': 3,
  'valid_batch_count': 1,
  'train_loss': 1.1082840760548909,
  'train_loss_best': True,
  'valid_loss': 0.9064457416534424,
  'valid_loss_best': True,
  'train_trial_accuracy': 0.5,
  'train_trial_accuracy_best': False,
  'valid_trial_accuracy': 0.5,
  'valid_trial_accuracy_best': False}]




history_without_dur = [{k: v for k,v in h.items() if k != 'dur'}
                       for h in clf.history]


assert_deep_allclose(history_without_dur, expected)

In [None]:
history_without_dur