Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions braindecode/preprocessing/eegprep_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,53 @@ def apply_eeg(self, eeg: dict[str, Any], raw: BaseRaw) -> dict[str, Any]:
"""Apply the preprocessor to an EEGLAB EEG structure. Overridden by subclass."""
...

@staticmethod
def _get_annotation_durations(raw: BaseRaw) -> list[float]:
"""Capture annotation durations in seconds before the EEGLAB round-trip."""
if raw.annotations is None or len(raw.annotations) == 0:
return []
return [float(duration) for duration in raw.annotations.duration]

@staticmethod
def _duration_to_samples(duration_s: float, sfreq: float) -> int:
"""Convert a duration in seconds to EEGLAB event samples."""
duration_samples = int(round(duration_s * sfreq))
if duration_s > 0:
return max(1, duration_samples)
return duration_samples

@staticmethod
def _restore_event_durations(
eeg: dict[str, Any],
annotation_durations: Sequence[float],
opname: str,
) -> None:
"""Restore annotation durations after EEGPrep may have changed the sampling rate."""
events = eeg.get("event", [])
if len(events) == 0:
return

non_boundary_events = [ev for ev in events if ev.get("type") != "boundary"]
if annotation_durations and len(non_boundary_events) != len(
annotation_durations
):
log.warning(
"EEGPrep event count changed during %s processing (%d annotated events,"
" %d non-boundary events); restoring durations in order for the"
" overlapping subset only.",
opname,
len(annotation_durations),
len(non_boundary_events),
)

sfreq = float(eeg["srate"])
to_samples = EEGPrepBasePreprocessor._duration_to_samples
for ev, duration_s in zip(non_boundary_events, annotation_durations):
ev["duration"] = to_samples(duration_s, sfreq)

for ev in events:
ev.setdefault("duration", 1)

def _apply_op(self, raw: BaseRaw) -> None:
"""Internal method that does the actual work; this is called by Preprocessor.apply()."""
# handle error if eegprep is not available
Expand Down Expand Up @@ -121,17 +168,13 @@ def _apply_op(self, raw: BaseRaw) -> None:
eeg = raw
non_eeg = None

annotation_durations = self._get_annotation_durations(eeg)
eeg = eegprep.mne2eeg(eeg)

# back up channel locations for potential later use
orig_chanlocs = [cl.copy() for cl in eeg["chanlocs"]]

# ensure all events in EEG structure have a 'duration' field; this is
# necessary for some of the EEGPrep operations to succeed
if not all("duration" in ev for ev in eeg["event"]):
for ev in eeg["event"]:
if "duration" not in ev:
ev["duration"] = 1
self._restore_event_durations(eeg, annotation_durations, opname)

if self.force_dtype is not None:
eeg["data"] = eeg["data"].astype(self.force_dtype)
Expand All @@ -142,6 +185,8 @@ def _apply_op(self, raw: BaseRaw) -> None:
if self.force_dtype is not None:
eeg["data"] = eeg["data"].astype(self.force_dtype)

self._restore_event_durations(eeg, annotation_durations, opname)

# rename EEGLAB-type boundary events to a form that's recognized by MNE so they
# (or intersecting epochs) are ignored during potential downstream epoching
# done by braindecode pipelines
Expand Down
6 changes: 6 additions & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ Current 1.5.0 (GitHub)
Enhancements
============

- Tutorials now train for a few epochs then load pretrained weights from
Hugging Face Hub to show full training curves and metrics. All 9 tutorial
checkpoints published to ``huggingface.co/braindecode/``. The offline
training script used to produce the checkpoints is available as a gist:
https://gist.github.com/bruAristimunha/27d74c8410fe9d0db258a03f42efa7c6.
Comment thread
bruAristimunha marked this conversation as resolved.
(:pr:`985` by :user:`bruAristimunha`)
Comment thread
bruAristimunha marked this conversation as resolved.
- Use ``F.scaled_dot_product_attention`` in :class:`braindecode.modules.MultiHeadAttention`,
enabling optimized attention kernels (flash-attention on CUDA,
memory-efficient backends on other devices).
Expand Down
41 changes: 32 additions & 9 deletions examples/advanced_training/bcic_iv_4_ecog_cropped.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@
######################################################################
# We select only first 30 seconds from the training dataset to limit time and memory
# to run this example. We split training dataset into train and validation (only 6 seconds).
# To obtain full results whole datasets should be used.
valid_set = preprocess(
copy.deepcopy(train_set), [Preprocessor("crop", tmin=24, tmax=30)], n_jobs=-1
)
Expand Down Expand Up @@ -274,7 +273,7 @@
# cross validation on your training data.
#

from skorch.callbacks import LRScheduler
from skorch.callbacks import EarlyStopping, LRScheduler
from skorch.helper import predefined_split

from braindecode import EEGRegressor
Expand All @@ -299,7 +298,7 @@
iterator_train__shuffle=True,
batch_size=batch_size,
callbacks=[
("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=n_epochs - 1)),
("lr_scheduler", LRScheduler("CosineAnnealingLR", T_max=max(1, n_epochs - 1))),
(
"r2_train",
CroppedTimeSeriesEpochScoring(
Expand All @@ -318,6 +317,7 @@
name="r2_valid",
),
),
("early_stopping", EarlyStopping(patience=10, load_best=True)),
],
device=device,
)
Expand All @@ -328,6 +328,35 @@
# in the dataset.
regressor.fit(train_set, y=None, epochs=n_epochs)

######################################################################
# Training for longer
# -------------------
#
# The gallery build above uses only ``n_epochs = 8``. When trained
# offline on the full recording for up to 100 epochs with early stopping,
# the model reaches a test-set mean Pearson r of **0.07**.
#
# We can load the pretrained checkpoint from the Hugging Face Hub and
# inspect the full training curves:

import warnings

repo_id = "braindecode/bcic_iv_4_ecog_cropped"
try:
from huggingface_hub import hf_hub_download

regressor.initialize()
regressor.load_params(
f_params=hf_hub_download(repo_id, "params.safetensors"),
f_history=hf_hub_download(repo_id, "history.json"),
use_safetensors=True,
)
Comment thread
bruAristimunha marked this conversation as resolved.
except Exception as exc:
warnings.warn(
f"Could not load pretrained checkpoint from {repo_id} ({exc}); "
"continuing with the locally trained short-run model.",
stacklevel=2,
)

Comment thread
bruAristimunha marked this conversation as resolved.
######################################################################
# Obtaining predictions and targets for the test, train, and validation dataset
Expand Down Expand Up @@ -364,12 +393,6 @@ def pad_and_select_predictions(preds, y):
######################################################################
# We plot target and predicted finger flexion on training, validation, and test sets.
#
# .. note::
# The model is trained and validated on limited dataset (to decrease the time needed to run
# this example) which does not contain diverse dataset in terms of fingers flexions and may
# cause overfitting. To obtain better results use whole dataset as well as improve the decoding
# pipeline which may be not optimal for ECoG.
#
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.lines import Line2D
Expand Down
Loading
Loading