Claude/cellpose mlx support#1
Conversation
Implement MLX (Apple's ML framework) as an alternative backend for cellpose inference on Apple Silicon Macs (M1/M2/M3/M4), providing native GPU acceleration without PyTorch MPS overhead. New files: - cellpose/mlx_net.py: MLX implementation of the CP-SAM Transformer (ViT encoder, attention with relative position embeddings, neck, readout head with pixel shuffle) - cellpose/mlx_utils.py: PyTorch-to-MLX weight conversion utilities (key mapping, Conv2d weight transposition OIHW->OHWI) Modified files: - cellpose/models.py: Add use_mlx parameter to CellposeModel - cellpose/core.py: Add _forward_mlx() and pass use_mlx_backend through run_net/run_3D - cellpose/cli.py: Add --use_mlx CLI flag - cellpose/__main__.py: Wire use_mlx to model creation - setup.py: Add optional 'mlx' extras_require Usage: model = CellposeModel(use_mlx=True) # or CLI: cellpose --use_mlx --dir /path/to/images Falls back gracefully to PyTorch if MLX is not installed. https://claude.ai/code/session_01EfKu1kx3mC9ZWXvzrmPgy5
- Fix _add_decomposed_rel_pos: add proper dimension expansion (unsqueeze) for rel_h and rel_w when adding to 5D attention tensor, matching SAM's original implementation exactly - Pre-compute interpolated relative position embeddings at weight load time to avoid repeated scipy interpolation during inference (48x per forward pass) - Guard torch.cuda.empty_cache() with is_available() check to prevent crash on non-CUDA systems (Apple Silicon) - Add use_mlx="auto" mode for automatic MLX detection on Apple Silicon when CUDA is not available - Update CLI --use_mlx to accept optional 'auto' value https://claude.ai/code/session_01EfKu1kx3mC9ZWXvzrmPgy5
|
You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard. |
There was a problem hiding this comment.
Pull request overview
This PR introduces an optional MLX inference backend for the Cellpose-SAM (CP4) Transformer to accelerate inference on Apple Silicon, and wires backend selection through the CLI and model execution path.
Changes:
- Added an MLX-native CP-SAM Transformer implementation and PyTorch→MLX weight conversion utilities.
- Added MLX backend selection plumbing through
CellposeModelintocore.run_net/core.run_3D. - Added a
--use_mlxCLI flag and propagated it through__main__.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| setup.py | Adds an optional mlx extra dependency group. |
| cellpose/models.py | Adds use_mlx option to CellposeModel and routes inference through MLX when enabled. |
| cellpose/core.py | Adds MLX availability detection and an MLX forward path for run_net / run_3D. |
| cellpose/cli.py | Adds --use_mlx argument to enable/auto-detect MLX backend. |
| cellpose/main.py | Passes CLI use_mlx into CellposeModel. |
| cellpose/mlx_utils.py | New utilities for converting/loading PyTorch checkpoint weights into MLX format. |
| cellpose/mlx_net.py | New MLX implementation of the CP-SAM Transformer model. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability. | ||
| style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles. | ||
| """ | ||
| forward_fn = _forward_mlx if use_mlx_backend else _forward |
There was a problem hiding this comment.
run_net(..., use_mlx_backend=True) will crash with a NameError when MLX isn’t installed because mx is undefined (only set inside the import try/except). Please add an explicit guard (e.g., if use_mlx_backend and not MLX_AVAILABLE: raise ImportError with install instructions) before selecting _forward_mlx.
| forward_fn = _forward_mlx if use_mlx_backend else _forward | |
| if use_mlx_backend: | |
| if not MLX_AVAILABLE: | |
| raise ImportError( | |
| "MLX backend requested (use_mlx_backend=True) but the 'mlx' package is not installed. " | |
| "Please install it with `pip install mlx` and try again." | |
| ) | |
| forward_fn = _forward_mlx | |
| else: | |
| forward_fn = _forward |
| def use_mlx(): | ||
| """Check if MLX is available and running on Apple Silicon. | ||
|
|
||
| Returns: | ||
| bool: True if MLX is available. | ||
| """ | ||
| if not MLX_AVAILABLE: | ||
| return False | ||
| try: | ||
| _ = mx.zeros((1, 1)) | ||
| mx.eval(_) | ||
| core_logger.info("** MLX available on Apple Silicon. **") | ||
| return True | ||
| except Exception: | ||
| return False | ||
|
|
||
|
|
There was a problem hiding this comment.
use_mlx() is currently unused anywhere in the codebase (search only finds its definition). Consider either wiring it into the MLX selection logic (e.g., CellposeModel(use_mlx="auto")) or removing it to avoid maintaining a misleading/unused API surface.
| def use_mlx(): | |
| """Check if MLX is available and running on Apple Silicon. | |
| Returns: | |
| bool: True if MLX is available. | |
| """ | |
| if not MLX_AVAILABLE: | |
| return False | |
| try: | |
| _ = mx.zeros((1, 1)) | |
| mx.eval(_) | |
| core_logger.info("** MLX available on Apple Silicon. **") | |
| return True | |
| except Exception: | |
| return False |
| "--gpu_device", required=False, default="0", type=str, | ||
| help="which gpu device to use, use an integer for torch, or mps for M1") | ||
| hardware_args.add_argument( | ||
| "--use_mlx", nargs="?", const=True, default=False, |
There was a problem hiding this comment.
--use_mlx currently accepts any optional value as a string (because nargs="?" with no type/choices), so --use_mlx false or --use_mlx 0 will be treated as truthy and enable MLX. Consider restricting values via choices=["auto"] (so only --use_mlx or --use_mlx auto are valid) or adding a parser that normalizes true/false/auto to avoid surprising behavior.
| "--use_mlx", nargs="?", const=True, default=False, | |
| "--use_mlx", nargs="?", const=True, default=False, type=str, choices=["auto"], |
| # Check if MLX backend is requested or auto-detected | ||
| if use_mlx == "auto": | ||
| # Auto-detect: use MLX on Apple Silicon when CUDA is not available | ||
| if MLX_AVAILABLE and not torch.cuda.is_available(): | ||
| self.use_mlx = True | ||
| models_logger.info( | ||
| "MLX auto-detected on Apple Silicon (no CUDA available). " | ||
| "Using MLX backend." | ||
| ) | ||
| else: | ||
| self.use_mlx = False | ||
| elif use_mlx: | ||
| self.use_mlx = MLX_AVAILABLE | ||
| if not MLX_AVAILABLE: | ||
| models_logger.warning( | ||
| "MLX backend requested but MLX is not available. " | ||
| "Install with: pip install mlx. Falling back to PyTorch." | ||
| ) | ||
| else: | ||
| self.use_mlx = False |
There was a problem hiding this comment.
use_mlx treats any truthy value as a request for MLX, so if CLI passes a string like "false"/"0" it will still enable MLX. Since the CLI flag can produce strings, please validate/normalize string inputs (e.g., only accept "auto" as a string and otherwise raise a ValueError, or coerce common string booleans) before the elif use_mlx: branch.
| try: | ||
| import mlx.core as mx | ||
| MLX_AVAILABLE = True | ||
| except ImportError: | ||
| MLX_AVAILABLE = False | ||
|
|
There was a problem hiding this comment.
When MLX isn’t installed, MLX_AVAILABLE is set to False but mx is never defined; functions like convert_pytorch_to_mlx_weights and save_mlx_weights will then raise NameError if called. Consider defining mx=None in the ImportError branch and adding an explicit check at the start of MLX-dependent functions to raise a clear ImportError when MLX_AVAILABLE is False.
This pull request adds support for Apple Silicon acceleration in Cellpose by introducing an MLX backend. The main changes include adding a new MLX-based model implementation, updating the CLI and core logic to allow selecting the MLX backend, and ensuring the relevant device checks and inference code paths are available. These changes enable users on macOS with Apple Silicon to leverage MLX for faster inference without relying on PyTorch/MPS.
MLX backend support for Apple Silicon:
cellpose/mlx_net.pyimplementing the Cellpose-SAM Transformer model using MLX for native Apple Silicon GPU acceleration, including all necessary model components and weight-loading logic.cellpose/core.py, added MLX availability detection and ause_mlx()helper to check for MLX and Apple Silicon at runtime. [1] [2]_forward_mlx()for running inference with MLX models, and updatedrun_net()andrun_3D()to optionally use the MLX backend for inference. [1] [2] [3] [4] [5] [6]CLI and argument updates:
cellpose/cli.pyto add a--use_mlxcommand-line argument, allowing users to enable or auto-detect the MLX backend for Apple Silicon acceleration.Main execution and device selection:
cellpose/__main__.pyto pass theuse_mlxargument and initialize the model with MLX support when requested.