forked from braindecode/braindecode
-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
312 lines (278 loc) · 10.9 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
# Authors: Pierre Guetschel
# Maciej Sliwowski
#
# License: BSD-3
import warnings
from typing import Dict, Iterable, List, Optional, Tuple
from collections import OrderedDict
import numpy as np
import torch
from docstring_inheritance import NumpyDocstringInheritanceInitMeta
from torchinfo import ModelStatistics, summary
def deprecated_args(obj, *old_new_args):
out_args = []
for old_name, new_name, old_val, new_val in old_new_args:
if old_val is None:
out_args.append(new_val)
else:
warnings.warn(
f'{obj.__class__.__name__}: {old_name!r} is depreciated. Use {new_name!r} instead.'
)
if new_val is not None:
raise ValueError(
f'{obj.__class__.__name__}: Both {old_name!r} and {new_name!r} were specified.'
)
out_args.append(old_val)
return out_args
class EEGModuleMixin(metaclass=NumpyDocstringInheritanceInitMeta):
"""
Mixin class for all EEG models in braindecode.
Parameters
----------
n_outputs : int
Number of outputs of the model. This is the number of classes
in the case of classification.
n_chans : int
Number of EEG channels.
chs_info : list of dict
Information about each individual EEG channel. This should be filled with
``info["chs"]``. Refer to :class:`mne.Info` for more details.
n_times : int
Number of time samples of the input window.
input_window_seconds : float
Length of the input window in seconds.
sfreq : float
Sampling frequency of the EEG recordings.
add_log_softmax: bool
Whether to use log-softmax non-linearity as the output function.
LogSoftmax final layer will be removed in the future.
Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!
Check the documentation of the torch.nn loss functions:
https://pytorch.org/docs/stable/nn.html#loss-functions.
Raises
------
ValueError: If some input signal-related parameters are not specified
and can not be inferred.
FutureWarning: If add_log_softmax is True, since LogSoftmax final layer
will be removed in the future.
Notes
-----
If some input signal-related parameters are not specified,
there will be an attempt to infer them from the other parameters.
"""
def __init__(
self,
n_outputs: Optional[int] = None,
n_chans: Optional[int] = None,
chs_info: Optional[List[Dict]] = None,
n_times: Optional[int] = None,
input_window_seconds: Optional[float] = None,
sfreq: Optional[float] = None,
add_log_softmax: Optional[bool] = False,
):
if (
n_chans is not None and
chs_info is not None and
len(chs_info) != n_chans
):
raise ValueError(f'{n_chans=} different from {chs_info=} length')
if (
n_times is not None and
input_window_seconds is not None and
sfreq is not None and
n_times != int(input_window_seconds * sfreq)
):
raise ValueError(
f'{n_times=} different from '
f'{input_window_seconds=} * {sfreq=}'
)
self._n_outputs = n_outputs
self._n_chans = n_chans
self._chs_info = chs_info
self._n_times = n_times
self._input_window_seconds = input_window_seconds
self._sfreq = sfreq
self._add_log_softmax = add_log_softmax
super().__init__()
@property
def n_outputs(self):
if self._n_outputs is None:
raise ValueError('n_outputs not specified.')
return self._n_outputs
@property
def n_chans(self):
if self._n_chans is None and self._chs_info is not None:
return len(self._chs_info)
elif self._n_chans is None:
raise ValueError(
'n_chans could not be inferred. Either specify n_chans or chs_info.'
)
return self._n_chans
@property
def chs_info(self):
if self._chs_info is None:
raise ValueError('chs_info not specified.')
return self._chs_info
@property
def n_times(self):
if (
self._n_times is None and
self._input_window_seconds is not None and
self._sfreq is not None
):
return int(self._input_window_seconds * self._sfreq)
elif self._n_times is None:
raise ValueError(
'n_times could not be inferred. '
'Either specify n_times or input_window_seconds and sfreq.'
)
return self._n_times
@property
def input_window_seconds(self):
if (
self._input_window_seconds is None and
self._n_times is not None and
self._sfreq is not None
):
return self._n_times / self._sfreq
elif self._input_window_seconds is None:
raise ValueError(
'input_window_seconds could not be inferred. '
'Either specify input_window_seconds or n_times and sfreq.'
)
return self._input_window_seconds
@property
def sfreq(self):
if (
self._sfreq is None and
self._input_window_seconds is not None and
self._n_times is not None
):
return self._n_times / self._input_window_seconds
elif self._sfreq is None:
raise ValueError(
'sfreq could not be inferred. '
'Either specify sfreq or input_window_seconds and n_times.'
)
return self._sfreq
@property
def add_log_softmax(self):
if self._add_log_softmax:
warnings.warn("LogSoftmax final layer will be removed! " +
"Please adjust your loss function accordingly (e.g. CrossEntropyLoss)!")
return self._add_log_softmax
@property
def input_shape(self) -> Tuple[int]:
"""Input data shape."""
return (1, self.n_chans, self.n_times)
def get_output_shape(self) -> Tuple[int]:
"""Returns shape of neural network output for batch size equal 1.
Returns
-------
output_shape: Tuple[int]
shape of the network output for `batch_size==1` (1, ...)
"""
with torch.inference_mode():
try:
return tuple(self.forward(
torch.zeros(
self.input_shape,
dtype=next(self.parameters()).dtype,
device=next(self.parameters()).device
)).shape)
except RuntimeError as exc:
if str(exc).endswith(
("Output size is too small",
"Kernel size can't be greater than actual input size")
):
msg = (
"During model prediction RuntimeError was thrown showing that at some "
f"layer `{str(exc).split('.')[-1]}` (see above in the stacktrace). This "
"could be caused by providing too small `n_times`/`input_window_seconds`. "
"Model may require longer chunks of signal in the input than "
f"{self.input_shape}."
)
raise ValueError(msg) from exc
raise exc
mapping = None
def load_state_dict(self, state_dict, *args, **kwargs):
mapping = self.mapping if self.mapping else {}
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k in mapping:
new_state_dict[mapping[k]] = v
else:
new_state_dict[k] = v
return super().load_state_dict(new_state_dict, *args, **kwargs)
def to_dense_prediction_model(self, axis: Tuple[int] = (2, 3)) -> None:
"""
Transform a sequential model with strides to a model that outputs
dense predictions by removing the strides and instead inserting dilations.
Modifies model in-place.
Parameters
----------
axis: int or (int,int)
Axis to transform (in terms of intermediate output axes)
can either be 2, 3, or (2,3).
Notes
-----
Does not yet work correctly for average pooling.
Prior to version 0.1.7, there had been a bug that could move strides
backwards one layer.
"""
if not hasattr(axis, "__len__"):
axis = [axis]
assert all([ax in [2, 3] for ax in axis]), "Only 2 and 3 allowed for axis"
axis = np.array(axis) - 2
stride_so_far = np.array([1, 1])
for module in self.modules():
if hasattr(module, "dilation"):
assert module.dilation == 1 or (module.dilation == (1, 1)), (
"Dilation should equal 1 before conversion, maybe the model is "
"already converted?"
)
new_dilation = [1, 1]
for ax in axis:
new_dilation[ax] = int(stride_so_far[ax])
module.dilation = tuple(new_dilation)
if hasattr(module, "stride"):
if not hasattr(module.stride, "__len__"):
module.stride = (module.stride, module.stride)
stride_so_far *= np.array(module.stride)
new_stride = list(module.stride)
for ax in axis:
new_stride[ax] = 1
module.stride = tuple(new_stride)
def get_torchinfo_statistics(
self,
col_names: Optional[Iterable[str]] = (
"input_size",
"output_size",
"num_params",
"kernel_size",
),
row_settings: Optional[Iterable[str]] = ("var_names", "depth"),
) -> ModelStatistics:
"""Generate table describing the model using torchinfo.summary.
Parameters
----------
col_names : tuple, optional
Specify which columns to show in the output, see torchinfo for details, by default
("input_size", "output_size", "num_params", "kernel_size")
row_settings : tuple, optional
Specify which features to show in a row, see torchinfo for details, by default
("var_names", "depth")
Returns
-------
torchinfo.ModelStatistics
ModelStatistics generated by torchinfo.summary.
"""
return summary(
self,
input_size=(1, self.n_chans, self.n_times),
col_names=col_names,
row_settings=row_settings,
verbose=0,
)
def __str__(self) -> str:
return str(self.get_torchinfo_statistics())