Fix SincShallowNet MPS backward failure#1087
Conversation
There was a problem hiding this comment.
Pull request overview
This PR addresses an Apple MPS-specific backward failure in SincShallowNet by ensuring the tensor passed into the depthwise BatchNorm2d is contiguous after an axis permutation.
Changes:
- Move the axis permutation out of the
depthwiseconvnn.Sequentialand intoforward, applying.contiguous()afterpermute. - Add a unit test that asserts the depthwise
BatchNorm2dreceives a contiguous input and that backward completes.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
braindecode/models/sinc_shallow.py |
Ensures contiguous memory layout before depthwise BatchNorm by permuting and calling .contiguous() in forward. |
test/unit_tests/models/test_sinc_shallow.py |
Adds a regression test that hooks the depthwise BatchNorm to assert its input is contiguous and that .backward() runs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def check_input_contiguous(_module, inputs): | ||
| batch_norm_inputs_are_contiguous.append(inputs[0].is_contiguous()) | ||
|
|
||
| handle = model.depthwiseconv[0].register_forward_pre_hook(check_input_contiguous) |
| # batch timefil time nfilter -> batch nfilter timefil time | ||
| x = x.permute(0, 3, 1, 2).contiguous() |
| # batch timefil time nfilter -> batch nfilter timefil time | ||
| x = x.permute(0, 3, 1, 2).contiguous() | ||
| x = self.depthwiseconv(x) |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ce4cb0e6bc
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| self.depthwiseconv = nn.Sequential( | ||
| # Matching dim to depth wise conv! | ||
| Rearrange("batch timefil time nfilter -> batch nfilter timefil time"), | ||
| nn.BatchNorm2d( | ||
| self.n_filters, momentum=0.99 | ||
| ), # To match keras implementation |
There was a problem hiding this comment.
Preserve depthwiseconv checkpoint keys
For checkpoints saved before this change, the depthwise batch norm and conv weights are named depthwiseconv.1.* and depthwiseconv.2.*; starting this Sequential with BatchNorm2d renames them to depthwiseconv.0.* and depthwiseconv.1.*. That makes existing SincShallowNet checkpoints fail under default load_state_dict strict loading, and can silently leave these layers randomly initialized when users load non-strictly. Keeping a no-parameter contiguous rearrange wrapper at index 0, or adding a mapping, would avoid breaking saved weights.
Useful? React with 👍 / 👎.
SincShallowNetcould fail during backward on Apple MPS becauseBatchNorm2dreceived a non-contiguous tensor from an einops axis permutation. CPU/CUDA were unaffected, but MPS BatchNorm backward can attempt an incompatible view on the saved input.BatchNorm2dso the saved BatchNorm input is contiguous.SincShallowNettest that verifies the depthwise BatchNorm receives contiguous input and that backward completes.