Strategy proposal: data-dependent output-shape ops (unique, nonzero, boolean indexing) via a static size= argument #3685
katlun-lgtm
started this conversation in
Ideas
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
A consistent strategy for data-dependent output-shape ops in MLX
Draft design discussion for ml-explore/mlx — covering
unique*,nonzero, boolean-mask read indexinga[mask], unarywhere(cond),compress/extract,repeatwith a tensor argument.1. Problem statement
MLX fixes every array's shape at graph-build time. Concretely (mlx source):
Primitive::output_shapes(const std::vector<array>& inputs)derives output shapes from input shapes only, never input values.arrayis constructed with its shape fixed;UnaryPrimitive::eval_cpu/eval_gpu(inputs, array& output)receive an already-allocated, already-shaped output and merely fill it.So any op whose output element-count depends on input data has nowhere to express its shape. This is why the maintainer deferred the whole class (issue #856: "ops whose output shape depends on input data … we will not implement it until we have figured out a consistent strategy") and why
a[mask]read indexing raises "boolean indices are not yet supported" (#865, #246) andnonzero/ unarywhereare absent (MLX indexing docs).This document surveys how every other major framework solves it and proposes a consistent, phased strategy for MLX.
2. The op class (array API)
The Python array API standard has a dedicated design note, "Data-dependent output shapes", and a
capabilities()["data-dependent shapes"]boolean flag so a conforming library can advertise that it does not support them ([array-api design topic], [array-api info.capabilities]). The ops it flags:unique_all/unique_counts/unique_inverse/unique_valuesnonzerox[mask](explicitly optional; the standard says graph-building libraries like JAX and Dask find it hard to implement, and a library may omit it and still conform) ([array-api indexing]).MLX already returns
Falsefor that capability flag in our new__array_namespace_info__— so MLX is conformant today by declining these. The question is whether/how to offer them.3. How other frameworks handle data-dependent shapes
3.1 The array API standard — make it optional + advertise a capability
Boolean-mask filtering "the output array shape is data-dependent"; the spec makes boolean-array indexing optional and names compute-graph libraries (JAX, Dask) as the ones that legitimately omit it ([array-api indexing], [array-api design topic]). → Precedent: declining is conformant; the decision is about ergonomics, not standards compliance.
3.2 JAX / XLA — static
size=argument +fill_valuepadding (the pragmatic winner)jnp.uniqueis "not by default compatible withjit()" because its output size is data-dependent ([jax unique]).sizeargument that fixes output length at trace time; without it underjityou get a concrete-value/abstract-tracer error ([jax unique], [jax gotchas]).sizeexceeds the true count, the tail is padded withfill_value(default: the min unique value), so the buffer always matches the statically-declared shape ([jax unique]).jnp.nonzeroworks the same way ([jax nonzero]).jax2tfshape-polymorphism explicitly cannot handle output shapes that depend on input values, only symbolic expressions of input dimensions ([jax2tf README]).→ This pattern maps onto MLX with zero core changes:
sizeis an argument, sooutput_shapes(inputs)can return[{size}]— known at build time. (More in §5.)3.3 XLA / StableHLO — bounded dynamic shapes (the general fix)
s32[<=4]. The HLO primitive isset-dimension-size, which takes a static-shape array + a scalar size and yields the bounded-dynamic dim ([jax #26265]).nonzeros) and recommends modeling it via bounded dynamism — specify an upper bound, hardware implements it via tensor padding ([stablehlo dynamism]).→ This is the "real" answer (true dynamic dims) but it's a deep core change: shape inference, allocation, compile specialization, and Metal kernels all must learn about bounded dims.
3.4 PyTorch — eager natively;
torch.compile/export via unbacked SymIntstorch.compile/export, output sizes that can't be known at compile time are represented as unbacked SymInts — symbolic ints with no concrete value/hint, introduced fornonzero()/item()([pt backed-unbacked], [pt guardon]).GuardOnDataDependentSymNode, fixed withtorch._check(...)/mark_unbacked([pt guardon]).nonzero/masked_select/masked_scatterbehindXLA_EXPERIMENTAL=...and uses bounded dynamic shapes —torch.nonzero()returnstorch.Size([<=25, 2])([pt/xla dynamic_shape], [pt/xla #3884]).→ Confirms the universal split: eager = trivial; ahead-of-time/graph = needs symbolic-or-bounded dims. MLX's lazy graph sits on the graph side.
3.5 GPU implementation reality — stream compaction
The actual GPU kernel for
unique/nonzero/compressis stream compaction: a prefix-sum (scan) over a predicate mask to compute output offsets, then a scatter. NVIDIA CUB exposes exactly this ascub::DeviceSelect::Unique/Flagged, which writes a device-sided_num_selected_outcount ([cub DeviceSelect]); the count must be read back to host to know the final size ([stream-compaction blog], [arxiv 2311.02103]). On Metal this is ascan+scatterover the sorted array (forunique) or the mask (fornonzero) — implementable with MLX's existing scan/scatter machinery, but the output size still has to come back to the host unless it's fixed by asizeargument.4. MLX-specific constraints any solution must satisfy
compile()specializes on shapes (it has ashapelesspath); a data-dependent dim defeats specialization unless represented symbolically.vmap/jvp/vjp: every primitive must provide them (or throw). Batched data-dependent sizes differ per row — only thesize=-bounded form vmaps cleanly (uniformsize).export: serializes shapes.eval_gpu): writes into a pre-sized buffer; a count-then-compact op needs either a host sync to read the count, or a fixedsize.5. Options for MLX
size=+fill_value(JAX-style)f(arg)→ known at buildeval<=N, StableHLO/XLA-style)Why A is the right first step
sizeis a function argument, sooutput_shapesreturns[{size, ...}]— fully compatible with today's build-time shape model. No change to the array/shape/eval core.unique/nonzero/compressbecome compositions of existing primitives oncesizeis fixed:unique_values(x, size, fill_value)=sort(x)→ adjacent-diff mask →cumsumoffsets →scatterfirst-occurrences into asize-length buffer pre-filled withfill_value(clamp/truncate tosize).unique_counts/inverse/alladd the inverse map and run-length counts (also static oncesizeis fixed).nonzero(x, size, fill_value)= mask →cumsumoffsets → scatter indices.size, but asize-extended superset is the standard escape hatch every static-shape backend uses).compile,vmap(uniformsize),export, and even has a sensiblevjp(gather/scatter are differentiable).What A does not give
True no-
sizeunique(x)/x[mask]returning exactly-N elements. That genuinely needs Option C. But A unblocks ~all real use cases (you almost always have an upper bound) and gives MLX a conformant, documented answer instead of "not supported."6. Recommended phased plan
size=-bounded ops (no core change). Addunique_values/unique_counts/unique_inverse/unique_all(x, *, size, fill_value=...),nonzero(x, *, size, fill_value=...), optionallycompress. Pure compositions ofsort+cumsum+scatter; works on CPU and Metal today. Document thesize/fill_valuecontract (copy JAX's wording).mx.unique(x)with nosizethat internallyevals and returns a concrete-shape array) for notebooks/REPL, explicitly marked as breaking laziness and unavailable undercompile/vmap/grad. Mirrors PyTorch-eager and the maintainer's current "convert to NumPy" advice, but in-framework. (Include only if the team is comfortable with one eager op; otherwise skip.)<=Nbounded-dim concept (à la XLAset-dimension-size/ StableHLO bounded dynamism / PyTorch-XLA), letting the no-sizeforms anda[mask]return bounded-dynamic outputs. Largest change; this is the "consistent strategy" the maintainer referenced and is theirs to own — but Phases 1–2 deliver value immediately and Phase 1's ops become thesize-pinned fast path under it.7. Concrete ask for the MLX team
size=/fill_valueAPI (Phase 1) acceptable as the sanctioned pattern for this op class? If so I can open a PR forunique_*+nonzerobuilt on existingsort/cumsum/scatter(CPU + Metal, with tests).Appendix — sources (all primary unless noted)
usage/indexing.html); issues Need for implementing .unique() for arrays #856 (unique), requirement for implementing boolean indices formx.array#865 (boolean indices), boolean mask or filter? #246 (boolean mask/filter).design_topics/data_dependent_output_shapes.html;API_specification/indexing.html;2024.12 … info.capabilities.jax.numpy.unique,jax.numpy.nonzero,Common_Gotchas_in_JAX,jax2tf/README.md, jax issue #26265 (bounded shapes / set-dimension-size).torch.compiler_dynamic_shapes,dynamic_shapes_backed_unbacked,dynamic_shapes_troubleshooting_guardon_errors; PyTorch/XLAdynamic_shape+ issue #3884.openxla.org/stablehlo/dynamism.ShapeInference.html.DeviceSelect; arXiv 2311.02103; "stream compaction using wave intrinsics" (blog).Beta Was this translation helpful? Give feedback.
All reactions