# Braindecode

In this brief tutorial we will show how to use the main EEG libraries (MNE, Braindecode) for benchmarking or exploratory purposes (test models on the target dataset). This can alleviate the burden on the exploration of models thanks to the already pretrained state of the art models found in the literature.

In [7]:
import mne 
import numpy as np 
from braindecode.models.util import models_dict

from skorch.dataset import ValidSplit
from braindecode import EEGClassifier

## Data Loading & Format

Firstly we need to understand the data format of EEG signals, which follow a nomenclature known as the **10-20 system**, composed by identifying region (characters) and either a number or a ending Z, following regex: **^(Fp|F|C|P|O|T|AF|FC|CP|PO|FT|TP)([1-9][0-9]*|z)$**

1) Characters: Fp (Frontopolar), F (Frontal), C (Central), P (Parietal), T (Temporal), O (Occipital).
2) Number: Indicates the horizontal position (X-axis) with **odd num -> Left hemisphere**, **even num -> Right Hemisphere**.
3) Ending "z": Indicates an special position at the midline of the skull. (Fz, Cz...)

In [5]:
# 1. DEFINING METADATA (The "Header")
# This object holds the hardware specifications.
# ch_names: The specific electrode locations (Standard 10-20 system).
# sfreq: Sampling Frequency (Hz). 256 Hz means we record 256 data points per second.
info = mne.create_info(
    ch_names=["C3", "C4", "Cz"], 
    sfreq=256, 
    ch_types="eeg"
)

# 2. GENERATING DATA (The "Signal")
# We create a 3D NumPy array representing the raw electrical potential.
# The shape must strictly follow: (n_epochs, n_channels, n_times)
#   - 100 Epochs:   100 distinct trials/events (e.g., 100 imagined movements).
#   - 3 Channels:   The spatial dimension (C3, C4, Cz).
#   - 1024 Samples: The temporal dimension. 
#                   Calculation: 4 seconds * 256 Hz = 1024 time points.
X = np.random.randn(100, 3, 1024)

# 3. CREATING THE CONTAINER (The MNE Object)
# EpochsArray binds the raw data (X) with the metadata (info).
# This provides access to built-in signal processing methods (filtering, plotting).
epochs = mne.EpochsArray(X, info=info)

# 4. TARGET LABELS
# The classification targets for the 100 trials.
# Classes: 0, 1, 2, 3 (e.g., Left, Right, Tongue, Passive).
y = np.random.randint(0, 4, size=100)

# 5. INSPECTION
print(epochs)
# Output will confirm: "100 events (all good), 0 - 3.99 sec, 3 ch"

Not setting metadata
100 matching events found
No baseline correction applied
0 projection items activated
<EpochsArray | 100 events (all good), 0 – 3.996 s (baseline off), ~2.4 MiB, data loaded,
 '1': 100>


## Modelling

Braindecode provides a Pytorch environment using Skorch for providing an scikit-learn compatible interface, providing models for **3 different objectives (Regression, Classification, Embeddings)**, and for multiple **uses cases ()**

In [None]:
print(f"All the Braindecode models:\n{list(models_dict.keys())}")

  from .autonotebook import tqdm as notebook_tqdm


All the Braindecode models:
['ATCNet', 'AttentionBaseNet', 'AttnSleep', 'BDTCN', 'BIOT', 'CTNet', 'ContraWR', 'Deep4Net', 'DeepSleepNet', 'EEGConformer', 'EEGITNet', 'EEGInceptionERP', 'EEGInceptionMI', 'EEGMiner', 'EEGNeX', 'EEGNet', 'EEGSimpleConv', 'EEGTCNet', 'FBCNet', 'FBLightConvNet', 'FBMSNet', 'IFNet', 'Labram', 'MSVTNet', 'SCCNet', 'SPARCNet', 'ShallowFBCSPNet', 'SignalJEPA', 'SignalJEPA_Contextual', 'SignalJEPA_PostLocal', 'SignalJEPA_PreLocal', 'SincShallowNet', 'SleepStagerBlanco2020', 'SleepStagerChambon2018', 'SyncNet', 'TIDNet', 'TSception', 'USleep']


In [12]:
net = EEGClassifier(
    "ShallowFBCSPNet",
    module__final_conv_length="auto",
    train_split=ValidSplit(0.2),
)
# Epochs is the X data format for EEG
net.fit(epochs, y)

  epoch    valid_acc    valid_loss     dur
-------  -----------  ------------  ------
      1       [36m0.3000[0m       [32m44.7866[0m  0.0185
      2       0.3000       44.7866  0.0093
      3       0.3000       44.7866  0.0096
      4       0.3000       44.7866  0.0108
      5       0.3000       44.7866  0.0098
      6       0.3000       44.7866  0.0182
      7       0.3000       44.7866  0.0142
      8       0.3000       44.7866  0.0124
      9       0.3000       44.7866  0.0096
     10       0.3000       44.7866  0.0082


0,1,2
,module,'ShallowFBCSPNet'
,criterion,<class 'torch...sEntropyLoss'>
,cropped,False
,callbacks,
,iterator_train__shuffle,True
,iterator_train__drop_last,True
,aggregate_predictions,True
,optimizer,<class 'torch.optim.sgd.SGD'>
,lr,0.01
,max_epochs,10


In [13]:
print(net.module_)

Layer (type (var_name):depth-idx)             Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)             [1, 3, 1024]              [1, 4]                    --                        --
├─Ensure4d (ensuredims): 1-1                  [1, 3, 1024]              [1, 3, 1024, 1]           --                        --
├─Rearrange (dimshuffle): 1-2                 [1, 3, 1024, 1]           [1, 1, 1024, 3]           --                        --
├─CombinedConv (conv_time_spat): 1-3          [1, 1, 1024, 3]           [1, 40, 1000, 1]          5,840                     --
├─BatchNorm2d (bnorm): 1-4                    [1, 40, 1000, 1]          [1, 40, 1000, 1]          80                        --
├─Expression (conv_nonlin_exp): 1-5           [1, 40, 1000, 1]          [1, 40, 1000, 1]          --                        --
├─AvgPool2d (pool): 1-6                       [1, 40, 1000, 1]          [1, 40, 62, 1]            -- 

Or use the equivalent model nomenclature

In [10]:
import torch
from braindecode.models import ShallowFBCSPNet
from braindecode.util import set_random_seeds
print(ShallowFBCSPNet.__doc__)

Shallow ConvNet model from Schirrmeister et al (2017) [Schirrmeister2017]_.

:bdg-success:`Convolution`

.. figure:: https://onlinelibrary.wiley.com/cms/asset/221ea375-6701-40d3-ab3f-e411aad62d9e/hbm23730-fig-0002-m.jpg
    :align: center
    :alt: ShallowNet Architecture

Model described in [Schirrmeister2017]_.

Parameters
----------
n_chans : int
    Number of EEG channels.
n_outputs : int
    Number of outputs of the model. This is the number of classes
    in the case of classification.
n_times : int
    Number of time samples of the input window.
n_filters_time: int
    Number of temporal filters.
filter_time_length: int
    Length of the temporal filter.
n_filters_spat: int
    Number of spatial filters.
pool_time_length: int
    Length of temporal pooling filter.
pool_time_stride: int
    Length of stride between temporal pooling filters.
final_conv_length: int | str
    Length of the final convolution layer.
    If set to "auto", length of the input signal must be specified.
c

In [11]:
model = ShallowFBCSPNet(
    n_chans=32,
    n_times=1000,
    n_outputs=2,
    final_conv_length="auto",
)
print(model)

Layer (type (var_name):depth-idx)             Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)             [1, 32, 1000]             [1, 2]                    --                        --
├─Ensure4d (ensuredims): 1-1                  [1, 32, 1000]             [1, 32, 1000, 1]          --                        --
├─Rearrange (dimshuffle): 1-2                 [1, 32, 1000, 1]          [1, 1, 1000, 32]          --                        --
├─CombinedConv (conv_time_spat): 1-3          [1, 1, 1000, 32]          [1, 40, 976, 1]           52,240                    --
├─BatchNorm2d (bnorm): 1-4                    [1, 40, 976, 1]           [1, 40, 976, 1]           80                        --
├─Expression (conv_nonlin_exp): 1-5           [1, 40, 976, 1]           [1, 40, 976, 1]           --                        --
├─AvgPool2d (pool): 1-6                       [1, 40, 976, 1]           [1, 40, 61, 1]            -- 

## Benchmarking different models

## Using our Dataset