Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EHN Normalizing the last layer in all the models #520

Merged
merged 53 commits into from Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6a6f9a8
Renaming last layer to final_layer
brunaafl Sep 8, 2023
cf63701
Renaming last layer to final_layer
brunaafl Sep 8, 2023
9ebe8e2
Renaming last layer to final_layer
brunaafl Sep 8, 2023
3665f10
Renaming last layer to final_layer
brunaafl Sep 8, 2023
a40f29e
Renaming last layer to final_layer
brunaafl Sep 8, 2023
002338f
Renaming last layer to final_layer
brunaafl Sep 8, 2023
c22e4e7
Commented tests involving return features (disabled for now)
brunaafl Sep 8, 2023
7ca5f94
Commented tests involving return features (disabled for now)
brunaafl Sep 8, 2023
0beaa80
Fixing examples
brunaafl Sep 8, 2023
9288a8d
Fixing examples
brunaafl Sep 8, 2023
9553342
Merge branch 'master' into Normalizing-layer-names
PierreGtch Sep 9, 2023
a77b65e
Updating whats_new file
brunaafl Sep 9, 2023
aa5bdd4
Changing last layer to final_layer on Sequential + load_state_dict on…
brunaafl Sep 12, 2023
c1deaf6
Changing last layer to final_layer on Sequential + load_state_dict on…
brunaafl Sep 12, 2023
f913a79
Changing last layer to final_layer on Sequential + load_state_dict on…
brunaafl Sep 12, 2023
14e90ba
Apply suggestions from code review
brunaafl Sep 13, 2023
9b87ad9
Changing last layer to final_layer on Modules
brunaafl Sep 13, 2023
a99c7ec
implementing load_state_dict on EEGMixin
brunaafl Sep 14, 2023
cde2ac5
Merge branch 'master' into Normalizing-layer-names
brunaafl Sep 14, 2023
9a8f8b1
implementing load_state_dict on EEGMixin
brunaafl Sep 14, 2023
7605b98
fixing flake
brunaafl Sep 14, 2023
d878f1b
fixing issue
brunaafl Sep 14, 2023
219d037
fixing issue
brunaafl Sep 14, 2023
fef6e44
fixing issue
brunaafl Sep 14, 2023
132435d
Update braindecode/models/eegconformer.py
bruAristimunha Sep 14, 2023
f5ea230
Update braindecode/models/eegconformer.py
bruAristimunha Sep 14, 2023
22691de
Update braindecode/models/eeginception.py
bruAristimunha Sep 14, 2023
ff874c4
Update braindecode/models/eegnet.py
bruAristimunha Sep 14, 2023
86b7587
Update braindecode/models/eegconformer.py
bruAristimunha Sep 14, 2023
819ff36
Update braindecode/models/eegnet.py
bruAristimunha Sep 14, 2023
2d5a5c7
Update braindecode/models/deepsleepnet.py
brunaafl Sep 18, 2023
07c1f41
Update braindecode/models/base.py
brunaafl Sep 18, 2023
a3161d1
Update braindecode/models/eegconformer.py
brunaafl Sep 18, 2023
e49cc51
Update braindecode/models/eegconformer.py
brunaafl Sep 18, 2023
3d04155
Update braindecode/models/eegconformer.py
brunaafl Sep 18, 2023
2b5cfb3
Update braindecode/models/eegconformer.py
brunaafl Sep 18, 2023
1ce7c70
removing a comment hybrid.py
brunaafl Sep 18, 2023
6e18a67
Merge remote-tracking branch 'origin/Normalizing-layer-names' into No…
brunaafl Sep 18, 2023
f302017
Update braindecode/models/eegconformer.py
brunaafl Sep 19, 2023
f429daf
Update braindecode/models/eegconformer.py
brunaafl Sep 19, 2023
7846f89
Update braindecode/models/eegconformer.py
brunaafl Sep 19, 2023
463b7eb
Update braindecode/models/eegconformer.py
brunaafl Sep 19, 2023
557e4ad
Update braindecode/models/eegconformer.py
brunaafl Sep 19, 2023
96d38ab
eegconformer update following suggestions
brunaafl Sep 19, 2023
20edb41
Update braindecode/models/hybrid.py
brunaafl Sep 19, 2023
f07079a
Update braindecode/models/tcn.py
brunaafl Sep 19, 2023
a7bc0af
Update braindecode/models/shallow_fbcsp.py
brunaafl Sep 19, 2023
2341366
Update braindecode/models/eegitnet.py
brunaafl Sep 19, 2023
911f112
Merge branch 'master' into Normalizing-layer-names
brunaafl Sep 19, 2023
426dc75
fixing flake
brunaafl Sep 19, 2023
c77fe04
Merge remote-tracking branch 'origin/Normalizing-layer-names' into No…
brunaafl Sep 19, 2023
04a1311
fixing test
brunaafl Sep 19, 2023
41d8ee0
Merge branch 'master' into Normalizing-layer-names
bruAristimunha Sep 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 10 additions & 4 deletions braindecode/models/atcnet.py
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
Expand Up @@ -165,6 +165,12 @@ def __init__(
self.concat = concat
self.max_norm_const = max_norm_const

map = dict()
for w in range(self.n_windows):
map[f'max_norm_linears.[{w}].weight'] = f'final_layer.[{w}].weight'
map[f'max_norm_linears.[{w}].bias'] = f'final_layer.[{w}].bias'
self.mapping = map

# Check later if we want to keep the Ensure4d. Not sure if we can
# remove it or replace it with eipsum.
self.ensuredims = Ensure4d()
Expand Down Expand Up @@ -209,15 +215,15 @@ def __init__(
])

if self.concat:
self.max_norm_linears = nn.ModuleList([
self.final_layer = nn.ModuleList([
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
MaxNormLinear(
in_features=self.F2 * self.n_windows,
out_features=self.n_outputs,
max_norm_val=self.max_norm_const
)
])
else:
self.max_norm_linears = nn.ModuleList([
self.final_layer = nn.ModuleList([
MaxNormLinear(
in_features=self.F2,
out_features=self.n_outputs,
Expand Down Expand Up @@ -261,14 +267,14 @@ def forward(self, X):
# mapped by dense layer or concatenated then mapped by a dense
# layer
if not self.concat:
tcn_feat = self.max_norm_linears[w](tcn_feat)
tcn_feat = self.final_layer[w](tcn_feat)

sw_concat.append(tcn_feat)

# ----- Aggregation and prediction -----
if self.concat:
sw_concat = torch.cat(sw_concat, dim=1)
sw_concat = self.max_norm_linears[0](sw_concat)
sw_concat = self.final_layer[0](sw_concat)
else:
if len(sw_concat) > 1: # more than one window
sw_concat = torch.stack(sw_concat, dim=0)
Expand Down
20 changes: 18 additions & 2 deletions braindecode/models/base.py
Expand Up @@ -6,6 +6,8 @@
import warnings
from typing import Dict, Iterable, List, Optional, Tuple

from collections import OrderedDict

import numpy as np
import torch
from docstring_inheritance import NumpyDocstringInheritanceInitMeta
Expand Down Expand Up @@ -202,8 +204,8 @@ def get_output_shape(self) -> Tuple[int]:
)).shape)
except RuntimeError as exc:
if str(exc).endswith(
("Output size is too small",
"Kernel size can't be greater than actual input size")
("Output size is too small",
"Kernel size can't be greater than actual input size")
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
):
msg = (
"During model prediction RuntimeError was thrown showing that at some "
Expand All @@ -215,6 +217,20 @@ def get_output_shape(self) -> Tuple[int]:
raise ValueError(msg) from exc
raise exc

mapping = None

def load_state_dict(self, state_dict, *args, **kwargs):

mapping = self.mapping if self.mapping else {}
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k in mapping:
new_state_dict[mapping[k]] = v
else:
new_state_dict[k] = v

return super().load_state_dict(new_state_dict, *args, **kwargs)

def to_dense_prediction_model(self, axis: Tuple[int] = (2, 3)) -> None:
"""
Transform a sequential model with strides to a model that outputs
Expand Down
59 changes: 29 additions & 30 deletions braindecode/models/deep4.py
Expand Up @@ -2,8 +2,6 @@
#
# License: BSD (3-clause)

from collections import OrderedDict

from einops.layers.torch import Rearrange
from torch import nn
from torch.nn import init
Expand Down Expand Up @@ -167,6 +165,18 @@ def __init__(
self.batch_norm_alpha = batch_norm_alpha
self.stride_before_pool = stride_before_pool

# For the load_state_dict
# When padronize all layers,
# add the old's parameters here
self.mapping = {
"conv_time.weight": "conv_time_spat.conv_time.weight",
"conv_spat.weight": "conv_time_spat.conv_spat.weight",
"conv_time.bias": "conv_time_spat.conv_time.bias",
"conv_spat.bias": "conv_time_spat.conv_spat.bias",
"conv_classifier.weight": "final_layer.conv_classifier.weight",
"conv_classifier.bias": "final_layer.conv_classifier.bias"
}

if self.stride_before_pool:
conv_stride = self.pool_time_stride
pool_stride = 1
Expand Down Expand Up @@ -272,18 +282,23 @@ def add_conv_pool_block(
self.eval()
if self.final_conv_length == "auto":
self.final_conv_length = self.get_output_shape()[2]
self.add_module(
"conv_classifier",
nn.Conv2d(
self.n_filters_4,
self.n_outputs,
(self.final_conv_length, 1),
bias=True,
),
)

# Incorporating classification module and subsequent ones in one final layer
module = nn.Sequential()

module.add_module("conv_classifier",
nn.Conv2d(
self.n_filters_4,
self.n_outputs,
(self.final_conv_length, 1),
bias=True, ))

if self.add_log_softmax:
self.add_module("logsoftmax", nn.LogSoftmax(dim=1))
self.add_module("squeeze", Expression(squeeze_final_output))

module.add_module("squeeze", Expression(squeeze_final_output))

self.add_module("final_layer", module)

# Initialization, xavier is same as in our paper...
# was default from lasagne
Expand Down Expand Up @@ -311,24 +326,8 @@ def add_conv_pool_block(
init.constant_(bnorm_weight, 1)
init.constant_(bnorm_bias, 0)

init.xavier_uniform_(self.conv_classifier.weight, gain=1)
init.constant_(self.conv_classifier.bias, 0)
init.xavier_uniform_(self.final_layer.conv_classifier.weight, gain=1)
init.constant_(self.final_layer.conv_classifier.bias, 0)

# Start in eval mode
self.eval()

def load_state_dict(self, state_dict, *args, **kwargs):
"""Wrapper to allow for loading of a state_dict from a model before CombinedConv was
implemented"""
keys_to_change = [
"conv_time.weight",
"conv_spat.weight",
"conv_time.bias",
"conv_spat.bias",
]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k in keys_to_change:
k = f"conv_time_spat.{k}"
new_state_dict[k] = v
return super().load_state_dict(new_state_dict, *args, **kwargs)
4 changes: 4 additions & 0 deletions braindecode/models/deepsleepnet.py
Expand Up @@ -222,8 +222,12 @@ def __init__(
self.features_extractor = nn.Identity()
self.len_last_layer = 1024
self.return_feats = return_feats

# TODO: Add new way to handle return_features == True
if not return_feats:
self.final_layer = nn.Linear(1024, self.n_outputs)
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
else:
self.final_layer = nn.Identity()

def forward(self, x):
"""Forward pass.
Expand Down
60 changes: 47 additions & 13 deletions braindecode/models/eegconformer.py
Expand Up @@ -123,6 +123,11 @@ def __init__(
sfreq=sfreq,
add_log_softmax=add_log_softmax,
)
self.mapping = {
'classification_head.fc.6.weight': 'final_layer.final_layer.0.weight',
'classification_head.fc.6.bias': 'final_layer.final_layer.0.bias'
}

del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
del n_classes, n_channels, input_window_samples
if not (self.n_chans <= 64):
Expand All @@ -148,16 +153,19 @@ def __init__(
att_heads=att_heads,
att_drop=att_drop_prob)

self.classification_head = _ClassificationHead(
final_fc_length=final_fc_length,
n_classes=self.n_outputs, return_features=return_features,
add_log_softmax=self.add_log_softmax)
self.fc = _FullyConnected(
final_fc_length=final_fc_length)

self.final_layer = _FinalLayer(n_classes=self.n_outputs,
return_features=return_features,
add_log_softmax=self.add_log_softmax)

def forward(self, x: Tensor) -> Tensor:
x = torch.unsqueeze(x, dim=1) # add one extra dimension
x = self.patch_embedding(x)
x = self.transformer(x)
x = self.classification_head(x)
x = fc(x)
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
x = self.final_layer(x)
return x

def get_fc_size(self):
Expand Down Expand Up @@ -345,12 +353,11 @@ def __init__(self, att_depth, emb_size, att_heads, att_drop):
)


class _ClassificationHead(nn.Module):
def __init__(self, final_fc_length, n_classes,
class _FullyConnected(nn.Module):
def __init__(self, final_fc_length,
drop_prob_1=0.5, drop_prob_2=0.3, out_channels=256,
hidden_channels=32, return_features=False,
add_log_softmax=True):
""""Classification head for the transformer encoder.
hidden_channels=32):
"”””Fully-connected layer for the transformer encoder.
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
Expand Down Expand Up @@ -380,6 +387,32 @@ def __init__(self, final_fc_length, n_classes,
nn.Linear(out_channels, hidden_channels),
nn.ELU(),
nn.Dropout(drop_prob_2),
)

def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
return out


class _FinalLayer(nn.Module):
def __init__(self, n_classes, hidden_channels=32, return_features=False, add_log_softmax=True):
""""Classification head for the transformer encoder.
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
n_classes : int
Number of classes for classification.
hidden_channels : int
Number of output channels for the second linear layer.
return_features : bool
Whether to return input features.
bruAristimunha marked this conversation as resolved.
Show resolved Hide resolved
add_log_softmax : bool
Adding LogSoftmax or not.
"""

super().__init__()
self.final_layer = nn.Sequential(
nn.Linear(hidden_channels, n_classes),
)
self.return_features = return_features
Expand All @@ -389,9 +422,10 @@ def __init__(self, final_fc_length, n_classes,
self.classification = nn.Identity()
brunaafl marked this conversation as resolved.
Show resolved Hide resolved

brunaafl marked this conversation as resolved.
Show resolved Hide resolved
def forward(self, x):
x = x.contiguous().view(x.size(0), -1)
out = self.fc(x)
if self.return_features:
out = self.final_layer(x)
return out, x
else:
return self.classification(out)
self.final_layer.add_module('classification', self.classification)
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
out = self.final_layer(x)
return out
29 changes: 21 additions & 8 deletions braindecode/models/eeginception.py
Expand Up @@ -144,6 +144,10 @@ def __init__(
self.depth_multiplier = depth_multiplier
self.pooling_sizes = pooling_sizes

self.mapping = {
'classification.1.weight': 'final_layer.fc.weight',
'classification.1.bias': 'final_layer.fc.bias'}

self.add_module("ensuredims", Ensure4d())

self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 C T"))
Expand Down Expand Up @@ -246,14 +250,23 @@ def __init__(
self.n_times // prod(self.pooling_sizes))
n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4

self.add_module("classification", nn.Sequential(
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
nn.Flatten(),
nn.Linear(
spatial_dim_last_layer * n_channels_last_layer,
self.n_outputs
),
nn.LogSoftmax(dim=1) if self.add_log_softmax else nn.Identity(),
))
self.add_module("flat", nn.Flatten())

module = nn.Sequential()

module.add_module("fc",
nn.Linear(
spatial_dim_last_layer * n_channels_last_layer,
self.n_outputs
), )

if self.add_log_softmax:
module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
else:
module.add_module("identity", nn.Identity())

# The conv_classifier will be the final_layer and the other ones will be incorporated
self.add_module("final_layer", module)

_glorot_weight_zero_bias(self)

Expand Down
29 changes: 21 additions & 8 deletions braindecode/models/eeginception_erp.py
Expand Up @@ -140,6 +140,10 @@ def __init__(
self.depth_multiplier = depth_multiplier
self.pooling_sizes = pooling_sizes

self.mapping = {
'classification.1.weight': 'final_layer.fc.weight',
'classification.1.bias': 'final_layer.fc.bias'}

self.add_module("ensuredims", Ensure4d())

self.add_module("dimshuffle", Rearrange("batch C T 1 -> batch 1 C T"))
Expand Down Expand Up @@ -242,14 +246,23 @@ def __init__(
self.n_times // prod(self.pooling_sizes))
n_channels_last_layer = self.n_filters * len(self.scales_samples) // 4

self.add_module("classification", nn.Sequential(
brunaafl marked this conversation as resolved.
Show resolved Hide resolved
nn.Flatten(),
nn.Linear(
spatial_dim_last_layer * n_channels_last_layer,
self.n_outputs
),
nn.LogSoftmax(dim=1) if self.add_log_softmax else nn.Identity(),
))
self.add_module("flat", nn.Flatten())

# Incorporating classification module and subsequent ones in one final layer
module = nn.Sequential()

module.add_module("fc",
nn.Linear(
spatial_dim_last_layer * n_channels_last_layer,
self.n_outputs
), )

if self.add_log_softmax:
module.add_module("logsoftmax", nn.LogSoftmax(dim=1))
else:
module.add_module("identity", nn.Identity())

self.add_module("final_layer", module)

_glorot_weight_zero_bias(self)

Expand Down