Summary
segmented_mm accepts two 2D inputs documented as MxK and KxN, but it does not currently validate that the inner dimensions match before constructing the primitive.
Neighboring matmul-like APIs such as matmul, addmm, block_masked_mm, and gather_mm reject incompatible matrix dimensions at the API boundary.
Example
a = mx.ones((2, 3))
b = mx.ones((4, 5))
segments = mx.array([[0, 3]], dtype=mx.uint32)
mx.segmented_mm(a, b, segments)
This should raise a ValueError because 3 != 4.
Impact
Rejecting the mismatch early avoids constructing a segmented matmul with incompatible input shapes and makes segmented_mm consistent with the rest of the matmul API surface.
Summary
segmented_mmaccepts two 2D inputs documented asMxKandKxN, but it does not currently validate that the inner dimensions match before constructing the primitive.Neighboring matmul-like APIs such as
matmul,addmm,block_masked_mm, andgather_mmreject incompatible matrix dimensions at the API boundary.Example
This should raise a
ValueErrorbecause3 != 4.Impact
Rejecting the mismatch early avoids constructing a segmented matmul with incompatible input shapes and makes
segmented_mmconsistent with the rest of the matmul API surface.