Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Normalise models BCI #488

Merged
merged 73 commits into from Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
73 commits
Select commit Hold shift + click to select a range
cff263f
Add EEGModuleMixin
PierreGtch Sep 4, 2023
1011bf9
Add tests for EEGModuleMixin
PierreGtch Sep 4, 2023
b7dd55b
Handle double inheritance
PierreGtch Sep 4, 2023
a93f44d
Update braindecode/models/base.py
PierreGtch Sep 4, 2023
f63ae63
Fix lines too long
PierreGtch Sep 4, 2023
fdd624a
Replace channel_names by ch_names
PierreGtch Sep 4, 2023
b322373
Replace input_window_samples by n_times
PierreGtch Sep 4, 2023
283ed3b
Add n_outputs parameter
PierreGtch Sep 4, 2023
7f0b7ff
Remove test class
PierreGtch Sep 4, 2023
3891b10
Merge branch 'master' into normalise-models
bruAristimunha Sep 5, 2023
060d042
Add docstring_inheritance to environment.yml
PierreGtch Sep 4, 2023
88fc5e8
Rename n_channels to n_chans
PierreGtch Sep 4, 2023
09139d3
Test dummy submodule initialisation
PierreGtch Sep 5, 2023
0c69fee
Add depreciated_args helper
PierreGtch Sep 5, 2023
f790c8e
Add docstring_inheritance to requirements.txt
PierreGtch Sep 5, 2023
850fa29
Update eegnet.py
PierreGtch Sep 5, 2023
4b0d0fb
Add docstring_inheritance to setup.py
PierreGtch Sep 5, 2023
f43ff12
Use ValueError instead of AttributeError to escape pytorch's __getattr__
PierreGtch Sep 5, 2023
692ce25
Update whats_new.rst
PierreGtch Sep 5, 2023
837192a
Fix docstring
PierreGtch Sep 5, 2023
217be7e
Add parameter to docstring
PierreGtch Sep 5, 2023
ec556e4
Merge branch 'master' into normalise-models
bruAristimunha Sep 5, 2023
d239384
Update braindecode/models/base.py
PierreGtch Sep 5, 2023
c134f0b
Update docs/whats_new.rst
PierreGtch Sep 5, 2023
8db9496
Use else instead of continue in deprecated_args
PierreGtch Sep 5, 2023
d9f1173
Update deprecated_args usage
PierreGtch Sep 5, 2023
61883e8
Update atcnet.py
PierreGtch Sep 5, 2023
c4f52e2
Update deep4.py
PierreGtch Sep 5, 2023
c98eafd
Add all parameters explicitly to submodules
PierreGtch Sep 5, 2023
ea89239
Fix atcnet.py
PierreGtch Sep 5, 2023
8fe1d75
Fix deep4.py
PierreGtch Sep 5, 2023
40bf357
Fix eegnet.py
PierreGtch Sep 5, 2023
2ac55ec
Fix docstring base
PierreGtch Sep 5, 2023
a3ddafe
Fix base using wrong inheritance class
PierreGtch Sep 5, 2023
c0fa3f5
Update deepsleepnet.py
PierreGtch Sep 5, 2023
3436424
Update eegconformer.py
PierreGtch Sep 5, 2023
d8a05b4
Merge branch 'master' into normalise-models
PierreGtch Sep 5, 2023
cf064d3
Merge branch 'master' into normalise-models
PierreGtch Sep 5, 2023
d258688
Replace AttributeError by ValueError in docstring
PierreGtch Sep 6, 2023
c9cc6ae
Add EEGModuleMixin to doc
PierreGtch Sep 6, 2023
379347f
Add description for deprecated parameters
PierreGtch Sep 6, 2023
f35795c
Update eeginception.py
PierreGtch Sep 6, 2023
aa3b645
Update eeginception_erp.py and eeginception_mi.py
PierreGtch Sep 6, 2023
ca2ae8c
Update eegitnet.py
PierreGtch Sep 6, 2023
b32affc
Update eegresnet.py
PierreGtch Sep 6, 2023
153901c
Update shallow_fbcsp.py
PierreGtch Sep 6, 2023
f4d17aa
Update sleep_stager_blanco_2020.py
PierreGtch Sep 6, 2023
3376dca
Update sleep_stager_chambon_2018.py
PierreGtch Sep 6, 2023
9162d1d
Update sleep_stager_eldele_2021.py
PierreGtch Sep 6, 2023
975b68b
Update tcn.py
PierreGtch Sep 6, 2023
112eb85
Update tidnet.py
PierreGtch Sep 6, 2023
af049c5
Update usleep.py
PierreGtch Sep 6, 2023
e3723b7
Restore default EEGInception
PierreGtch Sep 6, 2023
42cca58
Update plot_sleep_staging_chambon2018.py
PierreGtch Sep 6, 2023
66d84ad
Update plot_sleep_staging_usleep.py
PierreGtch Sep 6, 2023
1d404ee
Update plot_sleep_staging_eldele2021.py
PierreGtch Sep 6, 2023
85491c9
Update plot_relative_positioning.py
PierreGtch Sep 6, 2023
781e5b5
Merge branch 'master' into normalise-models
PierreGtch Sep 6, 2023
3b0a9ee
Fix plot_sleep_staging_usleep.py
PierreGtch Sep 6, 2023
5e5842c
added model summary with torchinfo as __str__ method in EEGModuleMixin
sliwy Sep 6, 2023
daf97fb
added printing model tables into examples, change to return torchinfo…
sliwy Sep 6, 2023
7c03fad
added tests for generating table description from torchinfo
sliwy Sep 6, 2023
4116d83
added torchinfo to requirements
sliwy Sep 6, 2023
9dd1778
added torchinfo to conda yml and setup.py
sliwy Sep 6, 2023
86a9445
Rename args to old_new_args
PierreGtch Sep 6, 2023
2eeb4cc
fix python3.8 failing due to | in typing
sliwy Sep 6, 2023
24b9fe1
flake8 cleaning
sliwy Sep 6, 2023
8f1545e
Merge pull request #1 from sliwy/add_models_viz
PierreGtch Sep 6, 2023
3c3b36f
Merge branch 'master' into normalise-models
PierreGtch Sep 6, 2023
ff64449
Merge branch 'master' into normalise-models
PierreGtch Sep 7, 2023
a9323ac
Change ch_names to chs_info
PierreGtch Sep 7, 2023
1d53fca
Merge branch 'master' into normalise-models
PierreGtch Sep 7, 2023
f4877ef
Merge branch 'master' into normalise-models
PierreGtch Sep 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
167 changes: 95 additions & 72 deletions braindecode/models/atcnet.py
Expand Up @@ -8,20 +8,17 @@
from einops.layers.torch import Rearrange

from .modules import Ensure4d, MaxNormLinear, CausalConv1d
from .base import EEGModuleMixin, deprecated_args


class ATCNet(nn.Module):
class ATCNet(EEGModuleMixin, nn.Module):
"""ATCNet model from [1]_

Pytorch implementation based on official tensorflow code [2]_.

Parameters
----------
n_channels : int
Number of EEG channels.
n_classes : int
Number of target classes.
input_size_s : float, optional
input_window_seconds : float, optional
Time length of inputs, in secods. Defaults to 4.5 s, as in BCI-IV 2a
dataset.
sfreq : int, optional
Expand Down Expand Up @@ -79,12 +76,18 @@ class ATCNet(nn.Module):
concat : bool
When ``True``, concatenates each slidding window embedding before
feeding it to a fully-connected layer, as done in [1]_. When ``False``,
maps each slidding window to `n_classes` logits and average them.
maps each slidding window to `n_outputs` logits and average them.
Defaults to ``False`` contrary to what is reported in [1]_, but
matching what the official code does [2]_.
max_norm_const : float
Maximum L2-norm constraint imposed on weights of the last
fully-connected layer. Defaults to 0.25.
n_channels:
Alias for n_chans.
n_classes:
Alias for n_outputs.
input_size_s:
Alias for input_window_seconds.

References
----------
Expand All @@ -94,36 +97,53 @@ class ATCNet(nn.Module):
2022, doi: 10.1109/TII.2022.3197419.
.. [2] https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
"""

def __init__(
self,
n_channels,
n_classes,
input_size_s=4.5,
sfreq=250,
conv_block_n_filters=16,
conv_block_kernel_length_1=64,
conv_block_kernel_length_2=16,
conv_block_pool_size_1=8,
conv_block_pool_size_2=7,
conv_block_depth_mult=2,
conv_block_dropout=0.3,
n_windows=5,
att_head_dim=8,
att_num_heads=2,
att_dropout=0.5,
tcn_depth=2,
tcn_kernel_size=4,
tcn_n_filters=32,
tcn_dropout=0.3,
tcn_activation=nn.ELU(),
concat=False,
max_norm_const=0.25,
self,
n_chans=None,
n_outputs=None,
input_window_seconds=4.5,
sfreq=250.,
conv_block_n_filters=16,
conv_block_kernel_length_1=64,
conv_block_kernel_length_2=16,
conv_block_pool_size_1=8,
conv_block_pool_size_2=7,
conv_block_depth_mult=2,
conv_block_dropout=0.3,
n_windows=5,
att_head_dim=8,
att_num_heads=2,
att_dropout=0.5,
tcn_depth=2,
tcn_kernel_size=4,
tcn_n_filters=32,
tcn_dropout=0.3,
tcn_activation=nn.ELU(),
concat=False,
max_norm_const=0.25,
chs_info=None,
n_times=None,
n_channels=None,
n_classes=None,
input_size_s=None,
):
super().__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.input_size_s = input_size_s
self.sfreq = sfreq
n_chans, n_outputs, input_window_seconds = deprecated_args(
self,
('n_channels', 'n_chans', n_channels, n_chans),
('n_classes', 'n_outputs', n_classes, n_outputs),
('input_size_s', 'input_window_seconds', input_size_s, input_window_seconds),
)
super().__init__(
n_outputs=n_outputs,
n_chans=n_chans,
chs_info=chs_info,
n_times=n_times,
input_window_seconds=input_window_seconds,
sfreq=sfreq,
)
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
del n_channels, n_classes, input_size_s
self.conv_block_n_filters = conv_block_n_filters
self.conv_block_kernel_length_1 = conv_block_kernel_length_1
self.conv_block_kernel_length_2 = conv_block_kernel_length_2
Expand All @@ -149,7 +169,7 @@ def __init__(
self.dimshuffle = Rearrange("batch C T 1 -> batch 1 T C")

self.conv_block = _ConvBlock(
n_channels=n_channels, # input shape: (batch_size, 1, T, C)
n_channels=self.n_chans, # input shape: (batch_size, 1, T, C)
n_filters=conv_block_n_filters,
kernel_length_1=conv_block_kernel_length_1,
kernel_length_2=conv_block_kernel_length_2,
Expand All @@ -160,8 +180,8 @@ def __init__(
)

self.F2 = int(conv_block_depth_mult * conv_block_n_filters)
self.Tc = int(input_size_s * sfreq / (
conv_block_pool_size_1 * conv_block_pool_size_2))
self.Tc = int(self.input_window_seconds * self.sfreq / (
conv_block_pool_size_1 * conv_block_pool_size_2))
self.Tw = self.Tc - self.n_windows + 1

self.attention_blocks = nn.ModuleList([
Expand All @@ -181,7 +201,7 @@ def __init__(
n_filters=tcn_n_filters,
dropout=tcn_dropout,
activation=tcn_activation,
dilation=2**i
dilation=2 ** i
) for i in range(tcn_depth)]
) for _ in range(self.n_windows)
])
Expand All @@ -190,15 +210,15 @@ def __init__(
self.max_norm_linears = nn.ModuleList([
MaxNormLinear(
in_features=self.F2 * self.n_windows,
out_features=self.n_classes,
out_features=self.n_outputs,
max_norm_val=self.max_norm_const
)
])
else:
self.max_norm_linears = nn.ModuleList([
MaxNormLinear(
in_features=self.F2,
out_features=self.n_classes,
out_features=self.n_outputs,
max_norm_val=self.max_norm_const
) for _ in range(self.n_windows)
])
Expand Down Expand Up @@ -270,16 +290,17 @@ class _ConvBlock(nn.Module):
Brain-Computer Interfaces.
arXiv preprint arXiv:1611.08024.
"""

def __init__(
self,
n_channels,
n_filters=16,
kernel_length_1=64,
kernel_length_2=16,
pool_size_1=8,
pool_size_2=7,
depth_mult=2,
dropout=0.3,
self,
n_channels,
n_filters=16,
kernel_length_1=64,
kernel_length_2=16,
pool_size_1=8,
pool_size_2=7,
depth_mult=2,
dropout=0.3,
):
super().__init__()

Expand Down Expand Up @@ -368,12 +389,13 @@ class _AttentionBlock(nn.Module):
.. [2] Vaswani, A. et al., "Attention is all you need",
in Advances in neural information processing systems, 2017.
"""

def __init__(
self,
in_shape=32,
head_dim=8,
num_heads=2,
dropout=0.5,
self,
in_shape=32,
head_dim=8,
num_heads=2,
dropout=0.5,
):
super().__init__()
self.in_shape = in_shape
Expand Down Expand Up @@ -442,14 +464,15 @@ class _TCNResidualBlock(nn.Module):
"An empirical evaluation of generic convolutional and recurrent
networks for sequence modeling", 2018.
"""

def __init__(
self,
in_channels,
kernel_size=4,
n_filters=32,
dropout=0.3,
activation=nn.ELU(),
dilation=1
self,
in_channels,
kernel_size=4,
n_filters=32,
dropout=0.3,
activation=nn.ELU(),
dilation=1
):
super().__init__()
self.activation = activation
Expand Down Expand Up @@ -516,12 +539,12 @@ def forward(self, X):

class _MHA(nn.Module):
def __init__(
self,
input_dim: int,
head_dim: int,
output_dim: int,
num_heads: int,
dropout: float = 0.,
self,
input_dim: int,
head_dim: int,
output_dim: int,
num_heads: int,
dropout: float = 0.,
):
"""Multi-head Attention

Expand Down Expand Up @@ -564,10 +587,10 @@ def __init__(
self.dropout = nn.Dropout(dropout)

def forward(
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor
self,
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor
) -> torch.Tensor:
""" Compute MHA(Q, K, V)

Expand Down