forked from braindecode/braindecode
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eegconformer.py
431 lines (371 loc) · 13.9 KB
/
eegconformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
# Authors: Yonghao Song <eeyhsong@gmail.com>
#
# License: BSD (3-clause)
import torch
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn, Tensor
import warnings
from .base import EEGModuleMixin, deprecated_args
class EEGConformer(EEGModuleMixin, nn.Module):
"""EEG Conformer.
Convolutional Transformer for EEG decoding.
The paper and original code with more details about the methodological
choices are available at the [Song2022]_ and [ConformerCode]_.
This neural network architecture recieves a traditional braindecode input.
The input shape should be three-dimensional matrix representing the EEG
signals.
`(batch_size, n_channels, n_timesteps)`.
The EEG Conformer architecture is composed of three modules:
- PatchEmbedding
- TransformerEncoder
- ClassificationHead
Notes
-----
The authors recommend using data augmentation before using Conformer,
e.g. sementation and recombination,
Please refer to the original paper and code for more details.
The model was initially tuned on 4 seconds of 250 Hz data.
Please adjust the scale of the temporal convolutional layer,
and the pooling layer for better performance.
.. versionadded:: 0.8
We aggregate the parameters based on the parts of the models, or
when the parameters were used first, e.g. n_filters_time.
Parameters
----------
n_filters_time: int
Number of temporal filters, defines also embedding size.
filter_time_length: int
Length of the temporal filter.
pool_time_length: int
Length of temporal pooling filter.
pool_time_stride: int
Length of stride between temporal pooling filters.
drop_prob: float
Dropout rate of the convolutional layer.
att_depth: int
Number of self-attention layers.
att_heads: int
Number of attention heads.
att_drop_prob: float
Dropout rate of the self-attention layer.
final_fc_length: int | str
The dimension of the fully connected layer.
return_features: bool
If True, the forward method returns the features before the
last classification layer. Defaults to False.
n_classes :
Alias for n_outputs.
n_channels :
Alias for n_chans.
input_window_samples :
Alias for n_times.
References
----------
.. [Song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
conformer: Convolutional transformer for EEG decoding and visualization.
IEEE Transactions on Neural Systems and Rehabilitation Engineering,
31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
.. [ConformerCode] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
conformer: Convolutional transformer for EEG decoding and visualization.
https://github.com/eeyhsong/EEG-Conformer.
"""
def __init__(
self,
n_outputs=None,
n_chans=None,
n_filters_time=40,
filter_time_length=25,
pool_time_length=75,
pool_time_stride=15,
drop_prob=0.5,
att_depth=6,
att_heads=10,
att_drop_prob=0.5,
final_fc_length=2440,
return_features=False,
n_times=None,
chs_info=None,
input_window_seconds=None,
sfreq=None,
n_classes=None,
n_channels=None,
input_window_samples=None,
add_log_softmax=True,
):
n_outputs, n_chans, n_times = deprecated_args(
self,
('n_classes', 'n_outputs', n_classes, n_outputs),
('n_channels', 'n_chans', n_channels, n_chans),
('input_window_samples', 'n_times', input_window_samples, n_times)
)
super().__init__(
n_outputs=n_outputs,
n_chans=n_chans,
chs_info=chs_info,
n_times=n_times,
input_window_seconds=input_window_seconds,
sfreq=sfreq,
add_log_softmax=add_log_softmax,
)
self.mapping = {
'classification_head.fc.6.weight': 'final_layer.final_layer.0.weight',
'classification_head.fc.6.bias': 'final_layer.final_layer.0.bias'
}
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
del n_classes, n_channels, input_window_samples
if not (self.n_chans <= 64):
warnings.warn("This model has only been tested on no more " +
"than 64 channels. no guarantee to work with " +
"more channels.", UserWarning)
self.patch_embedding = _PatchEmbedding(
n_filters_time=n_filters_time,
filter_time_length=filter_time_length,
n_channels=self.n_chans,
pool_time_length=pool_time_length,
stride_avg_pool=pool_time_stride,
drop_prob=drop_prob)
if final_fc_length == "auto":
assert self.n_times is not None
final_fc_length = self.get_fc_size()
self.transformer = _TransformerEncoder(
att_depth=att_depth,
emb_size=n_filters_time,
att_heads=att_heads,
att_drop=att_drop_prob)
self.fc = _FullyConnected(
final_fc_length=final_fc_length)
self.final_layer = _FinalLayer(n_classes=self.n_outputs,
return_features=return_features,
add_log_softmax=self.add_log_softmax)
def forward(self, x: Tensor) -> Tensor:
x = torch.unsqueeze(x, dim=1) # add one extra dimension
x = self.patch_embedding(x)
x = self.transformer(x)
x = fc(x)
x = self.final_layer(x)
return x
def get_fc_size(self):
out = self.patch_embedding(torch.ones((1, 1,
self.n_chans,
self.n_times)))
size_embedding_1 = out.cpu().data.numpy().shape[1]
size_embedding_2 = out.cpu().data.numpy().shape[2]
return size_embedding_1 * size_embedding_2
class _PatchEmbedding(nn.Module):
"""Patch Embedding.
The authors used a convolution module to capture local features,
instead of position embedding.
Parameters
----------
n_filters_time: int
Number of temporal filters, defines also embedding size.
filter_time_length: int
Length of the temporal filter.
n_channels: int
Number of channels to be used as number of spatial filters.
pool_time_length: int
Length of temporal poling filter.
stride_avg_pool: int
Length of stride between temporal pooling filters.
drop_prob: float
Dropout rate of the convolutional layer.
Returns
-------
x: torch.Tensor
The output tensor of the patch embedding layer.
"""
def __init__(
self,
n_filters_time,
filter_time_length,
n_channels,
pool_time_length,
stride_avg_pool,
drop_prob,
):
super().__init__()
self.shallownet = nn.Sequential(
nn.Conv2d(1, n_filters_time,
(1, filter_time_length), (1, 1)),
nn.Conv2d(n_filters_time, n_filters_time,
(n_channels, 1), (1, 1)),
nn.BatchNorm2d(num_features=n_filters_time),
nn.ELU(),
nn.AvgPool2d(
kernel_size=(1, pool_time_length),
stride=(1, stride_avg_pool)
),
# pooling acts as slicing to obtain 'patch' along the
# time dimension as in ViT
nn.Dropout(p=drop_prob),
)
self.projection = nn.Sequential(
nn.Conv2d(
n_filters_time, n_filters_time, (1, 1), stride=(1, 1)
), # transpose, conv could enhance fiting ability slightly
Rearrange("b d_model 1 seq -> b seq d_model"),
)
def forward(self, x: Tensor) -> Tensor:
x = self.shallownet(x)
x = self.projection(x)
return x
class _MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
queries = rearrange(
self.queries(x), "b n (h d) -> b h n d", h=self.num_heads
)
keys = rearrange(
self.keys(x), "b n (h d) -> b h n d", h=self.num_heads
)
values = rearrange(
self.values(x), "b n (h d) -> b h n d", h=self.num_heads
)
energy = torch.einsum("bhqd, bhkd -> bhqk", queries, keys)
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum("bhal, bhlv -> bhav ", att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
class _ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class _FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, drop_p):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
class _TransformerEncoderBlock(nn.Sequential):
def __init__(self, emb_size, att_heads, att_drop, forward_expansion=4):
super().__init__(
_ResidualAdd(
nn.Sequential(
nn.LayerNorm(emb_size),
_MultiHeadAttention(emb_size, att_heads, att_drop),
nn.Dropout(att_drop),
)
),
_ResidualAdd(
nn.Sequential(
nn.LayerNorm(emb_size),
_FeedForwardBlock(
emb_size, expansion=forward_expansion,
drop_p=att_drop
),
nn.Dropout(att_drop),
)
),
)
class _TransformerEncoder(nn.Sequential):
"""Transformer encoder module for the transformer encoder.
Similar to the layers used in ViT.
Parameters
----------
att_depth : int
Number of transformer encoder blocks.
emb_size : int
Embedding size of the transformer encoder.
att_heads : int
Number of attention heads.
att_drop : float
Dropout probability for the attention layers.
"""
def __init__(self, att_depth, emb_size, att_heads, att_drop):
super().__init__(
*[
_TransformerEncoderBlock(emb_size, att_heads, att_drop)
for _ in range(att_depth)
]
)
class _FullyConnected(nn.Module):
def __init__(self, final_fc_length,
drop_prob_1=0.5, drop_prob_2=0.3, out_channels=256,
hidden_channels=32):
"”””Fully-connected layer for the transformer encoder.
Parameters
----------
final_fc_length : int
Length of the final fully connected layer.
n_classes : int
Number of classes for classification.
drop_prob_1 : float
Dropout probability for the first dropout layer.
drop_prob_2 : float
Dropout probability for the second dropout layer.
out_channels : int
Number of output channels for the first linear layer.
hidden_channels : int
Number of output channels for the second linear layer.
return_features : bool
Whether to return input features.
add_log_softmax: bool
Whether to add LogSoftmax non-linearity as the final layer.
"""
super().__init__()
self.fc = nn.Sequential(
nn.Linear(final_fc_length, out_channels),
nn.ELU(),
nn.Dropout(drop_prob_1),
nn.Linear(out_channels, hidden_channels),
nn.ELU(),
nn.Dropout(drop_prob_2),
)
def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
return out
class _FinalLayer(nn.Module):
def __init__(self, n_classes, hidden_channels=32, return_features=False, add_log_softmax=True):
""""Classification head for the transformer encoder.
Parameters
----------
n_classes : int
Number of classes for classification.
hidden_channels : int
Number of output channels for the second linear layer.
return_features : bool
Whether to return input features.
add_log_softmax : bool
Adding LogSoftmax or not.
"""
super().__init__()
self.final_layer = nn.Sequential(
nn.Linear(hidden_channels, n_classes),
)
self.return_features = return_features
if add_log_softmax:
self.classification = nn.LogSoftmax(dim=1)
else:
self.classification = nn.Identity()
def forward(self, x):
if self.return_features:
out = self.final_layer(x)
return out, x
else:
self.final_layer.add_module('classification', self.classification)
out = self.final_layer(x)
return out