### Step: Set up general parameters
We'll be using the small version of the model and a subset of the EEG channels to simplify the analysis.

In [1]:
!uv pip install numpy pandas scikit-learn matplotlib tqdm torch moabb pyriemann
!uv pip install git+https://github.com/moment-timeseries-foundation-model/moment.git

TORCH_DEVICE = 'mps' # Apple Silicon GPU
MODEL_NAME = 'AutonLab/MOMENT-1-small'
EEG_CHANNELS = ['Fz', 'C3', 'Cz', 'C4', 'P3', 'Pz', 'P4', 'O1', 'O2'] 
SAMPLING_FREQUENCY = 512

[2mUsing Python 3.10.14 environment at /Users/scastanoc/miniforge3[0m
[2mAudited [1m8 packages[0m [2min 8ms[0m[0m
[2mUsing Python 3.10.14 environment at /Users/scastanoc/miniforge3[0m
[2K[36m[1mUpdating[0m[39m https://github.com/moment-timeseries-foundation-model/moment.git ([2mHEAD[0m)
[2K[1A[36m[1mUpdating[0m[39m https://github.com/moment-timeseries-foundation-model/moment.git ([2mHEAD[0m)
[2K[1A[36m[1mUpdating[0m[39m https://github.com/moment-timeseries-foundation-model/moment.git ([2mHEAD[0m)
[2K[1A[36m[1mUpdating[0m[39m https://github.com/moment-timeseries-foundation-model/moment.git ([2mHEAD[0m)
[2K[1A [32m[1mUpdated[0m[39m https://github.com/moment-timeseries-foundation-model/moment.git ([2mfb62093[
[2K[2mResolved [1m24 packages[0m [2min 580ms[0m[0m                                        [0m
[2mAudited [1m24 packages[0m [2min 0.07ms[0m[0m


### Step: Load the data
We will use one of the sessions of the ERP study published by [Hübner et al. 2017](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0175856) and available via de Mother Of All BCI Benchmarks [MOABB](http://moabb.neurotechx.com/docs/index.html). The experiment in a nutshell: The patients are presented a rapid sequence of different visual stimuli, some of which they are instructed to pay attention to---target stimuli---and some of which they have to ignore---non-target stimuli---. It is expected that the brain response to each class is different.

The dataset contains data for 13 subjects participating in 3 sessions each. It is not the scope of this post to do a throughout benchmark, so we will limit the dataset to one of the sessions of one of the subjects.

In [2]:

import contextlib
import io
import warnings

import numpy as np
from sklearn.preprocessing import LabelEncoder
from moabb import datasets
from moabb.paradigms import P300


dataset = datasets.Huebner2017(interval=[0,.99])
dataset.download()



paradigm = P300(    
    resample=SAMPLING_FREQUENCY,
    baseline=None,
    channels=EEG_CHANNELS
)


stdout = io.StringIO()
with contextlib.redirect_stdout(stdout), warnings.catch_warnings(record=True) as w:
    X, y, metadata = paradigm.get_data(
        dataset=dataset,
        subjects=[1],
        return_epochs=False,
        return_raws=False,
        cache_config=None,
        postprocess_pipeline=None,
    )

# Limit to one session
session = '0'
ids_mask_session = metadata.session == session
y_encoded = LabelEncoder().fit_transform(y)

X = X[ids_mask_session]
y = y_encoded[ids_mask_session]

# Print number of classes and number of samples per class
unique_classes, counts = np.unique(y, return_counts=True)
print(f"Number of classes: {len(unique_classes)}")
for cls, count in zip(unique_classes, counts):
    print(f"Class '{cls}': {count} samples")



Number of classes: 2
Class '0': 3275 samples
Class '1': 1008 samples


### Step: Explore the data
Let's take a look at what the average brain activity is for each target and non target stimuli, for each of the channels. I have highlighted two periods, where we find archetypical responses for this type of stimuli, the so-called N100 and P300: a (N)egative spike around 100ms after stimuli, and a (P)ositive spike around 300ms after stimuli. N100 is to be seen predominantly in electrodes around the visual cortex, i.e., Occital Channels O1 and O2, where as P300 is commonly seen in Central-Parietal, closer to the top of the skull, i.e., Cz and Pz.

In [3]:
import plotly.express as px
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from plotly.subplots import make_subplots


ids_target = y == 1
X_target = X[ids_target]

ids_nontarget = y == 0
X_nontarget = X[ids_nontarget]

X_target_avg = X_target.mean(axis=0)
X_nontarget_avg = X_nontarget.mean(axis=0)

num_channels = X_target_avg.shape[0]


# Create a figure with subplots in a square grid
num_cols = int(np.ceil(np.sqrt(num_channels)))
num_rows = int(np.ceil(num_channels / num_cols))
fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=EEG_CHANNELS, shared_xaxes=True, shared_yaxes=True, vertical_spacing=0.05, horizontal_spacing=0.02)

time_vector: np.ndarray = np.arange(X_target_avg.shape[1])/SAMPLING_FREQUENCY
# Plot each channel in a separate subplot
for i in range(num_channels):
    row: int = i // num_cols + 1
    col: int = i % num_cols + 1
    showlegend: bool = (i == 0)
    fig.add_trace(
        go.Scatter(
            x=time_vector, 
            y=X_target_avg[i], 
            mode='lines', 
            name='Target' if showlegend else None, 
            line=dict(color='red'),
            showlegend=showlegend
        ), 
        row=row, col=col
    )
    fig.add_trace(
        go.Scatter(
            x=time_vector, 
            y=X_nontarget_avg[i], 
            mode='lines', 
            name='NonTarget' if showlegend else None, 
            line=dict(color='blue'),
            showlegend=showlegend
        ), 
        row=row, col=col
    )
    fig.add_vrect(
        x0=60/SAMPLING_FREQUENCY, x1=80/SAMPLING_FREQUENCY, 
        fillcolor="#EFCB66", opacity=0.5, 
        layer="below", line_width=0,
        row=row, col=col
    )
    fig.add_vrect(
        x0=128/SAMPLING_FREQUENCY, x1=256/SAMPLING_FREQUENCY, 
        fillcolor="#90EE90", opacity=0.5, 
        layer="below", line_width=0,
        row=row, col=col
    )
    if row == num_rows:
        fig.update_xaxes(title_text="seconds after stimulus", row=row, col=col)

    if col == 1:
        fig.update_yaxes(title_text="µV", row=row, col=col)

# Update layout
y_axis_range: list[float] = [min(X_nontarget_avg.min(), X_target_avg.min()), max(X_nontarget_avg.max(), X_target_avg.max())]
fig.update_layout(
    height=800, 
    width=1000, 
    showlegend=True
)
for row in range(1, num_rows + 1):
    fig.update_yaxes(range=y_axis_range, row=row, col=1)
fig.show()


### Step: Baseline pipeline
As baseline, we will implemented a pipeline that is commonly used with this type of data. It consist of a *spatial filter*, which means, a linear mix across all channels. The spatial filters are learned from data, such that the difference between classes is maximized. This is the [XDawn](https://ieeexplore.ieee.org/abstract/document/4760273) part of the pipeline. 

The spatial filters are followed by a standard LDA classifier.


In [4]:
from sklearn.metrics import get_scorer
from pyriemann.estimation import Xdawn
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import train_test_split
from mne.decoding import Vectorizer

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)


baseline_pipeline = make_pipeline(
    Xdawn(nfilter=1), 
    Vectorizer(),
    LDA(
        solver='lsqr', 
        shrinkage='auto'
    )
)
# Fit the baseline pipeline
baseline_pipeline.fit(X_train, y_train)

# Predict on the test set
y_pred = baseline_pipeline.predict(X_test)

# Calculate the test score
test_score: float = get_scorer(paradigm.scoring)(baseline_pipeline, X_test, y_test)
print(f"Test ROC-AUC XDAWN+LDA: {test_score}")



Test ROC-AUC XDAWN+LDA: 0.9788115284974094


pas mal.

### Step: Data preparation for the MOMENT torch mdoel

In [5]:

N_INPUT_SAMPLES = 512

from sklearn.preprocessing import LabelEncoder
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split

def prepare_data(X_train, X_test, y_train, y_test):
    # Convert to torch tensors
    X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.long)
    y_test_tensor = torch.tensor(y_test, dtype=torch.long)

    # pad the input to 512
    X_train_tensor = torch.nn.functional.pad(X_train_tensor, (0, N_INPUT_SAMPLES - X_train_tensor.shape[-1]))
    X_test_tensor = torch.nn.functional.pad(X_test_tensor, (0, N_INPUT_SAMPLES - X_test_tensor.shape[-1]))


    input_mask: torch.Tensor = torch.cat([torch.ones(X_train.shape[-1], dtype=torch.bool), torch.zeros(N_INPUT_SAMPLES - X_train.shape[-1], dtype=torch.bool)])
    input_mask = input_mask.unsqueeze(0).repeat(X_train_tensor.shape[0], 1)


    # Create TensorDataset
    train_dataset = TensorDataset(X_train_tensor, y_train_tensor, input_mask[:X_train_tensor.shape[0]])
    test_dataset = TensorDataset(X_test_tensor, y_test_tensor, input_mask[:X_test_tensor.shape[0]])


    # Create DataLoader
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    return train_dataloader, test_dataloader

train_dataloader, test_dataloader = prepare_data(X_train, X_test, y_train, y_test)


Just making sure that the conversion of data to torch tensors went OK

In [6]:
# Get tensors from dataloader
def collect_tensors(dataloader: DataLoader) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    X_list = []
    y_list = []
    mask_list = []
    for X, y, mask in dataloader:
        X_list.append(X)
        y_list.append(y) 
        mask_list.append(mask)
    return torch.cat(X_list), torch.cat(y_list), torch.cat(mask_list)

X_train_tensor, y_train_tensor, train_input_mask = collect_tensors(train_dataloader)
X_test_tensor, y_test_tensor, test_input_mask = collect_tensors(test_dataloader)

# Apply baseline_pipeline to X_train_tensor and X_test_tensor
baseline_pipeline.fit(X_train_tensor.cpu().numpy(), y_train_tensor.cpu().numpy())

# Predict on the test set
y_pred_tensor = baseline_pipeline.predict(X_test_tensor.cpu().numpy())

# Calculate the test score
test_score_tensor: float = get_scorer(paradigm.scoring)(baseline_pipeline, X_test_tensor.cpu().numpy(), y_test_tensor.cpu().numpy())
print(f"Test ROC-AUC XDAWN+LDA: {test_score_tensor}")



Test ROC-AUC XDAWN+LDA: 0.9788147668393783


### Calculate MOMENT embeddings
We calculate the embeddings for each channel individually. It’s important to note that this setup is inherently biased against MOMENT, as it isn’t trained to learn across channels. To address this, we reshape the data so that each channel’s embedding is obtained separately, and then we concatenate these embeddings. Consequently, the baseline pipeline extracts inter-channel information, while MOMENT does not.

In [7]:
from tqdm import tqdm
from momentfm import MOMENTPipeline
def get_embedding(model, dataloader):
    embeddings, labels = [], []
    with torch.no_grad():
        for batch_x, batch_labels, batch_mask in tqdm(dataloader, total=len(dataloader)):
            batch_size, n_channels, n_timesteps = batch_x.shape
            batch_x = batch_x.reshape(batch_size * n_channels, 1, n_timesteps).to(TORCH_DEVICE)
            batch_mask = batch_mask.to(TORCH_DEVICE)
            output = model(x_enc=batch_x) # [batch_size * n_channels x emb_dim]
            embedding = output.embeddings.cpu().reshape(batch_size, -1) # [batch_size x n_channels * emb_dim]
            embeddings.append(embedding)
            labels.append(batch_labels)        

    embeddings, labels = np.concatenate(embeddings), np.concatenate(labels)
    return embeddings, labels

model = MOMENTPipeline.from_pretrained(
    MODEL_NAME, 
    model_kwargs={
        'task_name': 'embedding',
    }
)
model.init()
model.to(TORCH_DEVICE)


train_embeddings, train_labels = get_embedding(model, train_dataloader)
test_embeddings, test_labels = get_embedding(model, test_dataloader)

print(train_embeddings.shape, train_labels.shape)
print(test_embeddings.shape, test_labels.shape)



IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


`torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.


Only reconstruction head is pre-trained. Classification and forecasting heads must be fine-tuned.


torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.


None of the inputs have requires_grad=True. Gradients will be None

100%|██████████| 94/94 [00:17<00:00,  5.29it/s]
100%|██████████| 41/41 [00:07<00:00,  5.49it/s]

(2998, 4608) (2998,)
(1285, 4608) (1285,)





### Step: Downstream classification
We use an LDA and hyperparameter optimized SVM


In [8]:
from momentfm.models.statistical_classifiers import fit_svm

moment_lda_pipeline = make_pipeline(
    LDA(solver='lsqr', shrinkage='auto')
)
moment_lda_pipeline.fit(train_embeddings, train_labels)
y_pred = moment_lda_pipeline.predict(test_embeddings)
test_score = get_scorer(paradigm.scoring)(moment_lda_pipeline, test_embeddings, test_labels)
print(f"Test ROC-AUC MOMENT+LDA: {test_score}")



Test ROC-AUC MOMENT+LDA: 0.609520725388601


In [9]:

moment_svm_pipeline = fit_svm(train_embeddings, train_labels)
y_pred = moment_svm_pipeline.predict(test_embeddings)
test_score = get_scorer(paradigm.scoring)(moment_svm_pipeline, test_embeddings, test_labels)
print(f"Test ROC-AUC MOMENT+SVM: {test_score}")


Test ROC-AUC MOMENT+SVM: 0.549556347150259


Well, that is not very good. I suspect that it has to do with the dimensionality of the input data (number of channels x the number of embedding dimensions). Let me try once again but this time applying PCA before the LDA classifier to reduce the number of input dimensions by 90%.

In [10]:
from sklearn.decomposition import PCA

moment_lda_pipeline = make_pipeline(
    PCA(n_components=int(train_embeddings.shape[-1] * 0.1)),
    LDA(solver='lsqr', shrinkage='auto')
)
moment_lda_pipeline.fit(train_embeddings, train_labels)
y_pred = moment_lda_pipeline.predict(test_embeddings)
test_score = get_scorer(paradigm.scoring)(moment_lda_pipeline, test_embeddings, test_labels)
print(f"Test ROC-AUC MOMENT+PCA+LDA: {test_score}")


Test ROC-AUC MOMENT+PCA+LDA: 0.6235427461139896


It improved, although not as much as I would have expected.

### Step: Dimensionality reduction before embeddings 
I want to try one last thing. Since MOMENT is still not built to merge information across channels, we will help it a bit by spatially filtering the data *before* calculating the embeddings. This should substantially increase the classification score.

In [11]:
x_dawn = Xdawn(nfilter=1)
X_train_filtered = x_dawn.fit_transform(X_train, y_train)
X_test_filtered = x_dawn.transform(X_test)


train_dataloader_filtered, test_dataloader_filtered = prepare_data(X_train_filtered, X_test_filtered, y_train, y_test)

train_embeddings_filtered, train_labels_filtered = get_embedding(model, train_dataloader_filtered)
test_embeddings_filtered, test_labels_filtered = get_embedding(model, test_dataloader_filtered)

moment_lda_pipeline = make_pipeline(
    LDA(solver='lsqr', shrinkage='auto')
)
moment_lda_pipeline.fit(train_embeddings_filtered, train_labels_filtered)
y_pred = moment_lda_pipeline.predict(test_embeddings_filtered)
test_score = get_scorer(paradigm.scoring)(moment_lda_pipeline, test_embeddings_filtered, test_labels_filtered)
print(f"Test ROC-AUC Xdawn+MOMENT+LDA: {test_score}")



torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.


None of the inputs have requires_grad=True. Gradients will be None

100%|██████████| 94/94 [00:04<00:00, 22.03it/s]
100%|██████████| 41/41 [00:01<00:00, 21.73it/s]


Test ROC-AUC Xdawn+MOMENT+LDA: 0.8285330310880828


Aha! the assumption of inter-channel independence in MOMENT appears to be a significant challenge. While we’re still not reaching the baseline’s 0.97 ROC-AUC, we’ve seen a clear improvement from 0.64 to 0.81 by spatially mixing the channels and, thus, reducing the channel count before calculating embeddings.

Remaining considerations:

1.	Cross-subject, cross-sesion: I still wonder who well these embeddings work if we merge the data across different subjects. This is still a challenging topic in the BCI community. In our case, I suspect that adding cross subject data could even improve the performance, since it seems that we are still under the curse of dimensionality.
2.	Generalizability to other experiments: It would be interesting to test if this pipeline can adapt to different types of experiments. For instance, in scenarios where discriminative features are not specific amplitude peaks (like N100 or P300) but other characteristics—such as power increases within a specific frequency band over longer periods of time, as seen in Motor Imagery experiments. This is actually one of the main selling points of foundation models: decent zero-shot performance across tasks, so it would make sense to try this one out.

I plan to explore this further in the future, but I’ll leave it here for now.