Skip to content

Fix SincShallowNet MPS backward failure#1087

Merged
bruAristimunha merged 3 commits into
masterfrom
copilot/fix-sincshallownet-backward-mps-error
Jun 28, 2026
Merged

Fix SincShallowNet MPS backward failure#1087
bruAristimunha merged 3 commits into
masterfrom
copilot/fix-sincshallownet-backward-mps-error

Conversation

Copilot AI commented Jun 28, 2026

Copy link
Copy Markdown
Contributor

SincShallowNet could fail during backward on Apple MPS because BatchNorm2d received a non-contiguous tensor from an einops axis permutation. CPU/CUDA were unaffected, but MPS BatchNorm backward can attempt an incompatible view on the saved input.

  • Model change
    • Added an einops-based contiguous rearrange wrapper.
    • Used it before the depthwise BatchNorm2d so the saved BatchNorm input is contiguous.
class _ContiguousRearrange(Rearrange):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return super().forward(x).contiguous()
  • Regression coverage
    • Added a focused SincShallowNet test that verifies the depthwise BatchNorm receives contiguous input and that backward completes.

Copilot AI changed the title [WIP] Fix SincShallowNet backward error on Apple MPS Fix SincShallowNet MPS backward failure Jun 28, 2026
Copilot AI requested a review from bruAristimunha June 28, 2026 21:06
@bruAristimunha bruAristimunha marked this pull request as ready for review June 28, 2026 21:16
Copilot AI review requested due to automatic review settings June 28, 2026 21:16
@bruAristimunha bruAristimunha merged commit cf19e02 into master Jun 28, 2026
10 of 11 checks passed

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses an Apple MPS-specific backward failure in SincShallowNet by ensuring the tensor passed into the depthwise BatchNorm2d is contiguous after an axis permutation.

Changes:

  • Move the axis permutation out of the depthwiseconv nn.Sequential and into forward, applying .contiguous() after permute.
  • Add a unit test that asserts the depthwise BatchNorm2d receives a contiguous input and that backward completes.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
braindecode/models/sinc_shallow.py Ensures contiguous memory layout before depthwise BatchNorm by permuting and calling .contiguous() in forward.
test/unit_tests/models/test_sinc_shallow.py Adds a regression test that hooks the depthwise BatchNorm to assert its input is contiguous and that .backward() runs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

def check_input_contiguous(_module, inputs):
batch_norm_inputs_are_contiguous.append(inputs[0].is_contiguous())

handle = model.depthwiseconv[0].register_forward_pre_hook(check_input_contiguous)
Comment on lines +204 to +205
# batch timefil time nfilter -> batch nfilter timefil time
x = x.permute(0, 3, 1, 2).contiguous()
Comment on lines +204 to 206
# batch timefil time nfilter -> batch nfilter timefil time
x = x.permute(0, 3, 1, 2).contiguous()
x = self.depthwiseconv(x)

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ce4cb0e6bc

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 160 to 163
self.depthwiseconv = nn.Sequential(
# Matching dim to depth wise conv!
Rearrange("batch timefil time nfilter -> batch nfilter timefil time"),
nn.BatchNorm2d(
self.n_filters, momentum=0.99
), # To match keras implementation

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve depthwiseconv checkpoint keys

For checkpoints saved before this change, the depthwise batch norm and conv weights are named depthwiseconv.1.* and depthwiseconv.2.*; starting this Sequential with BatchNorm2d renames them to depthwiseconv.0.* and depthwiseconv.1.*. That makes existing SincShallowNet checkpoints fail under default load_state_dict strict loading, and can silently leave these layers randomly initialized when users load non-strictly. Keeping a no-parameter contiguous rearrange wrapper at index 0, or adding a mapping, would avoid breaking saved weights.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants