> *This notebook is a modified version of the Braindecode tuturial [Fine-tuning a Foundation Model (Signal-JEPA)](https://braindecode.org/dev/auto_examples/advanced_training/plot_finetune_foundation_model.html). The main difference is that we use the [2025 EEG challenge](https://eeg2025.github.io/) data.*

> **[Note 2025-10-30]** It will soon be even easier to load a pre-trained foundation model from Braindecode, thanks to [PR#795](https://github.com/braindecode/braindecode/pull/795). Many new models will be added!

# Fine-tuning a Foundation Model (Signal-JEPA)

Foundation models are large-scale pre-trained models that serve as a starting point
for a wide range of downstream tasks, leveraging their generalization capabilities.
Fine-tuning these models is necessary to adapt them to specific tasks or datasets,
ensuring optimal performance in specialized applications.

In this tutorial, we demonstrate how to load a pre-trained foundation model
and fine-tune it for a specific task. We use the Signal-JEPA model [1]_
and a MOABB motor-imagery dataset for this tutorial.
   :depth: 2


In [None]:
!pip install braindecode
# we need PR#792 which fixes Labram: f69f12b
# !pip install git+https://github.com/braindecode/braindecode.git@f69f12b38d33d6341172bdf43457034ccfeab5ba
!pip install eegdash



In [None]:
# Authors: Pierre Guetschel <pierre.guetschel@gmail.com>
#
# License: BSD (3-clause)
#
from pathlib import Path
import mne
import numpy as np
import torch
from braindecode import EEGRegressor
from braindecode.datasets import MOABBDataset
from braindecode.models import SignalJEPA_PreLocal
from braindecode.preprocessing import (
    create_windows_from_events,
    Preprocessor,
    preprocess,
)
from eegdash.dataset import EEGChallengeDataset
from eegdash.hbn.windows import (
    annotate_trials_with_target,
    add_aux_anchors,
    add_extras_columns,
    keep_only_recordings_with,
)

torch.manual_seed(12)
np.random.seed(12)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    msg ='CUDA-enabled GPU found. Training should be faster.'
else:
    msg = (
        "No GPU found. Training will be carried out on CPU, which might be "
        "slower.\n\nIf running on Google Colab, you can request a GPU runtime by"
        " clicking\n`Runtime/Change runtime type` in the top bar menu, then "
        "selecting \'T4 GPU\'\nunder \'Hardware accelerator\'."
    )
print(msg)

CUDA-enabled GPU found. Training should be faster.


## Loading and preparing the data

### Loading a dataset

We start by loading a MOABB dataset, a single subject only for speed.
The dataset contains motor imagery EEG recordings, which we will preprocess and use for fine-tuning.




In [None]:
DATA_DIR = Path("data")
DATA_DIR.mkdir(parents=True, exist_ok=True)

dataset_ccd = EEGChallengeDataset(
    task="contrastChangeDetection",
    release="R1",
    cache_dir=DATA_DIR,
    mini=False,
)

# Set the montage for EEG channel locations
montage = mne.channels.make_standard_montage("GSN-HydroCel-129")
for ds in dataset_ccd.datasets:
    ds.raw.set_montage(montage)

### Define Dataset parameters

We extract the sampling frequency and ensure that it is consistent across
all recordings. We also extract the window size from the annotations and
information about the EEG channels (names, positions, etc.).




In [None]:
# Extract sampling frequency
SFREQ = dataset_ccd.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == SFREQ for ds in dataset_ccd.datasets])

# # Extract and validate window size from annotations
# window_size_seconds = dataset.datasets[0].raw.annotations.duration[0]
# assert all(
#     d == window_size_seconds
#     for ds in dataset.datasets
#     for d in ds.raw.annotations.duration
# )

# Extract channel information
chs_info = dataset_ccd.datasets[0].raw.info["chs"]  # Channel information

print(f"{SFREQ=}, {len(chs_info)=}")

SFREQ=100.0, len(chs_info)=129


### Create Windows from Events

We use the `create_windows_from_events` function from Braindecode to segment
the dataset into windows based on events.




In [None]:
EPOCH_LEN_S = 2.0

transformation_offline = [
    Preprocessor(
        annotate_trials_with_target,
        target_field="rt_from_stimulus", epoch_length=EPOCH_LEN_S,
        require_stimulus=True, require_response=True,
        apply_on_array=False,
    ),
    Preprocessor(add_aux_anchors, apply_on_array=False),
]
preprocess(dataset_ccd, transformation_offline, n_jobs=3)

ANCHOR = "stimulus_anchor"
SHIFT_AFTER_STIM = 0.5
WINDOW_LEN = 2.0

# Keep only recordings that actually contain stimulus anchors
dataset = keep_only_recordings_with(ANCHOR, dataset_ccd)

# Create single-interval windows (stim-locked, long enough to include the response)
windows_dataset = create_windows_from_events(
    dataset,
    mapping={ANCHOR: 0},
    trial_start_offset_samples=int(SHIFT_AFTER_STIM * SFREQ),                 # +0.5 s
    trial_stop_offset_samples=int((SHIFT_AFTER_STIM + WINDOW_LEN) * SFREQ),   # +2.5 s
    window_size_samples=int(EPOCH_LEN_S * SFREQ),
    window_stride_samples=int(SFREQ),
    preload=True,
)

# Injecting metadata into the extra mne annotation.
windows_dataset = add_extras_columns(
    windows_dataset,
    dataset,
    desc=ANCHOR,
    keys=("target", "rt_from_stimulus", "rt_from_trialstart",
          "stimulus_onset", "response_onset", "correct", "response_type")
)

metadata = windows_dataset.get_metadata()
print(metadata.head(10))

Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_('stimulus_anchor')]
Used Annotations descriptions: [np.str_(

## Loading a pre-trained foundation model

### Download and Load Pre-trained Weights

We download the pre-trained weights for the SignalJEPA model from the Hugging Face Hub.
These weights will serve as the starting point for finetuning.




In [None]:
model_state_dict = torch.hub.load_state_dict_from_url(
    url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
)
# !wget https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth
# model_state_dict = torch.load('signal-jepa_16s-60_adeuwv4s.pth')
# print(model_state_dict.keys())

### Instantiate the Foundation Model

We create an instance of the SignalJEPA model using the pre-local downstream
architecture. The model is initialized with the dataset's sampling frequency,
window size, and channel information.




In [None]:
model = SignalJEPA_PreLocal(
    sfreq=SFREQ,
    input_window_seconds=EPOCH_LEN_S,
    chs_info=chs_info,
    n_outputs=1,  # Regression task
)
# print(model)

### Load the Pre-trained Weights into the Model

We load the pre-trained weights into the model. The transformer layers are excluded
as this module is not used in the pre-local downstream architecture (see [1]_).




In [None]:
# Define layers to exclude from the pre-trained weights
new_layers = {
    "spatial_conv.1.weight",
    "spatial_conv.1.bias",
    "final_layer.1.weight",
    "final_layer.1.bias",
}

# Filter out transformer weights and load the state dictionary
model_state_dict = {
    k: v for k, v in model_state_dict.items() if not k.startswith("transformer.")
}
missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False)

# Ensure no unexpected keys and validate missing keys
assert unexpected_keys == [], f"{unexpected_keys=}"
assert set(missing_keys) == new_layers, f"{missing_keys=}"

## Fine-tuning the Model

Signal-JEPA is a model trained in a self-supervised manner on a masked
prediction task. In this task, the model is configured in a many-to-many
fashion, which is not suited for a classification task. Therefore, we need to
adjust the model architecture for finetuning. This is what is done by the
:class:`SignalJEPA_PreLocal`, :class:`SignalJEPA_Contextual`, and
:class:`SignalJEPA_PostLocal` classes. In these classes, new layers are added
specifically for classification, as described in the article [1]_ and in the following figure:

<img src="file://_static/model/sjepa_pre-local.jpg" alt="Signal-JEPA Pre-Local Downstream Architecture" align="center">

With this downstream architecture, two options are possible for fine-tuning:

1) Fine-tune only the newly added layers
2) Fine-tune the entire model

### Freezing Pre-trained Layers

As the second option is rather straightforward to implement,
we will focus on the first option here.
We will freeze all layers except the newly added ones.




In [None]:
for name, param in model.named_parameters():
    if name not in new_layers:
        param.requires_grad = False

print("Trainable parameters:")
other_modules = set()
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)
    else:
        other_modules.add(name.split(".")[0])

print("\nOther modules:")
print(other_modules)

Trainable parameters:
spatial_conv.1.weight
spatial_conv.1.bias
final_layer.1.weight
final_layer.1.bias

Other modules:
{'feature_encoder'}


### Fine-tuning Procedure

Finally, we set up the fine-tuning procedure using Braindecode's
:class:`EEGRegressor`. We define the loss function, optimizer, and training
parameters. We then fit the model to the windows dataset.

We only train for a few epochs for demonstration purposes.




In [None]:
clf = EEGRegressor(
    model,
    optimizer=torch.optim.AdamW,
    optimizer__lr=0.005,
    batch_size=64,
    callbacks=['neg_root_mean_squared_error'],
    device=device,
)
_ = clf.fit(windows_dataset, y=metadata["target"], epochs=10)

  epoch    train_loss    train_neg_root_mean_squared_error    valid_loss    valid_neg_root_mean_squared_error     dur
-------  ------------  -----------------------------------  ------------  -----------------------------------  ------
      1        [36m2.5882[0m                              [32m-1.5611[0m        [35m2.4270[0m                              [31m-1.5579[0m  5.2947
      2        [36m2.3084[0m                              [32m-1.4675[0m        [35m2.1564[0m                              [31m-1.4685[0m  4.3558
      3        [36m2.0284[0m                              [32m-1.3733[0m        [35m1.9051[0m                              [31m-1.3803[0m  3.5982
      4        [36m1.7724[0m                              [32m-1.2801[0m        [35m1.6724[0m                              [31m-1.2932[0m  4.3484
      5        [36m1.5264[0m                              [32m-1.1890[0m        [35m1.4581[0m                              [31m-1.2075[0m  3

### All-in-one Implementation

In the implementation above, we manually loaded the weights and froze the layers.
This forces us to pass an initialized model to :class:`EEGRegressor`, which may
create issues if we use it in a cross-validation setting.

Instead, we can implement the same procedure in a more compact and reproducible way,
by using skorch's callback system.

Here, we import a callback to freeze layers and define a custom
callback to load the pre-trained weights at the beginning of training:




In [None]:
from skorch.callbacks import Callback, Freezer


class WeightsLoader(Callback):
    def __init__(self, url, strict=False):
        self.url = url
        self.strict = strict

    def on_train_begin(self, net, X=None, y=None, **kwargs):
        state_dict = torch.hub.load_state_dict_from_url(url=self.url)
        net.module_.load_state_dict(state_dict, strict=self.strict)

We can now define a classifier with those callbacks, without having
to pass an initialized model, and fit it as before:




In [None]:
classifiers = {}

classifiers['SignalJEPA_PreLocal'] = EEGRegressor(
    "SignalJEPA_PreLocal",
    optimizer=torch.optim.AdamW,
    optimizer__lr=0.005,
    batch_size=64,
    callbacks=[
        'neg_root_mean_squared_error',
        WeightsLoader(
            url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
        ),
        Freezer(patterns="feature_encoder.*"),
    ],
)



In [None]:
for name, clf in classifiers.items():
    _ = clf.fit(windows_dataset, y=metadata["target"], epochs=10)

  epoch    train_loss    train_neg_root_mean_squared_error    valid_loss    valid_neg_root_mean_squared_error     dur
-------  ------------  -----------------------------------  ------------  -----------------------------------  ------
      1        [36m2.6664[0m                              [32m-1.5843[0m        [35m2.5012[0m                              [31m-1.5815[0m  4.9509
      2        [36m2.3731[0m                              [32m-1.4917[0m        [35m2.2270[0m                              [31m-1.4923[0m  4.8098
      3        [36m2.1096[0m                              [32m-1.3981[0m        [35m1.9713[0m                              [31m-1.4040[0m  4.2959
      4        [36m1.8351[0m                              [32m-1.3043[0m        [35m1.7328[0m                              [31m-1.3164[0m  5.1132
      5        [36m1.5817[0m                              [32m-1.2122[0m        [35m1.5136[0m                              [31m-1.2303[0m  4

## Conclusion and Next Steps

In this tutorial, we demonstrated how to fine-tune a pre-trained foundation
model, Signal-JEPA, for a motor imagery classification task. We now have a basic
implementation that can automatically load pre-trained weights and freeze specific layers.

This setup can easily be extended to explore different fine-tuning techniques,
base foundation models, and downstream tasks.




## References

.. [1] Guetschel, P., Moreau, T., and Tangermann, M. (2024)
       “S-JEPA: towards seamless cross-dataset transfer
       through dynamic spatial attention”.  https://arxiv.org/abs/2403.11772

