Skip to content

Commit

Permalink
DOC Using pure MNE Epochs (#539)
Browse files Browse the repository at this point in the history
* Fix set signal params when module is str

* Add absic usage example

* Update whats_new.rst

* Change title

* rename file

* FIX moving and returning the order

* Apply suggestions Bruno

Co-authored-by: Bru <a.bruno@aluno.ufabc.edu.br>

* Fix a few typos

---------

Co-authored-by: Bru <a.bruno@aluno.ufabc.edu.br>
  • Loading branch information
PierreGtch and bruAristimunha committed Sep 15, 2023
1 parent a4c896d commit 64b8e38
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 17 deletions.
20 changes: 13 additions & 7 deletions braindecode/eegneuralnet.py
Expand Up @@ -24,6 +24,16 @@
log = logging.getLogger(__name__)


def _get_model(model):
''' Returns the corresponding class in case the model passed is a string. '''
if isinstance(model, str):
if model in models_dict:
model = models_dict[model]
else:
raise ValueError(f'Unknown model name {model!r}.')
return model


class _EEGNeuralNet(NeuralNet, abc.ABC):
signal_args_set_ = False

Expand All @@ -41,12 +51,7 @@ def initialize_module(self):
"""
kwargs = self.get_params_for('module')
module = self.module
if isinstance(module, str):
if module in models_dict:
module = models_dict[module]
else:
raise ValueError(f'Unknown model name {module!r}.')
module = _get_model(self.module)
module = self.initialized_instance(module, kwargs)
# pylint: disable=attribute-defined-outside-init
self.module_ = module
Expand Down Expand Up @@ -215,7 +220,8 @@ def _set_signal_args(self, X, y, classes):

# kick out missing kwargs:
module_kwargs = dict()
all_module_kwargs = inspect.signature(self.module.__init__).parameters.keys()
module = _get_model(self.module)
all_module_kwargs = inspect.signature(module.__init__).parameters.keys()
for k, v in signal_kwargs.items():
if v is None:
continue
Expand Down
18 changes: 8 additions & 10 deletions docs/conf.py
Expand Up @@ -23,6 +23,7 @@
import os.path as op

import matplotlib

matplotlib.use('agg')
from datetime import datetime, timezone
import faulthandler
Expand Down Expand Up @@ -129,6 +130,7 @@ def linkcode_resolve(domain, info):

return f"{repo}/blob/master/braindecode/{fn}{linespec}"


# -- Options for sphinx gallery --------------------------------------------
faulthandler.enable()
os.environ['_BRAINDECODE_BROWSER_NO_BLOCK'] = 'true'
Expand Down Expand Up @@ -188,6 +190,7 @@ def linkcode_resolve(domain, info):
#
# The short X.Y version.
import braindecode

release = braindecode.__version__
# The full version, including alpha/beta/rc tags.
version = '.'.join(release.split('.')[:2])
Expand All @@ -210,7 +213,6 @@ def linkcode_resolve(domain, info):
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True


# Sphinx-gallery configuration

# Example configuration for intersphinx: refer to the Python standard library.
Expand Down Expand Up @@ -250,6 +252,7 @@ def linkcode_resolve(domain, info):
# a list of builtin themes.
#
import sphinx_rtd_theme # noqa

html_theme = "pydata_sphinx_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
switcher_version_match = 'dev' if release.endswith('dev0') else version
Expand All @@ -275,16 +278,16 @@ def linkcode_resolve(domain, info):
'show_toc_level': 1,
'navbar_end': ['theme-switcher', 'version-switcher'],
'switcher': {
'json_url': 'https://braindecode.org/stable/_static/versions.json',
'version_match': switcher_version_match,
'json_url': 'https://braindecode.org/stable/_static/versions.json',
'version_match': switcher_version_match,
},
"logo": {
"image_light": "_static/braindecode_symbol.png",
"image_dark": "_static/braindecode_symbol.png",
"alt_text": "Braindecode Logo",
},
'footer_start': ['copyright'],
#'pygment_light_style': 'default',
# 'pygment_light_style': 'default',
'analytics': dict(google_analytics_id='G-7Q43R82K6D'),
}

Expand Down Expand Up @@ -345,7 +348,6 @@ def linkcode_resolve(domain, info):

}


# -- Options for LaTeX output ---------------------------------------------

latex_elements = {
Expand Down Expand Up @@ -377,8 +379,6 @@ def linkcode_resolve(domain, info):
'Robin Tibor Schirrmeister', 'manual'),
]



# -- Fontawesome support -----------------------------------------------------

# here the "fab" and "fas" refer to "brand" and "solid" (determines which font
Expand All @@ -398,7 +398,7 @@ def linkcode_resolve(domain, info):
icons = dict()
for icon in brand_icons + fixed_icons + other_icons:
font = ('fab' if icon in brand_icons else 'fas',) # brand or solid font
fw = ('fa-fw',) if icon in fixed_icons else () # fixed-width
fw = ('fa-fw',) if icon in fixed_icons else () # fixed-width
icons[icon] = font + fw

prolog = ''
Expand All @@ -422,7 +422,6 @@ def linkcode_resolve(domain, info):
.. |ensp| unicode:: U+2002 .. EN SPACE
'''


# -- Options for manual page output ---------------------------------------

# One entry per manual page. List of tuples
Expand All @@ -432,7 +431,6 @@ def linkcode_resolve(domain, info):
[author], 1)
]


# -- Options for Texinfo output -------------------------------------------

# Grouping the document tree into Texinfo files. List of tuples
Expand Down
1 change: 1 addition & 0 deletions docs/whats_new.rst
Expand Up @@ -51,6 +51,7 @@ Enhancements
- Add ``models_dict`` to :mod:`braindecode.models.util` (:gh:`524` by `Pierre Guetschel`_)
- Add support for :class:`mne.Epochs` in :class:`braindecode.EEGClassifier` and :class:`braindecode.EEGRegressor` (:gh:`529` by `Pierre Guetschel`_)
- Allow passing only the name of a braindecode model to :class:`braindecode.EEGClassifier` and :class:`braindecode.EEGRegressor` (:gh:`528` by `Pierre Guetschel`_)
- Add basic training example with MNE epochs (:gh:`539` by `Pierre Guetschel`_)

Bugs
~~~~
Expand Down
177 changes: 177 additions & 0 deletions examples/model_building/plot_basic_training_epochs.py
@@ -0,0 +1,177 @@
"""
Simple training on MNE epochs
=============================
The braindecode library gives you access to a large number of neural network
architectures that were developed for EEG data decoding. This tutorial will
show you how you can easily use any of these models to decode your own data.
In particular, we assume that have your data in an MNE format and want to
train one of the Braindecode models on it.
.. contents:: This example covers:
:local:
:depth: 2
"""

# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
#
# License: BSD (3-clause)

######################################################################
# Finding the model you want
# --------------------------
#
# Exploring the braindecode online documentation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# Let's suppose you recently stumbled upon the Schirrmeister 2017 article [1]_.
# In this article, the authors mention that their novel architecture ShallowConvNet
# is performing well on the BCI Competition IV 2a dataset and you would like to use
# it on your own data. Fortunately, the authors also mentioned they published their
# architecture on Braindecode!
#
# In order to use this architecture, you first need to find what is its exact
# name in Braindecode. To do so, you can visit the Braindecode online documentation
# which lists all the available models.
#
# Models list: https://braindecode.org/stable/api.html#models
#
# Alternatively, the API also provide a dictionary with all available models:

from braindecode.models.util import models_dict

print(f'All the Braindecode models:\n{list(models_dict.keys())}')

######################################################################
# After your investigation, you found out that the model you are looking for is
# ``ShallowFBCSPNet``. You can now import it from Braindecode:

from braindecode.models import ShallowFBCSPNet

######################################################################
# Examining the model
# ~~~~~~~~~~~~~~~~~~~
#
# Now that you found your model, you must check which parameters it expects.
# You can find this information either in the online documentation here:
# :class:`braindecode.models.ShallowFBCSPNet` or directly in the module's docstring:

print(ShallowFBCSPNet.__doc__)

######################################################################
# Additionally, you might be interested in visualizing the model's architecture.
# This can be done by initializing the model and calling its ``__str__()`` method.
# To initialize it, we need to specify some parameters that we set at random
# values for now:

model = ShallowFBCSPNet(
n_chans=32,
n_times=1000,
n_outputs=2,
final_conv_length='auto',
)
print(model)

######################################################################
# Loading your own data with MNE
# ------------------------------
#
# In this tutorial, we demonstrate how to train the model on MNE data.
# MNE is quite a popular library for EEG data analysis as it provides methods
# to load data from many different file formats and a large collection of algorithms
# to preprocess it.
# However, Braindecode is not limited to MNE and can be used with numpy arrays or
# PyTorch tensors/datasets.
#
# For this example, we generate some random data containing 100 examples with each
# 3 channels and 1024 time points. We also generate some random labels for our data
# that simulate a 4-class classification problem.

import mne
import numpy as np

info = mne.create_info(ch_names=['C3', 'C4', 'Cz'], sfreq=256., ch_types='eeg')
X = np.random.randn(100, 3, 1024) # 100 epochs, 3 channels, 4 seconds (@256Hz)
epochs = mne.EpochsArray(X, info=info)
y = np.random.randint(0, 4, size=100) # 4 classes
print(epochs)

######################################################################
# Training your model (scikit-learn compatible)
# ---------------------------------------------
#
# Now that you know which model you want to use, you know how to instantiate it,
# and that we have some fake data, it is time to train the model!
#
# .. note::
# `Skorch <https://skorch.readthedocs.io>`_ is a library that allows you to wrapp
# any PyTorch module into a scikit-learn-compatible classifier or regressor.
# Braindecode provides wrappers that inherit form the original Skorch ones and simply
# implement a few additional features that facilitate the use of Braindecode models.
#
# To train a Braindecode model, the easiest way is by using braindecode's
# Skorch wrappers. These wrappers are :class:`braindecode.EEGClassifier` and
# :class:`braindecode.EEGRegressor`. As our fake data is a classification task,
# we will use the former.
#
# The wrapper :class:`braindecode.EEGClassifier` expects a model class as its first argument but
# to facilitate the usage, you can also simply pass the name of any braindecode model as a string.
# The wrapper automatically finds and instantiates the model for you.
# If you want to pass parameters to your model, you can give them to the wrapper
# with the prefix ``module__``.
#
from skorch.dataset import ValidSplit
from braindecode import EEGClassifier

net = EEGClassifier(
'ShallowFBCSPNet',
module__final_conv_length='auto',
train_split=ValidSplit(0.2),
# To train a neural network you need validation split, here, we use 20%.
)

######################################################################
# In this example, we passed one additional parameter to the wrapper: ``module__final_conv_length``
# that will be forwarded to the model (without the prefix ``module__``).
#
# We also note that the parameters ``n_chans``, ``n_times`` and ``n_outputs`` were not specified
# even if :class:`braindecode.ShallowFBCSPNet` needs them to be initialized. This is because the
# wrapper will automatically infer them, along with some other signal-related parameters,
# from the input data at training time.
#
# Now that we have our model wrapped in a scikit-learn-compatible classifier,
# we can train it by simply calling the ``fit`` method:

net.fit(epochs, y)

######################################################################
# The pre-trained model is accessible via the ``module_`` attribute:

print(net.module_)

######################################################################
# And we can see that all the following parameters were automatically inferred
# from the training data:

print(f'{net.module_.n_chans=}\n{net.module_.n_times=}\n{net.module_.n_outputs=}'
f'\n{net.module_.input_window_seconds=}\n{net.module_.sfreq=}\n{net.module_.chs_info=}')

######################################################################
# Depending on the type of data used for training, some parameters might not be
# possible to infer. For example if you pass a numpy array or a
# :class:`braindecode.dataset.WindowsDataset` with ``target_from="metadata"``,
# then only ``n_chans``, ``n_times`` and ``n_outputs`` will be inferred.
# And if you pass other types of datasets, only ``n_chans`` and ``n_times`` will be inferred.
# In these case, you will have to pass the missing parameters manually
# (with the prefix ``module__``).

######################################################################
# References
# ----------
#
# .. [1] Schirrmeister, R.T., Springenberg, J.T., Fiederer, L.D.J., Glasstetter,
# M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T.(2017).
# Deep learning with convolutional neural networks for EEG decoding and visualization.
# Human Brain Mapping, Aug. 2017.
# Online: http://dx.doi.org/10.1002/hbm.23730

0 comments on commit 64b8e38

Please sign in to comment.