Skip to content

Claude/cellpose mlx support#1

Merged
hkmoon merged 2 commits into
mainfrom
claude/cellpose-mlx-support-rwGYw
Apr 1, 2026
Merged

Claude/cellpose mlx support#1
hkmoon merged 2 commits into
mainfrom
claude/cellpose-mlx-support-rwGYw

Conversation

@hkmoon
Copy link
Copy Markdown
Owner

@hkmoon hkmoon commented Apr 1, 2026

This pull request adds support for Apple Silicon acceleration in Cellpose by introducing an MLX backend. The main changes include adding a new MLX-based model implementation, updating the CLI and core logic to allow selecting the MLX backend, and ensuring the relevant device checks and inference code paths are available. These changes enable users on macOS with Apple Silicon to leverage MLX for faster inference without relying on PyTorch/MPS.

MLX backend support for Apple Silicon:

  • Added a new module cellpose/mlx_net.py implementing the Cellpose-SAM Transformer model using MLX for native Apple Silicon GPU acceleration, including all necessary model components and weight-loading logic.
  • In cellpose/core.py, added MLX availability detection and a use_mlx() helper to check for MLX and Apple Silicon at runtime. [1] [2]
  • Implemented _forward_mlx() for running inference with MLX models, and updated run_net() and run_3D() to optionally use the MLX backend for inference. [1] [2] [3] [4] [5] [6]

CLI and argument updates:

  • Updated cellpose/cli.py to add a --use_mlx command-line argument, allowing users to enable or auto-detect the MLX backend for Apple Silicon acceleration.

Main execution and device selection:

  • Modified the main CLI execution in cellpose/__main__.py to pass the use_mlx argument and initialize the model with MLX support when requested.

claude added 2 commits March 31, 2026 18:47
Implement MLX (Apple's ML framework) as an alternative backend for
cellpose inference on Apple Silicon Macs (M1/M2/M3/M4), providing
native GPU acceleration without PyTorch MPS overhead.

New files:
- cellpose/mlx_net.py: MLX implementation of the CP-SAM Transformer
  (ViT encoder, attention with relative position embeddings, neck,
  readout head with pixel shuffle)
- cellpose/mlx_utils.py: PyTorch-to-MLX weight conversion utilities
  (key mapping, Conv2d weight transposition OIHW->OHWI)

Modified files:
- cellpose/models.py: Add use_mlx parameter to CellposeModel
- cellpose/core.py: Add _forward_mlx() and pass use_mlx_backend through
  run_net/run_3D
- cellpose/cli.py: Add --use_mlx CLI flag
- cellpose/__main__.py: Wire use_mlx to model creation
- setup.py: Add optional 'mlx' extras_require

Usage:
  model = CellposeModel(use_mlx=True)
  # or CLI: cellpose --use_mlx --dir /path/to/images

Falls back gracefully to PyTorch if MLX is not installed.

https://claude.ai/code/session_01EfKu1kx3mC9ZWXvzrmPgy5
- Fix _add_decomposed_rel_pos: add proper dimension expansion (unsqueeze)
  for rel_h and rel_w when adding to 5D attention tensor, matching SAM's
  original implementation exactly
- Pre-compute interpolated relative position embeddings at weight load
  time to avoid repeated scipy interpolation during inference (48x per
  forward pass)
- Guard torch.cuda.empty_cache() with is_available() check to prevent
  crash on non-CUDA systems (Apple Silicon)
- Add use_mlx="auto" mode for automatic MLX detection on Apple Silicon
  when CUDA is not available
- Update CLI --use_mlx to accept optional 'auto' value

https://claude.ai/code/session_01EfKu1kx3mC9ZWXvzrmPgy5
Copilot AI review requested due to automatic review settings April 1, 2026 07:18
@chatgpt-codex-connector
Copy link
Copy Markdown

You have reached your Codex usage limits for code reviews. You can see your limits in the Codex usage dashboard.
To continue using code reviews, you can upgrade your account or add credits to your account and enable them for code reviews in your settings.

@hkmoon hkmoon merged commit fe09e09 into main Apr 1, 2026
2 checks passed
@hkmoon hkmoon deleted the claude/cellpose-mlx-support-rwGYw branch April 1, 2026 07:19
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an optional MLX inference backend for the Cellpose-SAM (CP4) Transformer to accelerate inference on Apple Silicon, and wires backend selection through the CLI and model execution path.

Changes:

  • Added an MLX-native CP-SAM Transformer implementation and PyTorch→MLX weight conversion utilities.
  • Added MLX backend selection plumbing through CellposeModel into core.run_net / core.run_3D.
  • Added a --use_mlx CLI flag and propagated it through __main__.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
setup.py Adds an optional mlx extra dependency group.
cellpose/models.py Adds use_mlx option to CellposeModel and routes inference through MLX when enabled.
cellpose/core.py Adds MLX availability detection and an MLX forward path for run_net / run_3D.
cellpose/cli.py Adds --use_mlx argument to enable/auto-detect MLX backend.
cellpose/main.py Passes CLI use_mlx into CellposeModel.
cellpose/mlx_utils.py New utilities for converting/loading PyTorch checkpoint weights into MLX format.
cellpose/mlx_net.py New MLX implementation of the CP-SAM Transformer model.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread cellpose/core.py
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
"""
forward_fn = _forward_mlx if use_mlx_backend else _forward
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run_net(..., use_mlx_backend=True) will crash with a NameError when MLX isn’t installed because mx is undefined (only set inside the import try/except). Please add an explicit guard (e.g., if use_mlx_backend and not MLX_AVAILABLE: raise ImportError with install instructions) before selecting _forward_mlx.

Suggested change
forward_fn = _forward_mlx if use_mlx_backend else _forward
if use_mlx_backend:
if not MLX_AVAILABLE:
raise ImportError(
"MLX backend requested (use_mlx_backend=True) but the 'mlx' package is not installed. "
"Please install it with `pip install mlx` and try again."
)
forward_fn = _forward_mlx
else:
forward_fn = _forward

Copilot uses AI. Check for mistakes.
Comment thread cellpose/core.py
Comment on lines +43 to +59
def use_mlx():
"""Check if MLX is available and running on Apple Silicon.

Returns:
bool: True if MLX is available.
"""
if not MLX_AVAILABLE:
return False
try:
_ = mx.zeros((1, 1))
mx.eval(_)
core_logger.info("** MLX available on Apple Silicon. **")
return True
except Exception:
return False


Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_mlx() is currently unused anywhere in the codebase (search only finds its definition). Consider either wiring it into the MLX selection logic (e.g., CellposeModel(use_mlx="auto")) or removing it to avoid maintaining a misleading/unused API surface.

Suggested change
def use_mlx():
"""Check if MLX is available and running on Apple Silicon.
Returns:
bool: True if MLX is available.
"""
if not MLX_AVAILABLE:
return False
try:
_ = mx.zeros((1, 1))
mx.eval(_)
core_logger.info("** MLX available on Apple Silicon. **")
return True
except Exception:
return False

Copilot uses AI. Check for mistakes.
Comment thread cellpose/cli.py
"--gpu_device", required=False, default="0", type=str,
help="which gpu device to use, use an integer for torch, or mps for M1")
hardware_args.add_argument(
"--use_mlx", nargs="?", const=True, default=False,
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--use_mlx currently accepts any optional value as a string (because nargs="?" with no type/choices), so --use_mlx false or --use_mlx 0 will be treated as truthy and enable MLX. Consider restricting values via choices=["auto"] (so only --use_mlx or --use_mlx auto are valid) or adding a parser that normalizes true/false/auto to avoid surprising behavior.

Suggested change
"--use_mlx", nargs="?", const=True, default=False,
"--use_mlx", nargs="?", const=True, default=False, type=str, choices=["auto"],

Copilot uses AI. Check for mistakes.
Comment thread cellpose/models.py
Comment on lines +127 to +146
# Check if MLX backend is requested or auto-detected
if use_mlx == "auto":
# Auto-detect: use MLX on Apple Silicon when CUDA is not available
if MLX_AVAILABLE and not torch.cuda.is_available():
self.use_mlx = True
models_logger.info(
"MLX auto-detected on Apple Silicon (no CUDA available). "
"Using MLX backend."
)
else:
self.use_mlx = False
elif use_mlx:
self.use_mlx = MLX_AVAILABLE
if not MLX_AVAILABLE:
models_logger.warning(
"MLX backend requested but MLX is not available. "
"Install with: pip install mlx. Falling back to PyTorch."
)
else:
self.use_mlx = False
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_mlx treats any truthy value as a request for MLX, so if CLI passes a string like "false"/"0" it will still enable MLX. Since the CLI flag can produce strings, please validate/normalize string inputs (e.g., only accept "auto" as a string and otherwise raise a ValueError, or coerce common string booleans) before the elif use_mlx: branch.

Copilot uses AI. Check for mistakes.
Comment thread cellpose/mlx_utils.py
Comment on lines +12 to +17
try:
import mlx.core as mx
MLX_AVAILABLE = True
except ImportError:
MLX_AVAILABLE = False

Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When MLX isn’t installed, MLX_AVAILABLE is set to False but mx is never defined; functions like convert_pytorch_to_mlx_weights and save_mlx_weights will then raise NameError if called. Consider defining mx=None in the ImportError branch and adding an explicit check at the start of MLX-dependent functions to raise a clear ImportError when MLX_AVAILABLE is False.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants