Reproducer (verified on clean upstream/main, no local changes):
import mlx.core as mx
a = mx.arange(4*5*3).reshape(4,5,3)
idx = mx.zeros((2,2,1,3), dtype=mx.int32)
out = mx.vmap(lambda x,y: mx.take_along_axis(x,y,axis=0), in_axes=(None,0))(a, idx)
Result: Bus error (core dumped)
Environment:
- MLX: upstream/main @ e40ada3 (clean worktree)
- macOS 26.3.1 (25D2128)
- Hardware: MacBook Pro, Apple M3 Max, 128 GB RAM