Skip to content

feat: add MVPFormer foundation model#1057

Merged
bruAristimunha merged 9 commits into
masterfrom
MVPFormer
Jun 30, 2026
Merged

feat: add MVPFormer foundation model#1057
bruAristimunha merged 9 commits into
masterfrom
MVPFormer

Conversation

@bruAristimunha

@bruAristimunha bruAristimunha commented Jun 18, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds MVPFormer (Carzaniga et al., 2025, arXiv:2506.20354) — "A foundation model with multi-variate parallel attention to generate neuronal activity" — to braindecode.models.

MVPFormer is a decoder-only (Llama2-style) foundation model for heterogeneous multi-variate iEEG. Its multi-variate parallel attention (MVPA) decomposes self-attention over a (segment, channel) token grid into three additive terms — content, time-relative and channel-relative — so temporal and spatial structure are modelled jointly at the attention level. The raw signal is tokenized channel-wise into continuous db4 wavelet embeddings.

What's included

  • braindecode.models.MVPFormer — full architecture following braindecode conventions: canonical signal params, final_layer head + reset_head, return_features, config/HF round-trip, registration + a summary.csv row with Foundation Model / Attention/Transformer badges. MEDFormer-style docstring (Architecture Overview / Macro Components / Temporal–Spatial–Spectral Encoding / Additional Mechanisms), einops throughout with explicit axis names. Runs on CPU and accelerators (validated on MPS).
  • db4 wavelet encoder with no new dependencydaubechies_filters and wavelet_decomposition added to braindecode.functional, computing the Daubechies-4 filters from first principles (numpy spectral factorization) and running the periodic DWT as a strided, circular-padded conv1d. Verified bit-identical to pywt / ptwt periodic wavedec over lengths 1–5200.

Licensing

The architecture is transcribed from the authors' reference implementation (Copyright IBM Corp. 2024-2025), which is Apache-2.0, so braindecode/models/mvpformer.py is marked Apache-2.0 (following the luna.py precedent). The independently-implemented db4-wavelet functions in braindecode.functional remain BSD-3.

Released-weight loading

A temporary converter for the authors' currently-released SWEC ("GeNIE") base + LoRA checkpoints is not merged — it is preserved as a PR comment and will be added in a braindecode-native form once the authors re-host the weights.

Test plan

  • test_functional.py (+2): daubechies_filters and wavelet_decomposition vs pywt/ptwt (bound via importorskip, so they stay test-only and never a runtime dep).
  • Generic test_integration.py model suite passes for MVPFormer (forward/backward, compile, summary/completeness, docstring); registered with the usual skip-list entry for TorchScript.
  • ruff + pre-commit clean.

Notes for reviewers

  • The Triton FlashMVPA kernel and the contrastive generative pre-training loop are not ported — this PR provides the architecture and classification path, pure-PyTorch and CPU/accelerator-runnable. Consequently the content-attention path materialises the full token×token matrix (memory quadratic in tokens); the windowed-kernel optimisation is the un-ported FlashMVPA.
  • No new runtime dependency: the wavelet transform is reimplemented from numpy; pywt / ptwt are optional test oracles only.

MVPFormer (Carzaniga et al., 2025, arXiv:2506.20354): a decoder-only
foundation model with multi-variate parallel attention (MVPA) for
heterogeneous multi-variate iEEG, decomposing self-attention into content,
time-relative and channel-relative terms over a (segment, channel) token grid.

- braindecode.models.MVPFormer following EEGModuleMixin conventions (signal
  params, final_layer head, return_features, config round-trip, full
  registration + summary.csv badges).
- db4 wavelet signal encoder computed from first principles in
  braindecode.functional (daubechies_filters / wavelet_decomposition); no new
  runtime dependency, bit-identical to pywt periodic wavedec.
- Temporary state-dict converter (_mvpformer_convert) to load the authors'
  released SWEC ("GeNIE") base + LoRA checkpoints; to be removed once the
  weights are re-hosted in the braindecode-native layout.
- Bit-exact equivalence tests vs a transcribed reference oracle (both attention
  modes) and CPU loading of the real MVPFormer-S checkpoints (env-gated).
Copilot AI review requested due to automatic review settings June 18, 2026 22:04

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 30366fc4a4

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread braindecode/functional/functions.py Outdated

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds the MVPFormer foundation model to braindecode.models, including a db4 wavelet-based tokenizer/encoder, a temporary upstream-checkpoint conversion/LoRA-merge shim, and unit tests validating numerical equivalence and checkpoint loading.

Changes:

  • Introduce braindecode.models.MVPFormer (MVPA decoder-only transformer) plus temporary checkpoint conversion utilities.
  • Add wavelet primitives (daubechies_filters, wavelet_decomposition, dwt_max_level) to braindecode.functional.
  • Add unit tests (model equivalence, LoRA merge, optional real-checkpoint loading) and register/document the new model.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
test/unit_tests/models/test_mvpformer.py New unit tests for MVPFormer equivalence, LoRA merge, and optional real checkpoint loading.
test/unit_tests/models/test_integration.py Excludes MVPFormer from TorchScript/export test list.
test/unit_tests/models/test_functional.py Adds tests for db wavelet filters and wavelet decomposition vs optional reference libs.
test/unit_tests/models/_mvpformer_reference.py Adds a CPU-only upstream transcription oracle used for equivalence testing.
docs/whats_new.rst Documents the addition of MVPFormer.
braindecode/models/util.py Adds MVPFormer to model parameter metadata/defaults.
braindecode/models/summary.csv Adds MVPFormer to the model summary table.
braindecode/models/mvpformer.py New MVPFormer implementation (wavelet encoder + MVPA blocks + head).
braindecode/models/_mvpformer_convert.py Temporary upstream checkpoint key remapping + LoRA merge utility.
braindecode/models/init.py Exports MVPFormer from the models package.
braindecode/functional/functions.py Adds Daubechies filter generation + periodic DWT decomposition helpers.
braindecode/functional/init.py Exposes the new functional wavelet APIs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread braindecode/functional/functions.py
Comment thread braindecode/models/mvpformer.py Outdated
Comment thread braindecode/models/mvpformer.py Outdated
Comment thread braindecode/models/mvpformer.py
Comment thread test/unit_tests/models/test_integration.py Outdated
Comment thread test/unit_tests/models/test_functional.py
Comment thread braindecode/models/_mvpformer_convert.py Outdated
- Remove the dedicated test suite (test_mvpformer.py) and the transcribed
  reference oracle (_mvpformer_reference.py). The model stays covered by the
  generic models integration suite; the wavelet functions by test_functional.
- Remove the temporary released-weight converter (_mvpformer_convert.py) from
  the tree; it is kept as a PR comment until the authors re-host the weights.
- Drop the released-weight-loading claim from whats_new and the model docstring
  to match.
- The architecture is transcribed from the authors' Apache-2.0 reference
  implementation (Copyright IBM Corp. 2024-2025), so mvpformer.py is marked
  Apache-2.0 following the braindecode luna.py convention. The independent
  db4-wavelet functions in braindecode.functional remain BSD-3.
@bruAristimunha

Copy link
Copy Markdown
Collaborator Author

Temporary released-weight converter (not merged)

This converter was removed from the PR to keep the merged code to the model
architecture only. It is preserved here for anyone who needs to load the
currently released MVPFormer / GeNIE checkpoints (genie-{s,m}-{base,swec}.pt,
raw PyTorch state dicts) onto braindecode.models.MVPFormer.

It is a deliberate stop-gap: once the authors re-host the weights in a
braindecode-native layout, this becomes unnecessary. Until then, drop it in as
braindecode/models/_mvpformer_convert.py (or a local script):

# Authors: Bruno Aristimunha <b.aristimunha@gmail.com>
#
# License: BSD-3
#
# ============================================================================
# TEMPORARY -- DELETE WHEN UPSTREAM RE-HOSTS WEIGHTS
# ----------------------------------------------------------------------------
# Converts the *currently released* MVPFormer ("GeNIE") checkpoints (raw
# PyTorch state dicts, IBM Box) onto braindecode's `MVPFormer` parameter names.
# Once the authors re-publish braindecode-native weights, remove this whole
# file (it is the only place that knows the upstream key scheme).
#
# Observed in the released Sept/Oct-2024 checkpoints (genie-{s,m}-{base,swec}.pt):
#   * backbone prefix is ``genie.`` (the public code later renamed it
#     ``mvpformer.``; both are accepted here);
#   * each attention block carries vestigial ``time_bias`` / ``masked_bias``
#     buffers and a dead ``ln_2`` -- all dropped;
#   * the ``base`` head ``head.head`` is the (d_model, d_model) *generative*
#     projection, not a classifier -- dropped (attach a fresh head / reset_head);
#   * ``swec`` files contain only LoRA deltas + the real classification head
#     (no backbone), so a seizure classifier = base backbone + merged LoRA +
#     swec head.
# ============================================================================

from __future__ import annotations


# ponytail: keep all upstream-key knowledge in this one disposable file.

_BACKBONE_PREFIXES = ("genie.", "mvpformer.")


def _strip_backbone_prefix(key: str) -> str:
    for prefix in _BACKBONE_PREFIXES:
        if key.startswith(prefix):
            return key[len(prefix) :]
    return key


def _is_dropped(key: str) -> bool:
    """Upstream keys with no braindecode counterpart."""
    return (
        key == "seizure_embeddings"  # generative-only
        or ".ln_2." in key  # instantiated upstream but never used
        or key.endswith(".time_bias")  # vestigial attention buffer
        or key.endswith(".masked_bias")  # vestigial attention buffer
        or ".lora_" in key  # LoRA deltas (merged separately for swec)
        or key.startswith("head.")  # generative projection; classifier via swec
    )


def _map_backbone_key(key: str) -> str | None:
    """Map one upstream *backbone* key to its braindecode name (or ``None``)."""
    if _is_dropped(key):
        return None
    if key.startswith("encoder."):
        return "patch_embed." + key[len("encoder.") :]
    rest = _strip_backbone_prefix(key)  # genie./mvpformer. -> ""
    if rest.startswith("h."):
        return "blocks." + rest[len("h.") :]
    return rest  # ln_f.weight, positional_embedding.weight, channel_embedding.weight


def convert_mvpformer_state_dict(state_dict: dict, kind: str = "base") -> dict:
    """Convert a raw upstream MVPFormer/GeNIE checkpoint to ``MVPFormer`` keys.

    .. warning::
        Temporary shim for the currently released checkpoints; see the module
        banner. It will be removed once braindecode-native weights are
        published.

    Parameters
    ----------
    state_dict : dict
        Raw ``torch.load`` of an upstream ``*-base.pt`` checkpoint.
    kind : {"base"}
        Only ``"base"`` (the generatively pre-trained, LoRA-free backbone) is
        supported here. A full ``swec`` seizure classifier is built with
        :func:`merge_swec_checkpoint` (base backbone + LoRA + swec head).

    Returns
    -------
    dict
        State dict with ``MVPFormer`` parameter names. Load with
        ``model.load_state_dict(converted, strict=False)``; the fresh
        classification head appears as the only missing key (expected for a
        backbone) -- call :meth:`MVPFormer.reset_head` or train it.
    """
    if kind != "base":
        raise ValueError(
            f"convert_mvpformer_state_dict only handles kind='base', got {kind!r}. "
            "Use merge_swec_checkpoint for swec classifiers."
        )
    converted = {}
    for key, value in state_dict.items():
        new_key = _map_backbone_key(key)
        if new_key is not None:
            converted[new_key] = value
    return converted


def merge_swec_checkpoint(
    base_state_dict: dict,
    swec_state_dict: dict,
    lora_alpha: int = 16,
    lora_rank: int = 8,
) -> dict:
    """Build a full MVPFormer seizure classifier from a base backbone + swec.

    .. warning::
        Temporary shim for the currently released checkpoints; see the module
        banner.

    The released ``swec`` files contain only LoRA deltas (on ``q_attn`` and
    ``c_attn``, as plain ``loralib.Linear`` adapters -- verified against
    ``loralib``) plus the classification head. This merges
    ``W <- W + (lora_alpha / lora_rank) * (lora_B @ lora_A)`` into the backbone
    and installs the swec head as ``final_layer``.

    Parameters
    ----------
    base_state_dict : dict
        Raw upstream ``*-base.pt`` checkpoint (the backbone).
    swec_state_dict : dict
        Raw upstream ``*-swec.pt`` checkpoint (LoRA deltas + head).
    lora_alpha, lora_rank : int
        LoRA scaling parameters (released config: 16 and 8 -> scaling 2).

    Returns
    -------
    dict
        Full classifier state dict in ``MVPFormer`` naming;
        ``model.load_state_dict(merged)`` should match exactly (use a model with
        ``n_outputs`` equal to the swec head, e.g. 2 for seizure detection).
    """
    converted = convert_mvpformer_state_dict(base_state_dict, kind="base")
    scaling = lora_alpha / lora_rank

    for key in swec_state_dict:
        if not key.endswith(".lora_A"):
            continue
        stem = key[: -len(".lora_A")]  # e.g. genie.h.0.attn.q_attn
        lora_A = swec_state_dict[key]
        lora_B = swec_state_dict[stem + ".lora_B"]
        target = _map_backbone_key(stem + ".weight")  # blocks.{i}.attn.*_attn.weight
        if target is None or target not in converted:
            raise KeyError(f"LoRA target for {key!r} ({target}) not in backbone.")
        converted[target] = converted[target] + scaling * (lora_B @ lora_A)

    head = swec_state_dict.get("head.head.weight")
    if head is not None:
        converted["final_layer.weight"] = head
    return converted

From the automated PR review (Codex/Copilot) on #1057:
- wavelet_decomposition: build the conv weight on the input device (not just
  its dtype) so the functional API works on CUDA/MPS. Validated on MPS.
- _rel_shift_chan: construct index tensors on x.device with long dtype for
  GPU-safe advanced indexing.
- _WaveletPatchEmbed: project in the projection-weight dtype (robust across
  AMP / float16 inputs); identical on the float32 path.
- test_daubechies_filters_match_pywt: gate on pywt (not ptwt), bound via
  importorskip.
- test_integration: correct the stale "ptwt.wavedec" TorchScript skip comment.

All changes are numerically neutral on the float32 CPU path; an MPS forward
matches CPU to ~2e-8.
Copilot AI review requested due to automatic review settings June 18, 2026 22:18

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 5 comments.

Comment thread test/unit_tests/models/test_functional.py
Comment thread braindecode/models/mvpformer.py
Comment thread braindecode/models/mvpformer.py
Comment thread braindecode/models/mvpformer.py
Comment thread docs/whats_new.rst
- ruff format (pinned v0.14.9): wrap the long dec_hi line in
  daubechies_filters (the failing pre-commit hook).
- reset_head: re-apply _init_weights so a reset head matches fresh
  construction (std=0.02) instead of default Linear init.
- test_wavelet_decomposition_matches_pywt: importorskip both ptwt and pywt
  so the test skips (never errors) when either is missing.
- License header: add the greppable `# License: Apache-2.0` comment and drop
  the awkward in-docstring sentence; the model stays Apache-2.0 (transcribed
  from the authors' Apache reference).
- Docstring: stop claiming sub-quadratic cost — the pure-PyTorch path
  materialises the full content-attention matrix (the windowed FlashMVPA
  kernel is not ported).

PR description updated to the slimmed scope (no in-tree converter/oracle/
test_mvpformer; converter preserved as a PR comment).
einops.repeat expanding new axes returns a memory-sharing view, so the
in-place window_mask[-n_channels:] = 1 raised RuntimeError (more than one
element refers to a single memory location). Triggers whenever the global
attention mask path runs with a single segment (n_times == segment, i.e.
exactly one window). clone() before the in-place write; values unchanged.
@bruAristimunha

Copy link
Copy Markdown
Collaborator Author

End-to-end validation + FNUSA replication

Validated the model on real data and replicated the paper's FNUSA result.

Bug fixed (pushed in this branch)

_MVPAttention._rel_attn built window_mask via einops.repeat (expanding new axes → a memory-sharing view) and then wrote in-place window_mask[-n_channels:] = 1, raising:

RuntimeError: unsupported operation: more than one element of the written-to
tensor refers to a single memory location. Please clone() ...

It triggers whenever the default global_att=True mask path runs with a single segment (n_times == segment, i.e. one 5 s window — the common single-clip case). Fix: .clone() the expanded mask before the in-place write (values unchanged). Restored unit tests (test_mvpformer.py, reference-equivalence) pass.

Replication — FNUSA iEEG (Nejedly 2020), MVPFormer-S

Used the released genie-s-base backbone (frozen), trained a 2-class head on the first 4 patients, tested on the remaining 10 (193,118 clips; pathology-vs-rest). Clips resampled 5 kHz→512 Hz, padded to one 2560-sample segment. Paper protocol from App. Results on the MAYO and FNUSA datasets.

F1 sens spec
paper (reported) 0.46 0.99 0.03
ours @ paper operating point 0.41 0.99 0.02
all-positive prevalence baseline 0.41 1.00 0.00
ours @ balanced threshold (discriminating) 0.37 0.72 0.25

Matches within Δ=0.05, fully explained by test prevalence (25.6% → all-positive F1 = 0.41). Note the FNUSA F1 is the predict-all-positive prevalence baseline — the paper reports every model (incl. Brant-2) collapsing to sens≈0.99 / spec≈0.03. Our frozen-backbone head additionally discriminates (F1=0.37 at a balanced threshold).

Copilot AI review requested due to automatic review settings June 19, 2026 19:34

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.

Comment thread braindecode/models/mvpformer.py
Comment thread braindecode/models/mvpformer.py
Comment thread braindecode/models/mvpformer.py
Comment thread test/unit_tests/models/test_functional.py
Comment thread braindecode/models/mvpformer.py
Comment thread braindecode/models/mvpformer.py
Comment thread braindecode/functional/functions.py
Comment thread braindecode/models/summary.csv Outdated
…nto MVPFormer

# Conflicts:
#	braindecode/functional/functions.py
#	braindecode/models/summary.csv
#	docs/whats_new.rst
#	test/unit_tests/models/test_functional.py
#	test/unit_tests/models/test_integration.py
Copilot AI review requested due to automatic review settings June 30, 2026 12:29

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.

Comment on lines +163 to +165
global_att : bool
Whether to use the global content-attention term.
max_segments : int
Comment on lines +290 to +291
n = n_vanishing
poly = np.array([math.comb(n - 1 + k, k) for k in range(n)], dtype=float)
Comment on lines +89 to +91
def test_wavelet_decomposition_matches_pywt():
"""wavelet_decomposition is bit-identical to pywt/ptwt wavedec(mode='periodic')
across sizes (skipped if the reference library is unavailable)."""
Comment on lines +393 to +395
global_att : bool
Whether to add the global content-attention term.
local_window : int
Comment on lines +505 to +509
idxes = torch.triu_indices(chan_size, chan_size, offset=1, device=device)
shifting_idxes = torch.zeros(
chan_size, chan_size, dtype=torch.long, device=device
)
shifting_idxes[..., idxes[0], idxes[1]] = upper_val
Comment on lines +589 to +593
window = self.local_window
ones = torch.ones((n_segments, n_segments), device=query.device, dtype=bool)
window_mask = torch.logical_and(
torch.tril(ones, diagonal=window),
torch.triu(ones, diagonal=-window),
Copilot AI review requested due to automatic review settings June 30, 2026 12:35
@bruAristimunha bruAristimunha merged commit 4c0d44b into master Jun 30, 2026
12 checks passed

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 4 comments.

Comment on lines +290 to +291
n = n_vanishing
poly = np.array([math.comb(n - 1 + k, k) for k in range(n)], dtype=float)
Comment on lines +90 to +91
"""wavelet_decomposition is bit-identical to pywt/ptwt wavedec(mode='periodic')
across sizes (skipped if the reference library is unavailable)."""
Comment on lines +512 to +516
if chan_size > 1:
upper_val = torch.cat(
[
torch.arange(1, chan_size - i, dtype=torch.long, device=device)
for i in range(chan_size - 1)
Comment on lines +357 to +360
filter_len = filters.shape[-1]
if n_levels is None:
n_levels = dwt_max_level(x.shape[-1], filter_len)
leading = x.shape[:-1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants