From 4bf83054a8246e4882b1bf5b35437e943ab309da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Fri, 4 Nov 2022 21:26:28 +0100 Subject: [PATCH 01/13] Rename EEGInception for ERP --- braindecode/models/__init__.py | 2 +- braindecode/models/{eeginception.py => eeginception_erp.py} | 5 ++--- docs/api.rst | 2 +- test/unit_tests/models/test_models.py | 6 +++--- 4 files changed, 7 insertions(+), 8 deletions(-) rename braindecode/models/{eeginception.py => eeginception_erp.py} (98%) diff --git a/braindecode/models/__init__.py b/braindecode/models/__init__.py index 59cf43a82..c084ccf34 100644 --- a/braindecode/models/__init__.py +++ b/braindecode/models/__init__.py @@ -7,7 +7,7 @@ from .hybrid import HybridNet from .shallow_fbcsp import ShallowFBCSPNet from .eegresnet import EEGResNet -from .eeginception import EEGInception +from .eeginception_erp import EEGInceptionERP from .tcn import TCN from .sleep_stager_chambon_2018 import SleepStagerChambon2018 from .sleep_stager_blanco_2020 import SleepStagerBlanco2020 diff --git a/braindecode/models/eeginception.py b/braindecode/models/eeginception_erp.py similarity index 98% rename from braindecode/models/eeginception.py rename to braindecode/models/eeginception_erp.py index 8365ee187..0aa5d54ea 100644 --- a/braindecode/models/eeginception.py +++ b/braindecode/models/eeginception_erp.py @@ -14,10 +14,9 @@ def _transpose_to_b_1_c_0(x): return x.permute(0, 3, 1, 2) -class EEGInception(nn.Sequential): - """EEG Inception. +class EEGInceptionERP(nn.Sequential): + """EEG Inception for ERP-based classification - EEG Inception for ERP-based classification described in [Santamaria2020]_. The code for the paper and this model is also available at [Santamaria2020]_ and an adaptation for PyTorch [2]_. diff --git a/docs/api.rst b/docs/api.rst index ed965c49d..7e6883f5f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -56,7 +56,7 @@ Models ShallowFBCSPNet Deep4Net - EEGInception + EEGInceptionERP EEGITNet EEGNetv1 EEGNetv4 diff --git a/test/unit_tests/models/test_models.py b/test/unit_tests/models/test_models.py index b79e95361..5cd12c1c3 100644 --- a/test/unit_tests/models/test_models.py +++ b/test/unit_tests/models/test_models.py @@ -13,7 +13,7 @@ from braindecode.models import ( Deep4Net, EEGNetv4, EEGNetv1, HybridNet, ShallowFBCSPNet, EEGResNet, TCN, SleepStagerChambon2018, SleepStagerBlanco2020, SleepStagerEldele2021, USleep, - EEGITNet, EEGInception, TIDNet) + EEGITNet, EEGInceptionERP, TIDNet) from braindecode.util import set_random_seeds @@ -122,7 +122,7 @@ def test_eegitnet(input_sizes): def test_eeginception(input_sizes): - model = EEGInception( + model = EEGInceptionERP( n_classes=input_sizes['n_classes'], in_channels=input_sizes['n_channels'], input_window_samples=input_sizes['n_in_times']) @@ -134,7 +134,7 @@ def test_eeginception_n_params(): """Make sure the number of parameters is the same as in the paper when using the same architecture hyperparameters. """ - model = EEGInception( + model = EEGInceptionERP( in_channels=8, n_classes=2, input_window_samples=128, # input_time From 5f47e9f27a2c9a4dc3d37126e4c8b449693530c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Fri, 4 Nov 2022 21:27:08 +0100 Subject: [PATCH 02/13] New EEGInception for motor imagery --- braindecode/models/__init__.py | 2 +- braindecode/models/eeginception_mi.py | 271 ++++++++++++++++++++++++++ 2 files changed, 272 insertions(+), 1 deletion(-) create mode 100644 braindecode/models/eeginception_mi.py diff --git a/braindecode/models/__init__.py b/braindecode/models/__init__.py index c084ccf34..7bff63df1 100644 --- a/braindecode/models/__init__.py +++ b/braindecode/models/__init__.py @@ -8,6 +8,7 @@ from .shallow_fbcsp import ShallowFBCSPNet from .eegresnet import EEGResNet from .eeginception_erp import EEGInceptionERP +from .eeginception_mi import EEGInceptionMI from .tcn import TCN from .sleep_stager_chambon_2018 import SleepStagerChambon2018 from .sleep_stager_blanco_2020 import SleepStagerBlanco2020 @@ -16,4 +17,3 @@ from .usleep import USleep from .util import get_output_shape, to_dense_prediction_model from .modules import TimeDistributed - diff --git a/braindecode/models/eeginception_mi.py b/braindecode/models/eeginception_mi.py new file mode 100644 index 000000000..bba4a005d --- /dev/null +++ b/braindecode/models/eeginception_mi.py @@ -0,0 +1,271 @@ +# Authors: Cedric Rommel +# +# License: BSD (3-clause) + +import torch +from torch import nn + +from .modules import Expression, Ensure4d +from .functions import transpose_time_to_spat + + +class EEGInceptionMI(nn.Module): + """EEG Inception for Motor Imagery, as proposed in [1]_ + + The model is strongly based on the original InceptionNet for computer + vision. The main goal is to extract features in parallel with different + scales. The network has two blocks made of 3 inception modules with a skip + connection. + + The model is fully described in [1]_. + + Notes + ----- + This implementation is not guaranteed to be correct, has not been checked + by original authors, only reimplemented bosed on the paper [1]_. + + Parameters + ---------- + in_channels : int + Number of EEG channels. + n_classes : int + Number of classes. + input_size_ms : int + Size of the input, in milliseconds. Set to 1000 in [Santamaria2020]_. + sfreq : float + EEG sampling frequency. + drop_prob : float + Dropout rate inside all the network. + scales_time: list(int) + Windows for inception block, must be a list with proportional values of + the input_size_ms. + According to the authors: temporal scale (ms) of the convolutions + on each Inception module. + This parameter determines the kernel sizes of the filters. + n_filters : int + Initial number of convolutional filters. Set to 8 in [Santamaria2020]_. + activation: nn.Module + Activation function, default: ELU activation. + batch_norm_alpha: float + Momentum for BatchNorm2d. + depth_multiplier: int + Depth multiplier for the depthwise convolution. + pooling_sizes: list(int) + Pooling sizes for the inception block. + + References + ---------- + .. [1] Zhang, C., Kim, Y. K., & Eskandarian, A. (2021). + EEG-inception: an accurate and robust end-to-end neural network for + EEG-based motor imagery classification. + Journal of Neural Engineering, 18(4), 046014. + """ + + def __init__( + self, + in_channels, + n_classes, + input_window_samples=750, + sfreq=128, + n_convs=5, + n_filters=48, + kernel_unit_s=0.1, + activation=nn.ReLU(), + ): + super().__init__() + + self.in_channels = in_channels + self.n_classes = n_classes + self.input_window_samples = input_window_samples + self.sfreq = sfreq + self.n_convs = n_convs + self.n_filters = n_filters + self.kernel_unit_s = kernel_unit_s + self.activation = activation + + self.ensuredims = Ensure4d() + self.dimshuffle = Expression(transpose_time_to_spat) + + # ======== Inception branches ======================== + + self.initial_inception_module = _InceptionModuleMI( + in_channels=self.in_channels, + n_filters=self.n_filters, + n_convs=self.n_convs, + kernel_unit_s=self.kernel_unit_s, + sfreq=self.sfreq, + activation=self.activation, + ) + + intermediate_in_channels = (self.n_convs + 1) * self.n_filters + + self.intermediate_inception_modules_1 = nn.ModuleList([ + _InceptionModuleMI( + in_channels=intermediate_in_channels, + n_filters=self.n_filters, + n_convs=self.n_convs, + kernel_unit_s=self.kernel_unit_s, + sfreq=self.sfreq, + activation=self.activation, + ) for _ in range(2) + ]) + + self.residual_block_1 = _ResidualBlockMI( + in_channels=self.in_channels, + n_filters=intermediate_in_channels, + activation=self.activation, + ) + + self.intermediate_inception_modules_2 = nn.ModuleList([ + _InceptionModuleMI( + in_channels=intermediate_in_channels, + n_filters=self.n_filters, + n_convs=self.n_convs, + kernel_unit_s=self.kernel_unit_s, + sfreq=self.sfreq, + activation=self.activation, + ) for _ in range(3) + ]) + + self.residual_block_2 = _ResidualBlockMI( + in_channels=intermediate_in_channels, + n_filters=intermediate_in_channels, + activation=self.activation, + ) + + # XXX The paper mentions a final average pooling but does not indicate + # the kernel size... The only info available is figure1 showing a + # final AveragePooling layer and the table3 indicating the spatial and + # channel dimensions are unchanged by this layer... + # My best guess is they use the same size as the MaxPooling, with + # stride 1 + # self.ave_pooling = nn.AvgPool1d() + + self.fc = nn.Linear( + in_features=self.input_window_samples * intermediate_in_channels, + out_features=self.n_classes, + bias=True, + ) + + def forward( + self, + X: torch.Tensor, + ) -> torch.Tensor: + res1 = self.residual_block_1(X) + + out = self.initial_inception_module(X) + for layer in self.intermediate_inception_modules_1: + out = layer(out) + + out = out + res1 + + res2 = self.residual_block_2(out) + + for layer in self.intermediate_inception_modules_2: + out = layer(out) + + out = res2 + out + + # out = self.ave_pooling(out) + return self.fc(out.flatten()) + + +class _InceptionModuleMI(nn.Module): + def __init__( + self, + in_channels, + n_filters, + n_convs, + kernel_unit_s=0.1, + sfreq=250, + activation=nn.ReLU(), + ): + super().__init__() + self.in_channels = in_channels + self.n_filters = n_filters + self.n_convs = n_convs + self.kernel_unit_s = kernel_unit_s + self.sfreq = sfreq + + self.bottleneck = nn.Conv1d( + in_channels=self.in_channels, + out_channels=self.n_filters, + kernel_size=1, + bias=True, + ) + + kernel_unit = self.kernel_unit_s * self.sfreq + + # XXX I wonder whether stride is correct here. This is how MaxPooling + # is usually used, but table3 in the paper indicate an unchanged + # output shape... Are they using stride=1? + self.pooling = nn.MaxPool1d( + kernel_size=kernel_unit, + stride=kernel_unit, + ) + + self.pooling_conv = nn.Conv1d( + in_channels=self.in_channels, + out_channels=self.n_filters, + kernel_size=1, + bias=True, + ) + + self.conv_list = nn.ModuleList([ + nn.Conv1d( + in_channels=self.n_filters, + out_channels=self.n_filters, + kernel_size=n_units * kernel_unit, + padding="same", + bias=True, + ) for n_units in range(1, self.n_convs + 1) + ]) + + self.bn = nn.BatchNorm1d() + + self.activation = activation + + def forward( + self, + X: torch.Tensor, + ) -> torch.Tensor: + X1 = self.bottleneck(X) + + X1 = [conv(X1) for conv in self.conv_list] + + X2 = self.pooling(X) + X2 = self.pooling_conv(X2) + + out = torch.cat(X1 + [X2], 1) + + out = self.bn(out) + return self.activation(out) + + +class _ResidualBlockMI(nn.Module): + def __init__( + self, + in_channels, + n_filters, + activation=nn.ReLU() + ): + super().__init__() + self.in_channels = in_channels + self.n_filters = n_filters + self.activation = activation + + self.bn = nn.BatchNorm1d() + self.conv = nn.Conv1d( + in_channels=self.in_channels, + out_channels=self.n_filters, + kernel_size=1, + bias=True, + ) + + def forward( + self, + X: torch.Tensor, + ) -> torch.Tensor: + out = self.conv(X) + out = self.bn(out) + return self.activation(out) From 2c4528b5007c77b85636262f51e4bb143b4ec33b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Fri, 4 Nov 2022 21:51:15 +0100 Subject: [PATCH 03/13] Fix docstrings --- braindecode/models/eeginception_erp.py | 49 +++++++++++++------------ braindecode/models/eeginception_mi.py | 50 +++++++++++++------------- 2 files changed, 51 insertions(+), 48 deletions(-) diff --git a/braindecode/models/eeginception_erp.py b/braindecode/models/eeginception_erp.py index 0aa5d54ea..09a85b443 100644 --- a/braindecode/models/eeginception_erp.py +++ b/braindecode/models/eeginception_erp.py @@ -43,28 +43,33 @@ class EEGInceptionERP(nn.Sequential): Number of EEG channels. n_classes : int Number of classes. - input_size_ms : int - Size of the input, in milliseconds. Set to 1000 in [Santamaria2020]_. - sfreq : float - EEG sampling frequency. - drop_prob : float - Dropout rate inside all the network. - scales_time: list(int) - Windows for inception block, must be a list with proportional values of - the input_size_ms. - According to the authors: temporal scale (ms) of the convolutions - on each Inception module. - This parameter determines the kernel sizes of the filters. - n_filters : int - Initial number of convolutional filters. Set to 8 in [Santamaria2020]_. - activation: nn.Module - Activation function, default: ELU activation. - batch_norm_alpha: float - Momentum for BatchNorm2d. - depth_multiplier: int - Depth multiplier for the depthwise convolution. - pooling_sizes: list(int) - Pooling sizes for the inception block. + input_window_samples : int, optional + Size of the input, in number of sampels. Set to 128 (1s) as in + [Santamaria2020]_. + sfreq : float, optional + EEG sampling frequency. Defaults to 128 as in [Santamaria2020]_. + drop_prob : float, optional + Dropout rate inside all the network. Defaults to 0.5 as in + [Santamaria2020]_. + scales_samples_s: list(float), optional + Windows for inception block. Temporal scale (s) of the convolutions on + each Inception module. This parameter determines the kernel sizes of + the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in + [Santamaria2020]_. + n_filters : int, optional + Initial number of convolutional filters. Defaults to 8 as in + [Santamaria2020]_. + activation: nn.Module, optional + Activation function. Defaults to ELU activation as in + [Santamaria2020]_. + batch_norm_alpha: float, optional + Momentum for BatchNorm2d. Defaults to 0.01. + depth_multiplier: int, optional + Depth multiplier for the depthwise convolution. Defaults to 2 as in + [Santamaria2020]_. + pooling_sizes: list(int), optional + Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as + in [Santamaria2020]_. References ---------- diff --git a/braindecode/models/eeginception_mi.py b/braindecode/models/eeginception_mi.py index bba4a005d..7eda4273a 100644 --- a/braindecode/models/eeginception_mi.py +++ b/braindecode/models/eeginception_mi.py @@ -30,28 +30,25 @@ class EEGInceptionMI(nn.Module): Number of EEG channels. n_classes : int Number of classes. - input_size_ms : int - Size of the input, in milliseconds. Set to 1000 in [Santamaria2020]_. - sfreq : float - EEG sampling frequency. - drop_prob : float - Dropout rate inside all the network. - scales_time: list(int) - Windows for inception block, must be a list with proportional values of - the input_size_ms. - According to the authors: temporal scale (ms) of the convolutions - on each Inception module. - This parameter determines the kernel sizes of the filters. - n_filters : int - Initial number of convolutional filters. Set to 8 in [Santamaria2020]_. + input_window_s : float, optional + Size of the input, in seconds. Set to 4.5 s as in [1]_ for dataset + BCI IV 2a. + sfreq : float, optional + EEG sampling frequency in Hz. Defaults to 250 Hz as in [1]_ for dataset + BCI IV 2a. + n_convs : int, optional + Number of convolution per inception wide branching. Defaults to 5 as + in [1]_ for dataset BCI IV 2a. + n_filters : int, optional + Number of convolutional filters for all layers of this type. Set to 48 + as in [1]_ for dataset BCI IV 2a. + kernel_unit_s : float, optional + Size in seconds of the basic 1D convolutional kernel used in inception + modules. Each convolutional layer in such modules have kernels of + increasing size, odd multiples of this value (e.g. 0.1, 0.3, 0.5, 0.7, + 0.9 here for `n_convs`=5). Defaults to 0.1 s. activation: nn.Module - Activation function, default: ELU activation. - batch_norm_alpha: float - Momentum for BatchNorm2d. - depth_multiplier: int - Depth multiplier for the depthwise convolution. - pooling_sizes: list(int) - Pooling sizes for the inception block. + Activation function. Defaults to ReLU activation. References ---------- @@ -65,8 +62,8 @@ def __init__( self, in_channels, n_classes, - input_window_samples=750, - sfreq=128, + input_window_s=4.5, + sfreq=250, n_convs=5, n_filters=48, kernel_unit_s=0.1, @@ -76,7 +73,8 @@ def __init__( self.in_channels = in_channels self.n_classes = n_classes - self.input_window_samples = input_window_samples + self.input_window_s = input_window_s + self.input_window_samples = input_window_s * sfreq self.sfreq = sfreq self.n_convs = n_convs self.n_filters = n_filters @@ -215,10 +213,10 @@ def __init__( nn.Conv1d( in_channels=self.n_filters, out_channels=self.n_filters, - kernel_size=n_units * kernel_unit, + kernel_size=(n_units * 2 + 1) * kernel_unit, padding="same", bias=True, - ) for n_units in range(1, self.n_convs + 1) + ) for n_units in range(self.n_convs) ]) self.bn = nn.BatchNorm1d() From 0924066457b1a562ec718c86490f0f1d01f2a86c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Fri, 4 Nov 2022 23:24:19 +0100 Subject: [PATCH 04/13] Add forward test and fix EEGInceptionMI --- braindecode/models/eeginception_mi.py | 41 +++++++++++++++++---------- test/unit_tests/models/test_models.py | 18 ++++++++++-- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/braindecode/models/eeginception_mi.py b/braindecode/models/eeginception_mi.py index 7eda4273a..558a3f7f0 100644 --- a/braindecode/models/eeginception_mi.py +++ b/braindecode/models/eeginception_mi.py @@ -6,7 +6,6 @@ from torch import nn from .modules import Expression, Ensure4d -from .functions import transpose_time_to_spat class EEGInceptionMI(nn.Module): @@ -74,7 +73,7 @@ def __init__( self.in_channels = in_channels self.n_classes = n_classes self.input_window_s = input_window_s - self.input_window_samples = input_window_s * sfreq + self.input_window_samples = int(input_window_s * sfreq) self.sfreq = sfreq self.n_convs = n_convs self.n_filters = n_filters @@ -82,7 +81,7 @@ def __init__( self.activation = activation self.ensuredims = Ensure4d() - self.dimshuffle = Expression(transpose_time_to_spat) + self.dimshuffle = Expression(_transpose_to_b_c_1_t) # ======== Inception branches ======================== @@ -139,6 +138,7 @@ def __init__( # stride 1 # self.ave_pooling = nn.AvgPool1d() + self.flat = nn.Flatten() self.fc = nn.Linear( in_features=self.input_window_samples * intermediate_in_channels, out_features=self.n_classes, @@ -149,6 +149,9 @@ def forward( self, X: torch.Tensor, ) -> torch.Tensor: + X = self.ensuredims(X) + X = self.dimshuffle(X) + res1 = self.residual_block_1(X) out = self.initial_inception_module(X) @@ -165,7 +168,8 @@ def forward( out = res2 + out # out = self.ave_pooling(out) - return self.fc(out.flatten()) + out = self.flat(out) + return self.fc(out) class _InceptionModuleMI(nn.Module): @@ -185,24 +189,26 @@ def __init__( self.kernel_unit_s = kernel_unit_s self.sfreq = sfreq - self.bottleneck = nn.Conv1d( + self.bottleneck = nn.Conv2d( in_channels=self.in_channels, out_channels=self.n_filters, kernel_size=1, bias=True, ) - kernel_unit = self.kernel_unit_s * self.sfreq + kernel_unit = int(self.kernel_unit_s * self.sfreq) # XXX I wonder whether stride is correct here. This is how MaxPooling # is usually used, but table3 in the paper indicate an unchanged # output shape... Are they using stride=1? - self.pooling = nn.MaxPool1d( - kernel_size=kernel_unit, - stride=kernel_unit, + self.pooling = nn.MaxPool2d( + kernel_size=(1, kernel_unit), + # stride=kernel_unit, + stride=1, + padding=(0, int(kernel_unit // 2)), ) - self.pooling_conv = nn.Conv1d( + self.pooling_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.n_filters, kernel_size=1, @@ -210,16 +216,16 @@ def __init__( ) self.conv_list = nn.ModuleList([ - nn.Conv1d( + nn.Conv2d( in_channels=self.n_filters, out_channels=self.n_filters, - kernel_size=(n_units * 2 + 1) * kernel_unit, + kernel_size=(1, (n_units * 2 + 1) * kernel_unit), padding="same", bias=True, ) for n_units in range(self.n_convs) ]) - self.bn = nn.BatchNorm1d() + self.bn = nn.BatchNorm2d(self.n_filters * (self.n_convs + 1)) self.activation = activation @@ -232,6 +238,7 @@ def forward( X1 = [conv(X1) for conv in self.conv_list] X2 = self.pooling(X) + X2 = X2[..., :-1] # XXX Ugly, but allows to preserve spatial dim... X2 = self.pooling_conv(X2) out = torch.cat(X1 + [X2], 1) @@ -252,8 +259,8 @@ def __init__( self.n_filters = n_filters self.activation = activation - self.bn = nn.BatchNorm1d() - self.conv = nn.Conv1d( + self.bn = nn.BatchNorm2d(self.n_filters) + self.conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.n_filters, kernel_size=1, @@ -267,3 +274,7 @@ def forward( out = self.conv(X) out = self.bn(out) return self.activation(out) + + +def _transpose_to_b_c_1_t(x): + return x.permute(0, 1, 3, 2) diff --git a/test/unit_tests/models/test_models.py b/test/unit_tests/models/test_models.py index 5cd12c1c3..2b4913fbd 100644 --- a/test/unit_tests/models/test_models.py +++ b/test/unit_tests/models/test_models.py @@ -13,7 +13,7 @@ from braindecode.models import ( Deep4Net, EEGNetv4, EEGNetv1, HybridNet, ShallowFBCSPNet, EEGResNet, TCN, SleepStagerChambon2018, SleepStagerBlanco2020, SleepStagerEldele2021, USleep, - EEGITNet, EEGInceptionERP, TIDNet) + EEGITNet, EEGInceptionERP, EEGInceptionMI, TIDNet) from braindecode.util import set_random_seeds @@ -121,7 +121,7 @@ def test_eegitnet(input_sizes): check_forward_pass(model, input_sizes,) -def test_eeginception(input_sizes): +def test_eeginception_erp(input_sizes): model = EEGInceptionERP( n_classes=input_sizes['n_classes'], in_channels=input_sizes['n_channels'], @@ -130,7 +130,7 @@ def test_eeginception(input_sizes): check_forward_pass(model, input_sizes,) -def test_eeginception_n_params(): +def test_eeginception_erp_n_params(): """Make sure the number of parameters is the same as in the paper when using the same architecture hyperparameters. """ @@ -149,6 +149,18 @@ def test_eeginception_n_params(): assert n_params == 14926 # From paper's TABLE IV EEG-Inception Architecture Details +def test_eeginception_mi(input_sizes): + sfreq = 100 + model = EEGInceptionMI( + n_classes=input_sizes['n_classes'], + in_channels=input_sizes['n_channels'], + input_window_s=input_sizes['n_in_times'] / sfreq, + sfreq=100, + ) + + check_forward_pass(model, input_sizes,) + + @pytest.mark.parametrize( "n_channels,sfreq,n_classes,input_size_s", [(20, 128, 5, 30), (10, 256, 4, 20), (1, 64, 2, 30)], From 9a9e9a0f8cf5fd9934ab476dc952ae0d774682e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Mon, 7 Nov 2022 14:28:23 +0100 Subject: [PATCH 05/13] Add EEGInception test for number of params --- braindecode/models/eeginception_mi.py | 29 ++++++++++++++--------- test/unit_tests/models/test_models.py | 33 +++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 13 deletions(-) diff --git a/braindecode/models/eeginception_mi.py b/braindecode/models/eeginception_mi.py index 558a3f7f0..51e0593bd 100644 --- a/braindecode/models/eeginception_mi.py +++ b/braindecode/models/eeginception_mi.py @@ -133,14 +133,21 @@ def __init__( # XXX The paper mentions a final average pooling but does not indicate # the kernel size... The only info available is figure1 showing a # final AveragePooling layer and the table3 indicating the spatial and - # channel dimensions are unchanged by this layer... - # My best guess is they use the same size as the MaxPooling, with - # stride 1 - # self.ave_pooling = nn.AvgPool1d() + # channel dimensions are unchanged by this layer... This could indicate + # a stride=1 as for MaxPooling layers. Howevere, when we look at the + # number of parameters of the linear layer following the average + # pooling, we see a small number of parameters, potentially indicating + # that the whole time dimension is averaged on this stage for each + # channel. We follow this last hypothesis here to comply with the + # number of parameters reported in the paper. + self.ave_pooling = nn.AvgPool2d( + kernel_size=(1, self.input_window_samples), + ) self.flat = nn.Flatten() self.fc = nn.Linear( - in_features=self.input_window_samples * intermediate_in_channels, + # in_features=self.input_window_samples * intermediate_in_channels, + in_features=intermediate_in_channels, out_features=self.n_classes, bias=True, ) @@ -151,6 +158,7 @@ def forward( ) -> torch.Tensor: X = self.ensuredims(X) X = self.dimshuffle(X) + import ipdb; ipdb.set_trace() res1 = self.residual_block_1(X) @@ -167,7 +175,7 @@ def forward( out = res2 + out - # out = self.ave_pooling(out) + out = self.ave_pooling(out) out = self.flat(out) return self.fc(out) @@ -198,12 +206,12 @@ def __init__( kernel_unit = int(self.kernel_unit_s * self.sfreq) - # XXX I wonder whether stride is correct here. This is how MaxPooling - # is usually used, but table3 in the paper indicate an unchanged - # output shape... Are they using stride=1? + # XXX Maxpooling is usually used to reduce spatial resolution, with a + # stride equal to the kernel size... But it seems the authors use + # stride=1 in their paper according to the output shapes from Table3, + # although this is not clearly specified in the paper text. self.pooling = nn.MaxPool2d( kernel_size=(1, kernel_unit), - # stride=kernel_unit, stride=1, padding=(0, int(kernel_unit // 2)), ) @@ -238,7 +246,6 @@ def forward( X1 = [conv(X1) for conv in self.conv_list] X2 = self.pooling(X) - X2 = X2[..., :-1] # XXX Ugly, but allows to preserve spatial dim... X2 = self.pooling_conv(X2) out = torch.cat(X1 + [X2], 1) diff --git a/test/unit_tests/models/test_models.py b/test/unit_tests/models/test_models.py index 2b4913fbd..752f8dfe0 100644 --- a/test/unit_tests/models/test_models.py +++ b/test/unit_tests/models/test_models.py @@ -150,17 +150,46 @@ def test_eeginception_erp_n_params(): def test_eeginception_mi(input_sizes): - sfreq = 100 + sfreq = 250 model = EEGInceptionMI( n_classes=input_sizes['n_classes'], in_channels=input_sizes['n_channels'], input_window_s=input_sizes['n_in_times'] / sfreq, - sfreq=100, + sfreq=sfreq, ) check_forward_pass(model, input_sizes,) +@pytest.mark.parametrize( + "n_filter,reported", + [(6, 51386), (12, 204002), (16, 361986), (24, 812930), (64, 5767170)] +) +def test_eeginception_mi_binary_n_params(n_filter, reported): + """Make sure the number of parameters is the same as in the paper when + using the same architecture hyperparameters. + + Note + ---- + For some reason, we match the correct number of parameters for all + configurations in the binary classification case, but none for the 4-class + case... Should be investigated by contacting the authors. + """ + model = EEGInceptionMI( + in_channels=3, + n_classes=2, + input_window_s=3., # input_time + sfreq=250, + n_convs=3, + n_filters=n_filter, + kernel_unit_s=0.1, + ) + + n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + # From first column of TABLE 2 in EEG-Inception paper + assert n_params == reported + + @pytest.mark.parametrize( "n_channels,sfreq,n_classes,input_size_s", [(20, 128, 5, 30), (10, 256, 4, 20), (1, 64, 2, 30)], From d4c5c35fc7a2bd0cf94574a423f7069840b8b8f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Mon, 7 Nov 2022 14:35:31 +0100 Subject: [PATCH 06/13] Remove forgotten breakpoint --- braindecode/models/eeginception_mi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/braindecode/models/eeginception_mi.py b/braindecode/models/eeginception_mi.py index 51e0593bd..2a31e267c 100644 --- a/braindecode/models/eeginception_mi.py +++ b/braindecode/models/eeginception_mi.py @@ -158,7 +158,6 @@ def forward( ) -> torch.Tensor: X = self.ensuredims(X) X = self.dimshuffle(X) - import ipdb; ipdb.set_trace() res1 = self.residual_block_1(X) From f9a9846b1bee4e18cd2732c6013ba6a29bf6a847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Mon, 7 Nov 2022 17:18:07 +0100 Subject: [PATCH 07/13] Add softmax layer in the end --- braindecode/models/eeginception_mi.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/braindecode/models/eeginception_mi.py b/braindecode/models/eeginception_mi.py index 2a31e267c..b68fa6f70 100644 --- a/braindecode/models/eeginception_mi.py +++ b/braindecode/models/eeginception_mi.py @@ -152,6 +152,8 @@ def __init__( bias=True, ) + self.softmax = nn.LogSoftmax(dim=1) + def forward( self, X: torch.Tensor, @@ -176,7 +178,8 @@ def forward( out = self.ave_pooling(out) out = self.flat(out) - return self.fc(out) + out = self.fc(out) + return self.softmax(out) class _InceptionModuleMI(nn.Module): From 2a6970c99b9d606d11f7712047aba9718db03f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Tue, 8 Nov 2022 09:16:39 +0100 Subject: [PATCH 08/13] Fix reference formatting in docstring Co-authored-by: Alexandre Gramfort --- braindecode/models/eeginception_mi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/braindecode/models/eeginception_mi.py b/braindecode/models/eeginception_mi.py index b68fa6f70..59ab2867d 100644 --- a/braindecode/models/eeginception_mi.py +++ b/braindecode/models/eeginception_mi.py @@ -52,9 +52,9 @@ class EEGInceptionMI(nn.Module): References ---------- .. [1] Zhang, C., Kim, Y. K., & Eskandarian, A. (2021). - EEG-inception: an accurate and robust end-to-end neural network for - EEG-based motor imagery classification. - Journal of Neural Engineering, 18(4), 046014. + EEG-inception: an accurate and robust end-to-end neural network + for EEG-based motor imagery classification. + Journal of Neural Engineering, 18(4), 046014. """ def __init__( From d311ee983ba84db9534e8c0523aaad3db9eddfaa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Tue, 8 Nov 2022 09:34:40 +0100 Subject: [PATCH 09/13] Properly deprecate prev version of EEGInception --- braindecode/models/__init__.py | 1 + braindecode/models/eeginception.py | 296 +++++++++++++++++++++++++++++ 2 files changed, 297 insertions(+) create mode 100644 braindecode/models/eeginception.py diff --git a/braindecode/models/__init__.py b/braindecode/models/__init__.py index 7bff63df1..419684b34 100644 --- a/braindecode/models/__init__.py +++ b/braindecode/models/__init__.py @@ -7,6 +7,7 @@ from .hybrid import HybridNet from .shallow_fbcsp import ShallowFBCSPNet from .eegresnet import EEGResNet +from .eeginception import EEGInception from .eeginception_erp import EEGInceptionERP from .eeginception_mi import EEGInceptionMI from .tcn import TCN diff --git a/braindecode/models/eeginception.py b/braindecode/models/eeginception.py new file mode 100644 index 000000000..a1e3434c7 --- /dev/null +++ b/braindecode/models/eeginception.py @@ -0,0 +1,296 @@ +# Authors: Bruno Aristimunha +# Cedric Rommel +# +# License: BSD (3-clause) +from warnings import warn + +from numpy import prod + +from torch import nn +from .modules import Expression, Ensure4d +from .eegnet import _glorot_weight_zero_bias +from .eegitnet import _InceptionBlock, _DepthwiseConv2d + + +def _transpose_to_b_1_c_0(x): + return x.permute(0, 3, 1, 2) + + +class EEGInception(nn.Sequential): + """ EEG Inception for ERP-based classification + + --> DEPERECATED <-- + THIS CLASS IS DEPRECATED AND WILL BE REMOVED IN THE NEXT RELEASE OF + BRAINDECODE. PLEASE USE braindecode.models.EEGInceptionERP INSTEAD IN THE + FUTURE. + + The code for the paper and this model is also available at [Santamaria2020]_ + and an adaptation for PyTorch [2]_. + + The model is strongly based on the original InceptionNet for an image. The main goal is + to extract features in parallel with different scales. The authors extracted three scales + proportional to the window sample size. The network had three parts: + 1-larger inception block largest, 2-smaller inception block followed by 3-bottleneck + for classification. + + One advantage of the EEG-Inception block is that it allows a network + to learn simultaneous components of low and high frequency associated with the signal. + The winners of BEETL Competition/NeurIps 2021 used parts of the model [beetl]_. + + The model is fully described in [Santamaria2020]_. + + Notes + ----- + This implementation is not guaranteed to be correct, has not been checked + by original authors, only reimplemented from the paper based on [2]_. + + Parameters + ---------- + in_channels : int + Number of EEG channels. + n_classes : int + Number of classes. + input_window_samples : int, optional + Size of the input, in number of sampels. Set to 128 (1s) as in + [Santamaria2020]_. + sfreq : float, optional + EEG sampling frequency. Defaults to 128 as in [Santamaria2020]_. + drop_prob : float, optional + Dropout rate inside all the network. Defaults to 0.5 as in + [Santamaria2020]_. + scales_samples_s: list(float), optional + Windows for inception block. Temporal scale (s) of the convolutions on + each Inception module. This parameter determines the kernel sizes of + the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in + [Santamaria2020]_. + n_filters : int, optional + Initial number of convolutional filters. Defaults to 8 as in + [Santamaria2020]_. + activation: nn.Module, optional + Activation function. Defaults to ELU activation as in + [Santamaria2020]_. + batch_norm_alpha: float, optional + Momentum for BatchNorm2d. Defaults to 0.01. + depth_multiplier: int, optional + Depth multiplier for the depthwise convolution. Defaults to 2 as in + [Santamaria2020]_. + pooling_sizes: list(int), optional + Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as + in [Santamaria2020]_. + + References + ---------- + .. [Santamaria2020] Santamaria-Vazquez, E., Martinez-Cagigal, V., + Vaquerizo-Villar, F., & Hornero, R. (2020). + EEG-inception: A novel deep convolutional neural network for assistive + ERP-based brain-computer interfaces. + IEEE Transactions on Neural Systems and Rehabilitation Engineering , v. 28. + Online: http://dx.doi.org/10.1109/TNSRE.2020.3048106 + .. [2] Grifcc. Implementation of the EEGInception in torch (2022). + Online: https://github.com/Grifcc/EEG/tree/90e412a407c5242dfc953d5ffb490bdb32faf022 + .. [beetl]_ Wei, X., Faisal, A.A., Grosse-Wentrup, M., Gramfort, A., Chevallier, S., + Jayaram, V., Jeunet, C., Bakas, S., Ludwig, S., Barmpas, K., Bahri, M., Panagakis, + Y., Laskaris, N., Adamos, D.A., Zafeiriou, S., Duong, W.C., Gordon, S.M., + Lawhern, V.J., ƚliwowski, M., Rouanne, V. & Tempczyk, P.. (2022). + 2021 BEETL Competition: Advancing Transfer Learning for Subject Independence & + Heterogenous EEG Data Sets. Proceedings of the NeurIPS 2021 Competitions and + Demonstrations Track, in Proceedings of Machine Learning Research + 176:205-219 Available from https://proceedings.mlr.press/v176/wei22a.html. + + """ + + def __init__( + self, + in_channels, + n_classes, + input_window_samples=1000, + sfreq=128, + drop_prob=0.5, + scales_samples_s=(0.5, 0.25, 0.125), + n_filters=8, + activation=nn.ELU(), + batch_norm_alpha=0.01, + depth_multiplier=2, + pooling_sizes=(4, 2, 2, 2), + ): + super().__init__() + + warn( + "The class EEGInception is deprecated and will be removed in the " + "next release of braindecode. Please use " + "braindecode.models.EEGInceptionERP instead in the future.", + DeprecationWarning + ) + + self.in_channels = in_channels + self.n_classes = n_classes + self.input_window_samples = input_window_samples + self.drop_prob = drop_prob + self.sfreq = sfreq + self.n_filters = n_filters + self.scales_samples_s = scales_samples_s + self.scales_samples = tuple( + int(size_s * self.sfreq) for size_s in self.scales_samples_s) + self.activation = activation + self.alpha_momentum = batch_norm_alpha + self.depth_multiplier = depth_multiplier + self.pooling_sizes = pooling_sizes + + self.add_module("ensuredims", Ensure4d()) + + self.add_module("dimshuffle", Expression(_transpose_to_b_1_c_0)) + + # ======== Inception branches ======================== + block11 = self._get_inception_branch_1( + in_channels=in_channels, + out_channels=self.n_filters, + kernel_length=self.scales_samples[0], + alpha_momentum=self.alpha_momentum, + activation=self.activation, + drop_prob=self.drop_prob, + depth_multiplier=self.depth_multiplier, + ) + block12 = self._get_inception_branch_1( + in_channels=in_channels, + out_channels=self.n_filters, + kernel_length=self.scales_samples[1], + alpha_momentum=self.alpha_momentum, + activation=self.activation, + drop_prob=self.drop_prob, + depth_multiplier=self.depth_multiplier, + ) + block13 = self._get_inception_branch_1( + in_channels=in_channels, + out_channels=self.n_filters, + kernel_length=self.scales_samples[2], + alpha_momentum=self.alpha_momentum, + activation=self.activation, + drop_prob=self.drop_prob, + depth_multiplier=self.depth_multiplier, + ) + + self.add_module("inception_block_1", _InceptionBlock((block11, block12, block13))) + + self.add_module("avg_pool_1", nn.AvgPool2d((1, self.pooling_sizes[0]))) + + # ======== Inception branches ======================== + n_concat_filters = len(self.scales_samples) * self.n_filters + n_concat_dw_filters = n_concat_filters * self.depth_multiplier + block21 = self._get_inception_branch_2( + in_channels=n_concat_dw_filters, + out_channels=self.n_filters, + kernel_length=self.scales_samples[0] // 4, + alpha_momentum=self.alpha_momentum, + activation=self.activation, + drop_prob=self.drop_prob + ) + block22 = self._get_inception_branch_2( + in_channels=n_concat_dw_filters, + out_channels=self.n_filters, + kernel_length=self.scales_samples[1] // 4, + alpha_momentum=self.alpha_momentum, + activation=self.activation, + drop_prob=self.drop_prob + ) + block23 = self._get_inception_branch_2( + in_channels=n_concat_dw_filters, + out_channels=self.n_filters, + kernel_length=self.scales_samples[2] // 4, + alpha_momentum=self.alpha_momentum, + activation=self.activation, + drop_prob=self.drop_prob + ) + + self.add_module( + "inception_block_2", _InceptionBlock((block21, block22, block23))) + + self.add_module("avg_pool_2", nn.AvgPool2d((1, self.pooling_sizes[1]))) + + self.add_module("final_block", nn.Sequential( + nn.Conv2d( + n_concat_filters, + n_concat_filters // 2, + (1, 8), + padding="same", + bias=False + ), + nn.BatchNorm2d(n_concat_filters // 2, + momentum=self.alpha_momentum), + activation, + nn.Dropout(self.drop_prob), + nn.AvgPool2d((1, self.pooling_sizes[2])), + + nn.Conv2d( + n_concat_filters // 2, + n_concat_filters // 4, + (1, 4), + padding="same", + bias=False + ), + nn.BatchNorm2d(n_concat_filters // 4, + momentum=self.alpha_momentum), + activation, + nn.Dropout(self.drop_prob), + nn.AvgPool2d((1, self.pooling_sizes[3])), + )) + + spatial_dim_last_layer = ( + input_window_samples // prod(self.pooling_sizes)) + n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4 + + self.add_module("classification", nn.Sequential( + nn.Flatten(), + nn.Linear( + spatial_dim_last_layer * n_channels_last_layer, + self.n_classes + ), + nn.Softmax(1) + )) + + _glorot_weight_zero_bias(self) + + @staticmethod + def _get_inception_branch_1(in_channels, out_channels, kernel_length, + alpha_momentum, drop_prob, activation, + depth_multiplier): + return nn.Sequential( + nn.Conv2d( + 1, + out_channels, + kernel_size=(1, kernel_length), + padding="same", + bias=True + ), + nn.BatchNorm2d(out_channels, momentum=alpha_momentum), + activation, + nn.Dropout(drop_prob), + _DepthwiseConv2d( + out_channels, + kernel_size=(in_channels, 1), + depth_multiplier=depth_multiplier, + bias=False, + padding="valid", + ), + nn.BatchNorm2d( + depth_multiplier * out_channels, + momentum=alpha_momentum + ), + activation, + nn.Dropout(drop_prob), + ) + + @staticmethod + def _get_inception_branch_2(in_channels, out_channels, kernel_length, + alpha_momentum, drop_prob, activation): + return nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=(1, kernel_length), + padding="same", + bias=False + ), + nn.BatchNorm2d(out_channels, momentum=alpha_momentum), + activation, + nn.Dropout(drop_prob), + ) From 1e09b394e33650f5b70f70be63e235d47f6b797e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Tue, 8 Nov 2022 10:51:44 +0100 Subject: [PATCH 10/13] Keep testing deprecated EEGInception --- test/unit_tests/models/test_models.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/test/unit_tests/models/test_models.py b/test/unit_tests/models/test_models.py index 752f8dfe0..d61876d64 100644 --- a/test/unit_tests/models/test_models.py +++ b/test/unit_tests/models/test_models.py @@ -13,7 +13,7 @@ from braindecode.models import ( Deep4Net, EEGNetv4, EEGNetv1, HybridNet, ShallowFBCSPNet, EEGResNet, TCN, SleepStagerChambon2018, SleepStagerBlanco2020, SleepStagerEldele2021, USleep, - EEGITNet, EEGInceptionERP, EEGInceptionMI, TIDNet) + EEGITNet, EEGInception, EEGInceptionERP, EEGInceptionMI, TIDNet) from braindecode.util import set_random_seeds @@ -121,8 +121,9 @@ def test_eegitnet(input_sizes): check_forward_pass(model, input_sizes,) -def test_eeginception_erp(input_sizes): - model = EEGInceptionERP( +@pytest.mark.parametrize("model_cls", [EEGInception, EEGInceptionERP]) +def test_eeginception_erp(input_sizes, model_cls): + model = model_cls( n_classes=input_sizes['n_classes'], in_channels=input_sizes['n_channels'], input_window_samples=input_sizes['n_in_times']) @@ -130,11 +131,12 @@ def test_eeginception_erp(input_sizes): check_forward_pass(model, input_sizes,) -def test_eeginception_erp_n_params(): +@pytest.mark.parametrize("model_cls", [EEGInception, EEGInceptionERP]) +def test_eeginception_erp_n_params(model_cls): """Make sure the number of parameters is the same as in the paper when using the same architecture hyperparameters. """ - model = EEGInceptionERP( + model = model_cls( in_channels=8, n_classes=2, input_window_samples=128, # input_time From 68f4bf2ec373d6a5885c03777494cb1e7dc3dc68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Tue, 8 Nov 2022 16:35:12 +0100 Subject: [PATCH 11/13] Restablish deprecated docstring --- braindecode/models/eeginception.py | 53 ++++++++++++++---------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/braindecode/models/eeginception.py b/braindecode/models/eeginception.py index a1e3434c7..6878c06ca 100644 --- a/braindecode/models/eeginception.py +++ b/braindecode/models/eeginception.py @@ -20,7 +20,7 @@ class EEGInception(nn.Sequential): """ EEG Inception for ERP-based classification --> DEPERECATED <-- - THIS CLASS IS DEPRECATED AND WILL BE REMOVED IN THE NEXT RELEASE OF + THIS CLASS IS DEPRECATED AND WILL BE REMOVED IN THE RELEASE 0.9 OF BRAINDECODE. PLEASE USE braindecode.models.EEGInceptionERP INSTEAD IN THE FUTURE. @@ -50,33 +50,28 @@ class EEGInception(nn.Sequential): Number of EEG channels. n_classes : int Number of classes. - input_window_samples : int, optional - Size of the input, in number of sampels. Set to 128 (1s) as in - [Santamaria2020]_. - sfreq : float, optional - EEG sampling frequency. Defaults to 128 as in [Santamaria2020]_. - drop_prob : float, optional - Dropout rate inside all the network. Defaults to 0.5 as in - [Santamaria2020]_. - scales_samples_s: list(float), optional - Windows for inception block. Temporal scale (s) of the convolutions on - each Inception module. This parameter determines the kernel sizes of - the filters. Defaults to 0.5, 0.25, 0.125 seconds, as in - [Santamaria2020]_. - n_filters : int, optional - Initial number of convolutional filters. Defaults to 8 as in - [Santamaria2020]_. - activation: nn.Module, optional - Activation function. Defaults to ELU activation as in - [Santamaria2020]_. - batch_norm_alpha: float, optional - Momentum for BatchNorm2d. Defaults to 0.01. - depth_multiplier: int, optional - Depth multiplier for the depthwise convolution. Defaults to 2 as in - [Santamaria2020]_. - pooling_sizes: list(int), optional - Pooling sizes for the inception blocks. Defaults to 4, 2, 2 and 2, as - in [Santamaria2020]_. + input_size_ms : int + Size of the input, in milliseconds. Set to 1000 in [Santamaria2020]_. + sfreq : float + EEG sampling frequency. + drop_prob : float + Dropout rate inside all the network. + scales_time: list(int) + Windows for inception block, must be a list with proportional values of + the input_size_ms. + According to the authors: temporal scale (ms) of the convolutions + on each Inception module. + This parameter determines the kernel sizes of the filters. + n_filters : int + Initial number of convolutional filters. Set to 8 in [Santamaria2020]_. + activation: nn.Module + Activation function, default: ELU activation. + batch_norm_alpha: float + Momentum for BatchNorm2d. + depth_multiplier: int + Depth multiplier for the depthwise convolution. + pooling_sizes: list(int) + Pooling sizes for the inception block. References ---------- @@ -117,7 +112,7 @@ def __init__( warn( "The class EEGInception is deprecated and will be removed in the " - "next release of braindecode. Please use " + "release 0.9 of braindecode. Please use " "braindecode.models.EEGInceptionERP instead in the future.", DeprecationWarning ) From d9ba02738490690cab8ea191acc601e06079dc4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Tue, 8 Nov 2022 16:35:22 +0100 Subject: [PATCH 12/13] Update api.rst and whats_new.rst --- docs/api.rst | 2 ++ docs/whats_new.rst | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/docs/api.rst b/docs/api.rst index 7e6883f5f..a71163a93 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -56,7 +56,9 @@ Models ShallowFBCSPNet Deep4Net + EEGInception EEGInceptionERP + EEGInceptionMI EEGITNet EEGNetv1 EEGNetv4 diff --git a/docs/whats_new.rst b/docs/whats_new.rst index 45b3c8e44..c00254aef 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -29,6 +29,22 @@ Bugs API changes ~~~~~~~~~~~ +.. _changes_0_8_0: + +Current 0.8 () +---------------------- + +Enhancements +~~~~~~~~~~~~ +- Adding :class:`braindecode.models.EEGInceptionMI` network for motor imagery (:gh:`428` by `Cedric Rommel`_) + +Bugs +~~~~ + + +API changes +~~~~~~~~~~~ +- Renaming the :class:`braindecode.models.EEGInception` network as :class:`braindecode.models.EEGInceptionERP` (:gh:`428` by `Cedric Rommel`_) .. _changes_0_7_0: From d1c03d6d451920bccab9ed2a23d5edee32527027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Rommel?= Date: Wed, 9 Nov 2022 11:43:28 +0100 Subject: [PATCH 13/13] Rename _ResidualModuleMI --- braindecode/models/eeginception_mi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/braindecode/models/eeginception_mi.py b/braindecode/models/eeginception_mi.py index 59ab2867d..d34e17c3c 100644 --- a/braindecode/models/eeginception_mi.py +++ b/braindecode/models/eeginception_mi.py @@ -107,7 +107,7 @@ def __init__( ) for _ in range(2) ]) - self.residual_block_1 = _ResidualBlockMI( + self.residual_block_1 = _ResidualModuleMI( in_channels=self.in_channels, n_filters=intermediate_in_channels, activation=self.activation, @@ -124,7 +124,7 @@ def __init__( ) for _ in range(3) ]) - self.residual_block_2 = _ResidualBlockMI( + self.residual_block_2 = _ResidualModuleMI( in_channels=intermediate_in_channels, n_filters=intermediate_in_channels, activation=self.activation, @@ -256,7 +256,7 @@ def forward( return self.activation(out) -class _ResidualBlockMI(nn.Module): +class _ResidualModuleMI(nn.Module): def __init__( self, in_channels,