From 64b8e386d0bebc32d7e2e485fe0c7c13fd73218d Mon Sep 17 00:00:00 2001 From: PierreGtch <25532709+PierreGtch@users.noreply.github.com> Date: Fri, 15 Sep 2023 13:00:34 +0200 Subject: [PATCH] DOC Using pure MNE Epochs (#539) * Fix set signal params when module is str * Add absic usage example * Update whats_new.rst * Change title * rename file * FIX moving and returning the order * Apply suggestions Bruno Co-authored-by: Bru * Fix a few typos --------- Co-authored-by: Bru --- braindecode/eegneuralnet.py | 20 +- docs/conf.py | 18 +- docs/whats_new.rst | 1 + .../plot_basic_training_epochs.py | 177 ++++++++++++++++++ 4 files changed, 199 insertions(+), 17 deletions(-) create mode 100644 examples/model_building/plot_basic_training_epochs.py diff --git a/braindecode/eegneuralnet.py b/braindecode/eegneuralnet.py index f9129c463..72bf3c107 100644 --- a/braindecode/eegneuralnet.py +++ b/braindecode/eegneuralnet.py @@ -24,6 +24,16 @@ log = logging.getLogger(__name__) +def _get_model(model): + ''' Returns the corresponding class in case the model passed is a string. ''' + if isinstance(model, str): + if model in models_dict: + model = models_dict[model] + else: + raise ValueError(f'Unknown model name {model!r}.') + return model + + class _EEGNeuralNet(NeuralNet, abc.ABC): signal_args_set_ = False @@ -41,12 +51,7 @@ def initialize_module(self): """ kwargs = self.get_params_for('module') - module = self.module - if isinstance(module, str): - if module in models_dict: - module = models_dict[module] - else: - raise ValueError(f'Unknown model name {module!r}.') + module = _get_model(self.module) module = self.initialized_instance(module, kwargs) # pylint: disable=attribute-defined-outside-init self.module_ = module @@ -215,7 +220,8 @@ def _set_signal_args(self, X, y, classes): # kick out missing kwargs: module_kwargs = dict() - all_module_kwargs = inspect.signature(self.module.__init__).parameters.keys() + module = _get_model(self.module) + all_module_kwargs = inspect.signature(module.__init__).parameters.keys() for k, v in signal_kwargs.items(): if v is None: continue diff --git a/docs/conf.py b/docs/conf.py index 4354e60ec..7cace3a63 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -23,6 +23,7 @@ import os.path as op import matplotlib + matplotlib.use('agg') from datetime import datetime, timezone import faulthandler @@ -129,6 +130,7 @@ def linkcode_resolve(domain, info): return f"{repo}/blob/master/braindecode/{fn}{linespec}" + # -- Options for sphinx gallery -------------------------------------------- faulthandler.enable() os.environ['_BRAINDECODE_BROWSER_NO_BLOCK'] = 'true' @@ -188,6 +190,7 @@ def linkcode_resolve(domain, info): # # The short X.Y version. import braindecode + release = braindecode.__version__ # The full version, including alpha/beta/rc tags. version = '.'.join(release.split('.')[:2]) @@ -210,7 +213,6 @@ def linkcode_resolve(domain, info): # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True - # Sphinx-gallery configuration # Example configuration for intersphinx: refer to the Python standard library. @@ -250,6 +252,7 @@ def linkcode_resolve(domain, info): # a list of builtin themes. # import sphinx_rtd_theme # noqa + html_theme = "pydata_sphinx_theme" html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] switcher_version_match = 'dev' if release.endswith('dev0') else version @@ -275,8 +278,8 @@ def linkcode_resolve(domain, info): 'show_toc_level': 1, 'navbar_end': ['theme-switcher', 'version-switcher'], 'switcher': { - 'json_url': 'https://braindecode.org/stable/_static/versions.json', - 'version_match': switcher_version_match, + 'json_url': 'https://braindecode.org/stable/_static/versions.json', + 'version_match': switcher_version_match, }, "logo": { "image_light": "_static/braindecode_symbol.png", @@ -284,7 +287,7 @@ def linkcode_resolve(domain, info): "alt_text": "Braindecode Logo", }, 'footer_start': ['copyright'], - #'pygment_light_style': 'default', + # 'pygment_light_style': 'default', 'analytics': dict(google_analytics_id='G-7Q43R82K6D'), } @@ -345,7 +348,6 @@ def linkcode_resolve(domain, info): } - # -- Options for LaTeX output --------------------------------------------- latex_elements = { @@ -377,8 +379,6 @@ def linkcode_resolve(domain, info): 'Robin Tibor Schirrmeister', 'manual'), ] - - # -- Fontawesome support ----------------------------------------------------- # here the "fab" and "fas" refer to "brand" and "solid" (determines which font @@ -398,7 +398,7 @@ def linkcode_resolve(domain, info): icons = dict() for icon in brand_icons + fixed_icons + other_icons: font = ('fab' if icon in brand_icons else 'fas',) # brand or solid font - fw = ('fa-fw',) if icon in fixed_icons else () # fixed-width + fw = ('fa-fw',) if icon in fixed_icons else () # fixed-width icons[icon] = font + fw prolog = '' @@ -422,7 +422,6 @@ def linkcode_resolve(domain, info): .. |ensp| unicode:: U+2002 .. EN SPACE ''' - # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples @@ -432,7 +431,6 @@ def linkcode_resolve(domain, info): [author], 1) ] - # -- Options for Texinfo output ------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples diff --git a/docs/whats_new.rst b/docs/whats_new.rst index 73ebe5578..53b2c26f1 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -51,6 +51,7 @@ Enhancements - Add ``models_dict`` to :mod:`braindecode.models.util` (:gh:`524` by `Pierre Guetschel`_) - Add support for :class:`mne.Epochs` in :class:`braindecode.EEGClassifier` and :class:`braindecode.EEGRegressor` (:gh:`529` by `Pierre Guetschel`_) - Allow passing only the name of a braindecode model to :class:`braindecode.EEGClassifier` and :class:`braindecode.EEGRegressor` (:gh:`528` by `Pierre Guetschel`_) +- Add basic training example with MNE epochs (:gh:`539` by `Pierre Guetschel`_) Bugs ~~~~ diff --git a/examples/model_building/plot_basic_training_epochs.py b/examples/model_building/plot_basic_training_epochs.py new file mode 100644 index 000000000..6feb16f29 --- /dev/null +++ b/examples/model_building/plot_basic_training_epochs.py @@ -0,0 +1,177 @@ +""" +Simple training on MNE epochs +============================= + +The braindecode library gives you access to a large number of neural network +architectures that were developed for EEG data decoding. This tutorial will +show you how you can easily use any of these models to decode your own data. +In particular, we assume that have your data in an MNE format and want to +train one of the Braindecode models on it. + +.. contents:: This example covers: + :local: + :depth: 2 + +""" + +# Authors: Pierre Guetschel +# +# License: BSD (3-clause) + +###################################################################### +# Finding the model you want +# -------------------------- +# +# Exploring the braindecode online documentation +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Let's suppose you recently stumbled upon the Schirrmeister 2017 article [1]_. +# In this article, the authors mention that their novel architecture ShallowConvNet +# is performing well on the BCI Competition IV 2a dataset and you would like to use +# it on your own data. Fortunately, the authors also mentioned they published their +# architecture on Braindecode! +# +# In order to use this architecture, you first need to find what is its exact +# name in Braindecode. To do so, you can visit the Braindecode online documentation +# which lists all the available models. +# +# Models list: https://braindecode.org/stable/api.html#models +# +# Alternatively, the API also provide a dictionary with all available models: + +from braindecode.models.util import models_dict + +print(f'All the Braindecode models:\n{list(models_dict.keys())}') + +###################################################################### +# After your investigation, you found out that the model you are looking for is +# ``ShallowFBCSPNet``. You can now import it from Braindecode: + +from braindecode.models import ShallowFBCSPNet + +###################################################################### +# Examining the model +# ~~~~~~~~~~~~~~~~~~~ +# +# Now that you found your model, you must check which parameters it expects. +# You can find this information either in the online documentation here: +# :class:`braindecode.models.ShallowFBCSPNet` or directly in the module's docstring: + +print(ShallowFBCSPNet.__doc__) + +###################################################################### +# Additionally, you might be interested in visualizing the model's architecture. +# This can be done by initializing the model and calling its ``__str__()`` method. +# To initialize it, we need to specify some parameters that we set at random +# values for now: + +model = ShallowFBCSPNet( + n_chans=32, + n_times=1000, + n_outputs=2, + final_conv_length='auto', +) +print(model) + +###################################################################### +# Loading your own data with MNE +# ------------------------------ +# +# In this tutorial, we demonstrate how to train the model on MNE data. +# MNE is quite a popular library for EEG data analysis as it provides methods +# to load data from many different file formats and a large collection of algorithms +# to preprocess it. +# However, Braindecode is not limited to MNE and can be used with numpy arrays or +# PyTorch tensors/datasets. +# +# For this example, we generate some random data containing 100 examples with each +# 3 channels and 1024 time points. We also generate some random labels for our data +# that simulate a 4-class classification problem. + +import mne +import numpy as np + +info = mne.create_info(ch_names=['C3', 'C4', 'Cz'], sfreq=256., ch_types='eeg') +X = np.random.randn(100, 3, 1024) # 100 epochs, 3 channels, 4 seconds (@256Hz) +epochs = mne.EpochsArray(X, info=info) +y = np.random.randint(0, 4, size=100) # 4 classes +print(epochs) + +###################################################################### +# Training your model (scikit-learn compatible) +# --------------------------------------------- +# +# Now that you know which model you want to use, you know how to instantiate it, +# and that we have some fake data, it is time to train the model! +# +# .. note:: +# `Skorch `_ is a library that allows you to wrapp +# any PyTorch module into a scikit-learn-compatible classifier or regressor. +# Braindecode provides wrappers that inherit form the original Skorch ones and simply +# implement a few additional features that facilitate the use of Braindecode models. +# +# To train a Braindecode model, the easiest way is by using braindecode's +# Skorch wrappers. These wrappers are :class:`braindecode.EEGClassifier` and +# :class:`braindecode.EEGRegressor`. As our fake data is a classification task, +# we will use the former. +# +# The wrapper :class:`braindecode.EEGClassifier` expects a model class as its first argument but +# to facilitate the usage, you can also simply pass the name of any braindecode model as a string. +# The wrapper automatically finds and instantiates the model for you. +# If you want to pass parameters to your model, you can give them to the wrapper +# with the prefix ``module__``. +# +from skorch.dataset import ValidSplit +from braindecode import EEGClassifier + +net = EEGClassifier( + 'ShallowFBCSPNet', + module__final_conv_length='auto', + train_split=ValidSplit(0.2), + # To train a neural network you need validation split, here, we use 20%. +) + +###################################################################### +# In this example, we passed one additional parameter to the wrapper: ``module__final_conv_length`` +# that will be forwarded to the model (without the prefix ``module__``). +# +# We also note that the parameters ``n_chans``, ``n_times`` and ``n_outputs`` were not specified +# even if :class:`braindecode.ShallowFBCSPNet` needs them to be initialized. This is because the +# wrapper will automatically infer them, along with some other signal-related parameters, +# from the input data at training time. +# +# Now that we have our model wrapped in a scikit-learn-compatible classifier, +# we can train it by simply calling the ``fit`` method: + +net.fit(epochs, y) + +###################################################################### +# The pre-trained model is accessible via the ``module_`` attribute: + +print(net.module_) + +###################################################################### +# And we can see that all the following parameters were automatically inferred +# from the training data: + +print(f'{net.module_.n_chans=}\n{net.module_.n_times=}\n{net.module_.n_outputs=}' + f'\n{net.module_.input_window_seconds=}\n{net.module_.sfreq=}\n{net.module_.chs_info=}') + +###################################################################### +# Depending on the type of data used for training, some parameters might not be +# possible to infer. For example if you pass a numpy array or a +# :class:`braindecode.dataset.WindowsDataset` with ``target_from="metadata"``, +# then only ``n_chans``, ``n_times`` and ``n_outputs`` will be inferred. +# And if you pass other types of datasets, only ``n_chans`` and ``n_times`` will be inferred. +# In these case, you will have to pass the missing parameters manually +# (with the prefix ``module__``). + +###################################################################### +# References +# ---------- +# +# .. [1] Schirrmeister, R.T., Springenberg, J.T., Fiederer, L.D.J., Glasstetter, +# M., Eggensperger, K., Tangermann, M., Hutter, F. & Ball, T.(2017). +# Deep learning with convolutional neural networks for EEG decoding and visualization. +# Human Brain Mapping, Aug. 2017. +# Online: http://dx.doi.org/10.1002/hbm.23730