-
Notifications
You must be signed in to change notification settings - Fork 173
/
plot_skorch_crop_decoding.py
152 lines (127 loc) · 3.78 KB
/
plot_skorch_crop_decoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""
Skorch Crop Decoding
=========================
Example using Skorch for crop decoding on a simpler dataset.
"""
# Authors: Lukas Gemein
# Robin Tibor Schirrmeister
# Alexandre Gramfort
# Maciej Sliwowski
#
# License: BSD-3
import mne
import numpy as np
import torch
from mne.io import concatenate_raws
from torch import optim
from braindecode.classifier import EEGClassifier
from braindecode.datasets.croppedxy import CroppedXyDataset
from braindecode.datautil.splitters import TrainTestSplit
from braindecode.losses import CroppedNLLLoss
from braindecode.models import ShallowFBCSPNet
from braindecode.models.util import to_dense_prediction_model
from braindecode.scoring import CroppedTrialEpochScoring
from braindecode.util import set_random_seeds
subject_id = (
22 # carefully cherry-picked to give nice results on such limited data :)
)
event_codes = [
5,
6,
9,
10,
13,
14,
] # codes for executed and imagined hands/feet
# This will download the files if you don't have them yet,
# and then return the paths to the files.
physionet_paths = mne.datasets.eegbci.load_data(
subject_id, event_codes, update_path=False
)
# Load each of the files
raws = [
mne.io.read_raw_edf(
path, preload=True, stim_channel="auto", verbose="WARNING"
)
for path in physionet_paths
]
# Concatenate them
raw = concatenate_raws(raws)
del raws
# Find the events in this dataset
events, _ = mne.events_from_annotations(raw)
# Use only EEG channels
picks = mne.pick_types(raw.info, meg=False, eeg=True, exclude="bads")
# Extract trials, only using EEG channels
epochs = mne.Epochs(
raw,
events,
event_id=dict(hands_or_left=2, feet_or_right=3),
tmin=1,
tmax=4.1,
proj=False,
picks=picks,
baseline=None,
preload=True,
)
X = (epochs.get_data() * 1e6).astype(np.float32)
y = (epochs.events[:, 2] - 2).astype(np.int64) # 2,3 -> 0,1
del epochs
# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)
n_classes = 2
in_chans = X.shape[1]
set_random_seeds(20200114, cuda=False)
# final_conv_length = auto ensures we only get a single output in the time dimension
model = ShallowFBCSPNet(
in_chans=in_chans,
n_classes=n_classes,
input_time_length=X.shape[2],
final_conv_length="auto",
)
to_dense_prediction_model(model)
if cuda:
model.cuda()
input_time_length = X.shape[2]
# Perform forward pass to determine how many outputs per input
with torch.no_grad():
dummy_input = torch.tensor(X[:1, :, :input_time_length, None], device="cpu")
n_preds_per_input = model(dummy_input).shape[2]
train_set = CroppedXyDataset(X[:70], y[:70],
input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input)
test_set = CroppedXyDataset(X[70:], y=y[70:],
input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input)
cropped_cb_train = CroppedTrialEpochScoring(
"accuracy",
on_train=True,
name="train_trial_accuracy",
lower_is_better=False,
)
cropped_cb_valid = CroppedTrialEpochScoring(
"accuracy",
on_train=False,
name="valid_trial_accuracy",
lower_is_better=False,
)
clf = EEGClassifier(
model,
criterion=CroppedNLLLoss,
optimizer=optim.AdamW,
train_split=TrainTestSplit(
train_size=40,
input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input,),
optimizer__lr=0.0625 * 0.01,
optimizer__weight_decay=0,
batch_size=64,
callbacks=[
("train_trial_accuracy", cropped_cb_train),
("valid_trial_accuracy", cropped_cb_valid),
],
)
clf.fit(train_set, y=None, epochs=4)
clf.predict(test_set)