feat: add MVPFormer foundation model#1057
Conversation
MVPFormer (Carzaniga et al., 2025, arXiv:2506.20354): a decoder-only
foundation model with multi-variate parallel attention (MVPA) for
heterogeneous multi-variate iEEG, decomposing self-attention into content,
time-relative and channel-relative terms over a (segment, channel) token grid.
- braindecode.models.MVPFormer following EEGModuleMixin conventions (signal
params, final_layer head, return_features, config round-trip, full
registration + summary.csv badges).
- db4 wavelet signal encoder computed from first principles in
braindecode.functional (daubechies_filters / wavelet_decomposition); no new
runtime dependency, bit-identical to pywt periodic wavedec.
- Temporary state-dict converter (_mvpformer_convert) to load the authors'
released SWEC ("GeNIE") base + LoRA checkpoints; to be removed once the
weights are re-hosted in the braindecode-native layout.
- Bit-exact equivalence tests vs a transcribed reference oracle (both attention
modes) and CPU loading of the real MVPFormer-S checkpoints (env-gated).
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 30366fc4a4
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
Pull request overview
Adds the MVPFormer foundation model to braindecode.models, including a db4 wavelet-based tokenizer/encoder, a temporary upstream-checkpoint conversion/LoRA-merge shim, and unit tests validating numerical equivalence and checkpoint loading.
Changes:
- Introduce
braindecode.models.MVPFormer(MVPA decoder-only transformer) plus temporary checkpoint conversion utilities. - Add wavelet primitives (
daubechies_filters,wavelet_decomposition,dwt_max_level) tobraindecode.functional. - Add unit tests (model equivalence, LoRA merge, optional real-checkpoint loading) and register/document the new model.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| test/unit_tests/models/test_mvpformer.py | New unit tests for MVPFormer equivalence, LoRA merge, and optional real checkpoint loading. |
| test/unit_tests/models/test_integration.py | Excludes MVPFormer from TorchScript/export test list. |
| test/unit_tests/models/test_functional.py | Adds tests for db wavelet filters and wavelet decomposition vs optional reference libs. |
| test/unit_tests/models/_mvpformer_reference.py | Adds a CPU-only upstream transcription oracle used for equivalence testing. |
| docs/whats_new.rst | Documents the addition of MVPFormer. |
| braindecode/models/util.py | Adds MVPFormer to model parameter metadata/defaults. |
| braindecode/models/summary.csv | Adds MVPFormer to the model summary table. |
| braindecode/models/mvpformer.py | New MVPFormer implementation (wavelet encoder + MVPA blocks + head). |
| braindecode/models/_mvpformer_convert.py | Temporary upstream checkpoint key remapping + LoRA merge utility. |
| braindecode/models/init.py | Exports MVPFormer from the models package. |
| braindecode/functional/functions.py | Adds Daubechies filter generation + periodic DWT decomposition helpers. |
| braindecode/functional/init.py | Exposes the new functional wavelet APIs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
- Remove the dedicated test suite (test_mvpformer.py) and the transcribed reference oracle (_mvpformer_reference.py). The model stays covered by the generic models integration suite; the wavelet functions by test_functional. - Remove the temporary released-weight converter (_mvpformer_convert.py) from the tree; it is kept as a PR comment until the authors re-host the weights. - Drop the released-weight-loading claim from whats_new and the model docstring to match. - The architecture is transcribed from the authors' Apache-2.0 reference implementation (Copyright IBM Corp. 2024-2025), so mvpformer.py is marked Apache-2.0 following the braindecode luna.py convention. The independent db4-wavelet functions in braindecode.functional remain BSD-3.
Temporary released-weight converter (not merged)This converter was removed from the PR to keep the merged code to the model It is a deliberate stop-gap: once the authors re-host the weights in a # Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
#
# License: BSD-3
#
# ============================================================================
# TEMPORARY -- DELETE WHEN UPSTREAM RE-HOSTS WEIGHTS
# ----------------------------------------------------------------------------
# Converts the *currently released* MVPFormer ("GeNIE") checkpoints (raw
# PyTorch state dicts, IBM Box) onto braindecode's `MVPFormer` parameter names.
# Once the authors re-publish braindecode-native weights, remove this whole
# file (it is the only place that knows the upstream key scheme).
#
# Observed in the released Sept/Oct-2024 checkpoints (genie-{s,m}-{base,swec}.pt):
# * backbone prefix is ``genie.`` (the public code later renamed it
# ``mvpformer.``; both are accepted here);
# * each attention block carries vestigial ``time_bias`` / ``masked_bias``
# buffers and a dead ``ln_2`` -- all dropped;
# * the ``base`` head ``head.head`` is the (d_model, d_model) *generative*
# projection, not a classifier -- dropped (attach a fresh head / reset_head);
# * ``swec`` files contain only LoRA deltas + the real classification head
# (no backbone), so a seizure classifier = base backbone + merged LoRA +
# swec head.
# ============================================================================
from __future__ import annotations
# ponytail: keep all upstream-key knowledge in this one disposable file.
_BACKBONE_PREFIXES = ("genie.", "mvpformer.")
def _strip_backbone_prefix(key: str) -> str:
for prefix in _BACKBONE_PREFIXES:
if key.startswith(prefix):
return key[len(prefix) :]
return key
def _is_dropped(key: str) -> bool:
"""Upstream keys with no braindecode counterpart."""
return (
key == "seizure_embeddings" # generative-only
or ".ln_2." in key # instantiated upstream but never used
or key.endswith(".time_bias") # vestigial attention buffer
or key.endswith(".masked_bias") # vestigial attention buffer
or ".lora_" in key # LoRA deltas (merged separately for swec)
or key.startswith("head.") # generative projection; classifier via swec
)
def _map_backbone_key(key: str) -> str | None:
"""Map one upstream *backbone* key to its braindecode name (or ``None``)."""
if _is_dropped(key):
return None
if key.startswith("encoder."):
return "patch_embed." + key[len("encoder.") :]
rest = _strip_backbone_prefix(key) # genie./mvpformer. -> ""
if rest.startswith("h."):
return "blocks." + rest[len("h.") :]
return rest # ln_f.weight, positional_embedding.weight, channel_embedding.weight
def convert_mvpformer_state_dict(state_dict: dict, kind: str = "base") -> dict:
"""Convert a raw upstream MVPFormer/GeNIE checkpoint to ``MVPFormer`` keys.
.. warning::
Temporary shim for the currently released checkpoints; see the module
banner. It will be removed once braindecode-native weights are
published.
Parameters
----------
state_dict : dict
Raw ``torch.load`` of an upstream ``*-base.pt`` checkpoint.
kind : {"base"}
Only ``"base"`` (the generatively pre-trained, LoRA-free backbone) is
supported here. A full ``swec`` seizure classifier is built with
:func:`merge_swec_checkpoint` (base backbone + LoRA + swec head).
Returns
-------
dict
State dict with ``MVPFormer`` parameter names. Load with
``model.load_state_dict(converted, strict=False)``; the fresh
classification head appears as the only missing key (expected for a
backbone) -- call :meth:`MVPFormer.reset_head` or train it.
"""
if kind != "base":
raise ValueError(
f"convert_mvpformer_state_dict only handles kind='base', got {kind!r}. "
"Use merge_swec_checkpoint for swec classifiers."
)
converted = {}
for key, value in state_dict.items():
new_key = _map_backbone_key(key)
if new_key is not None:
converted[new_key] = value
return converted
def merge_swec_checkpoint(
base_state_dict: dict,
swec_state_dict: dict,
lora_alpha: int = 16,
lora_rank: int = 8,
) -> dict:
"""Build a full MVPFormer seizure classifier from a base backbone + swec.
.. warning::
Temporary shim for the currently released checkpoints; see the module
banner.
The released ``swec`` files contain only LoRA deltas (on ``q_attn`` and
``c_attn``, as plain ``loralib.Linear`` adapters -- verified against
``loralib``) plus the classification head. This merges
``W <- W + (lora_alpha / lora_rank) * (lora_B @ lora_A)`` into the backbone
and installs the swec head as ``final_layer``.
Parameters
----------
base_state_dict : dict
Raw upstream ``*-base.pt`` checkpoint (the backbone).
swec_state_dict : dict
Raw upstream ``*-swec.pt`` checkpoint (LoRA deltas + head).
lora_alpha, lora_rank : int
LoRA scaling parameters (released config: 16 and 8 -> scaling 2).
Returns
-------
dict
Full classifier state dict in ``MVPFormer`` naming;
``model.load_state_dict(merged)`` should match exactly (use a model with
``n_outputs`` equal to the swec head, e.g. 2 for seizure detection).
"""
converted = convert_mvpformer_state_dict(base_state_dict, kind="base")
scaling = lora_alpha / lora_rank
for key in swec_state_dict:
if not key.endswith(".lora_A"):
continue
stem = key[: -len(".lora_A")] # e.g. genie.h.0.attn.q_attn
lora_A = swec_state_dict[key]
lora_B = swec_state_dict[stem + ".lora_B"]
target = _map_backbone_key(stem + ".weight") # blocks.{i}.attn.*_attn.weight
if target is None or target not in converted:
raise KeyError(f"LoRA target for {key!r} ({target}) not in backbone.")
converted[target] = converted[target] + scaling * (lora_B @ lora_A)
head = swec_state_dict.get("head.head.weight")
if head is not None:
converted["final_layer.weight"] = head
return converted |
From the automated PR review (Codex/Copilot) on #1057: - wavelet_decomposition: build the conv weight on the input device (not just its dtype) so the functional API works on CUDA/MPS. Validated on MPS. - _rel_shift_chan: construct index tensors on x.device with long dtype for GPU-safe advanced indexing. - _WaveletPatchEmbed: project in the projection-weight dtype (robust across AMP / float16 inputs); identical on the float32 path. - test_daubechies_filters_match_pywt: gate on pywt (not ptwt), bound via importorskip. - test_integration: correct the stale "ptwt.wavedec" TorchScript skip comment. All changes are numerically neutral on the float32 CPU path; an MPS forward matches CPU to ~2e-8.
- ruff format (pinned v0.14.9): wrap the long dec_hi line in daubechies_filters (the failing pre-commit hook). - reset_head: re-apply _init_weights so a reset head matches fresh construction (std=0.02) instead of default Linear init. - test_wavelet_decomposition_matches_pywt: importorskip both ptwt and pywt so the test skips (never errors) when either is missing. - License header: add the greppable `# License: Apache-2.0` comment and drop the awkward in-docstring sentence; the model stays Apache-2.0 (transcribed from the authors' Apache reference). - Docstring: stop claiming sub-quadratic cost — the pure-PyTorch path materialises the full content-attention matrix (the windowed FlashMVPA kernel is not ported). PR description updated to the slimmed scope (no in-tree converter/oracle/ test_mvpformer; converter preserved as a PR comment).
einops.repeat expanding new axes returns a memory-sharing view, so the in-place window_mask[-n_channels:] = 1 raised RuntimeError (more than one element refers to a single memory location). Triggers whenever the global attention mask path runs with a single segment (n_times == segment, i.e. exactly one window). clone() before the in-place write; values unchanged.
End-to-end validation + FNUSA replicationValidated the model on real data and replicated the paper's FNUSA result. Bug fixed (pushed in this branch)
It triggers whenever the default Replication — FNUSA iEEG (Nejedly 2020), MVPFormer-SUsed the released
Matches within Δ=0.05, fully explained by test prevalence (25.6% → all-positive F1 = 0.41). Note the FNUSA F1 is the predict-all-positive prevalence baseline — the paper reports every model (incl. Brant-2) collapsing to |
…nto MVPFormer # Conflicts: # braindecode/functional/functions.py # braindecode/models/summary.csv # docs/whats_new.rst # test/unit_tests/models/test_functional.py # test/unit_tests/models/test_integration.py
| global_att : bool | ||
| Whether to use the global content-attention term. | ||
| max_segments : int |
| n = n_vanishing | ||
| poly = np.array([math.comb(n - 1 + k, k) for k in range(n)], dtype=float) |
| def test_wavelet_decomposition_matches_pywt(): | ||
| """wavelet_decomposition is bit-identical to pywt/ptwt wavedec(mode='periodic') | ||
| across sizes (skipped if the reference library is unavailable).""" |
| global_att : bool | ||
| Whether to add the global content-attention term. | ||
| local_window : int |
| idxes = torch.triu_indices(chan_size, chan_size, offset=1, device=device) | ||
| shifting_idxes = torch.zeros( | ||
| chan_size, chan_size, dtype=torch.long, device=device | ||
| ) | ||
| shifting_idxes[..., idxes[0], idxes[1]] = upper_val |
| window = self.local_window | ||
| ones = torch.ones((n_segments, n_segments), device=query.device, dtype=bool) | ||
| window_mask = torch.logical_and( | ||
| torch.tril(ones, diagonal=window), | ||
| torch.triu(ones, diagonal=-window), |
| n = n_vanishing | ||
| poly = np.array([math.comb(n - 1 + k, k) for k in range(n)], dtype=float) |
| """wavelet_decomposition is bit-identical to pywt/ptwt wavedec(mode='periodic') | ||
| across sizes (skipped if the reference library is unavailable).""" |
| if chan_size > 1: | ||
| upper_val = torch.cat( | ||
| [ | ||
| torch.arange(1, chan_size - i, dtype=torch.long, device=device) | ||
| for i in range(chan_size - 1) |
| filter_len = filters.shape[-1] | ||
| if n_levels is None: | ||
| n_levels = dwt_max_level(x.shape[-1], filter_len) | ||
| leading = x.shape[:-1] |
Summary
Adds MVPFormer (Carzaniga et al., 2025, arXiv:2506.20354) — "A foundation model with multi-variate parallel attention to generate neuronal activity" — to
braindecode.models.MVPFormer is a decoder-only (Llama2-style) foundation model for heterogeneous multi-variate iEEG. Its multi-variate parallel attention (MVPA) decomposes self-attention over a
(segment, channel)token grid into three additive terms — content, time-relative and channel-relative — so temporal and spatial structure are modelled jointly at the attention level. The raw signal is tokenized channel-wise into continuous db4 wavelet embeddings.What's included
braindecode.models.MVPFormer— full architecture following braindecode conventions: canonical signal params,final_layerhead +reset_head,return_features, config/HF round-trip, registration + asummary.csvrow withFoundation Model/Attention/Transformerbadges. MEDFormer-style docstring (Architecture Overview / Macro Components / Temporal–Spatial–Spectral Encoding / Additional Mechanisms), einops throughout with explicit axis names. Runs on CPU and accelerators (validated on MPS).daubechies_filtersandwavelet_decompositionadded tobraindecode.functional, computing the Daubechies-4 filters from first principles (numpy spectral factorization) and running the periodic DWT as a strided, circular-paddedconv1d. Verified bit-identical topywt/ptwtperiodicwavedecover lengths 1–5200.Licensing
The architecture is transcribed from the authors' reference implementation (Copyright IBM Corp. 2024-2025), which is Apache-2.0, so
braindecode/models/mvpformer.pyis marked Apache-2.0 (following theluna.pyprecedent). The independently-implemented db4-wavelet functions inbraindecode.functionalremain BSD-3.Released-weight loading
A temporary converter for the authors' currently-released SWEC ("GeNIE") base + LoRA checkpoints is not merged — it is preserved as a PR comment and will be added in a braindecode-native form once the authors re-host the weights.
Test plan
test_functional.py(+2):daubechies_filtersandwavelet_decompositionvspywt/ptwt(bound viaimportorskip, so they stay test-only and never a runtime dep).test_integration.pymodel suite passes forMVPFormer(forward/backward, compile, summary/completeness, docstring); registered with the usual skip-list entry for TorchScript.ruff+pre-commitclean.Notes for reviewers
FlashMVPAkernel and the contrastive generative pre-training loop are not ported — this PR provides the architecture and classification path, pure-PyTorch and CPU/accelerator-runnable. Consequently the content-attention path materialises the full token×token matrix (memory quadratic in tokens); the windowed-kernel optimisation is the un-ported FlashMVPA.pywt/ptwtare optional test oracles only.