Skip to content

fix: correct MPS execution on apple silicon#145

Merged
yanghan234 merged 11 commits intomainfrom
hanyang/fix-mps-device-handling
Apr 22, 2026
Merged

fix: correct MPS execution on apple silicon#145
yanghan234 merged 11 commits intomainfrom
hanyang/fix-mps-device-handling

Conversation

@yanghan234
Copy link
Copy Markdown
Collaborator

Summary

  • register SphericalBasisLayer.coef as a buffer so it moves with the model on MPS
  • precompute graph-derived indexing values in batch_to_dict() and move the input dict to the target device
    explicitly
  • remove MPS device-to-host synchronization hotspots in the M3GNet forward path and stress path

@yanghan234 yanghan234 force-pushed the hanyang/fix-mps-device-handling branch from 450d7cf to 191cf36 Compare April 7, 2026 12:20
@yanghan234
Copy link
Copy Markdown
Collaborator Author

@copilot resolve the merge conflicts in this pull request

Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 21, 2026

@copilot resolve the merge conflicts in this pull request

Resolved the merge conflicts by merging origin/main into this branch and reconciling the conflicting files in potential.py and tests/conftest.py. Addressed in commit 4e668b1.

yanghan234 and others added 7 commits April 21, 2026 17:54
Add conftest.py with a device fixture that auto-detects available
torch devices. Tests using the fixture run on all available backends.
A --device flag allows restricting to a single device.

Converted test_batch_relax.py from unittest to pytest style to use
the device fixture. Verified passing on both cpu and mps.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- M3Gnet.forward now uses .get() with fallback computation for
  precomputed keys (total_num_atoms, bond_index_bias, etc.), so
  callers constructing input dicts directly won't KeyError.
- batch_to_dict creates index_map on CPU (moved to device at the end),
  avoiding intermediate device mismatches.
- Remove unused pytest import in test_batch_relax.py.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Prevents device mismatch if graph_batch tensors are already on
a non-CPU device when batch_to_dict is called.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Add total_num_atoms, total_num_bonds, bond_index_bias, and
three_body_edge_map to TENSOR_KEYS. Handle int values in device
assertions.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@yanghan234 yanghan234 force-pushed the hanyang/fix-mps-device-handling branch from 93a63fb to 578d379 Compare April 21, 2026 17:00
yanghan234 and others added 4 commits April 21, 2026 18:00
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@yanghan234 yanghan234 merged commit 452f378 into main Apr 22, 2026
7 of 8 checks passed
@yanghan234 yanghan234 deleted the hanyang/fix-mps-device-handling branch April 22, 2026 09:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants