Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log accuracy score in EEGClassifier #541

Merged
merged 11 commits into from Sep 27, 2023
48 changes: 38 additions & 10 deletions braindecode/classifier.py
Expand Up @@ -11,6 +11,8 @@
import numpy as np
from skorch import NeuralNet
from skorch.classifier import NeuralNetClassifier
from skorch.callbacks import EpochScoring
from torch.nn import CrossEntropyLoss

from .eegneuralnet import _EEGNeuralNet
from .training.scoring import predict_trials
Expand Down Expand Up @@ -60,19 +62,30 @@ class name. Name conflicts are resolved by appending a count suffix
""" # noqa: E501
__doc__ = update_estimator_docstring(NeuralNetClassifier, doc)

def __init__(self, module, *args, cropped=False, callbacks=None,
iterator_train__shuffle=True,
iterator_train__drop_last=True,
aggregate_predictions=True, **kwargs):
def __init__(
self,
module,
*args,
criterion=CrossEntropyLoss,
cropped=False,
callbacks=None,
iterator_train__shuffle=True,
iterator_train__drop_last=True,
aggregate_predictions=True,
**kwargs
):
self.cropped = cropped
self.aggregate_predictions = aggregate_predictions
self._last_window_inds_ = None
super().__init__(module,
*args,
callbacks=callbacks,
iterator_train__shuffle=iterator_train__shuffle,
iterator_train__drop_last=iterator_train__drop_last,
**kwargs)
super().__init__(
module,
*args,
criterion=criterion,
callbacks=callbacks,
iterator_train__shuffle=iterator_train__shuffle,
iterator_train__drop_last=iterator_train__drop_last,
**kwargs,
)

def get_iterator(self, dataset, training=False, drop_index=True):
iterator = super().get_iterator(dataset, training=training)
Expand Down Expand Up @@ -231,3 +244,18 @@ def _get_n_outputs(self, y, classes):
else:
classes = classes_y
return len(classes)

# Only add the 'accuracy' callback if we are not in cropped mode.
@property
def _default_callbacks(self):
callbacks = list(super()._default_callbacks)
if not self.cropped:
callbacks.append((
'valid_acc',
EpochScoring(
'accuracy',
name='valid_acc',
lower_is_better=False,
)
))
return callbacks
3 changes: 1 addition & 2 deletions braindecode/eegneuralnet.py
Expand Up @@ -139,8 +139,7 @@ def predict_with_window_inds_and_ys(self, dataset):
preds=preds, i_window_in_trials=i_window_in_trials,
i_window_stops=i_window_stops, window_ys=window_ys)

# Removes default EpochScoring callback computing 'accuracy' to work
# properly with cropped decoding.
# Changes the default target extractor to noop
@property
def _default_callbacks(self):
return [
Expand Down
1 change: 1 addition & 0 deletions docs/whats_new.rst
Expand Up @@ -53,6 +53,7 @@ Enhancements
- Add support for :class:`mne.Epochs` in :class:`braindecode.EEGClassifier` and :class:`braindecode.EEGRegressor` (:gh:`529` by `Pierre Guetschel`_)
- Allow passing only the name of a braindecode model to :class:`braindecode.EEGClassifier` and :class:`braindecode.EEGRegressor` (:gh:`528` by `Pierre Guetschel`_)
- Add basic training example with MNE epochs (:gh:`539` by `Pierre Guetschel`_)
- Log validation accuracy in :class:`braindecode.EEGClassifier` (:gh:`541` by `Pierre Guetschel`_)

Bugs
~~~~
Expand Down
6 changes: 1 addition & 5 deletions examples/advanced_training/plot_relative_positioning.py
Expand Up @@ -314,14 +314,11 @@ def forward(self, x):
early_stopping = EarlyStopping(patience=10)
train_acc = EpochScoring(
scoring='accuracy', on_train=True, name='train_acc', lower_is_better=False)
valid_acc = EpochScoring(
scoring='accuracy', on_train=False, name='valid_acc',
lower_is_better=False)

callbacks = [
('cp', cp),
('patience', early_stopping),
('train_acc', train_acc),
('valid_acc', valid_acc)
]

clf = EEGClassifier(
Expand Down Expand Up @@ -372,7 +369,6 @@ def forward(self, x):
styles = ['-', ':']
markers = ['.', '.']


fig, ax1 = plt.subplots(figsize=(8, 3))
ax2 = ax1.twinx()
for y1, y2, style, marker in zip(ys1, ys2, styles, markers):
Expand Down
3 changes: 3 additions & 0 deletions examples/applied_examples/plot_sleep_staging_usleep.py
Expand Up @@ -249,6 +249,9 @@ def balanced_accuracy_multi(model, X, y):
device=device,
classes=classes,
)
# Deactivate the default valid_acc callback:
clf.set_params(callbacks__valid_acc=None)

# Model training for a specified number of epochs. `y` is None as it is already
# supplied in the dataset.
clf.fit(train_set, y=None, epochs=n_epochs)
Expand Down
4 changes: 3 additions & 1 deletion test/unit_tests/test_eegneuralnet.py
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import torch
import mne
from scipy.special import softmax
from sklearn.base import clone
from skorch.callbacks import LRScheduler
from skorch.utils import to_tensor
Expand Down Expand Up @@ -162,8 +163,9 @@ def test_trialwise_predict_and_predict_proba(eegneuralnet_cls):
)
eegneuralnet.initialize()
target_predict = preds if isinstance(eegneuralnet, EEGRegressor) else preds.argmax(1)
preds = preds if isinstance(eegneuralnet, EEGRegressor) else softmax(preds, axis=1)
np.testing.assert_array_equal(target_predict, eegneuralnet.predict(MockDataset()))
np.testing.assert_array_equal(preds, eegneuralnet.predict_proba(MockDataset()))
np.testing.assert_allclose(preds, eegneuralnet.predict_proba(MockDataset()))


def test_cropped_predict_and_predict_proba(eegneuralnet_cls, preds):
Expand Down