Skip to content

segmented_mm should validate inner matrix dimensions #3571

@fallintoplace

Description

@fallintoplace

Summary

segmented_mm accepts two 2D inputs documented as MxK and KxN, but it does not currently validate that the inner dimensions match before constructing the primitive.

Neighboring matmul-like APIs such as matmul, addmm, block_masked_mm, and gather_mm reject incompatible matrix dimensions at the API boundary.

Example

a = mx.ones((2, 3))
b = mx.ones((4, 5))
segments = mx.array([[0, 3]], dtype=mx.uint32)
mx.segmented_mm(a, b, segments)

This should raise a ValueError because 3 != 4.

Impact

Rejecting the mismatch early avoids constructing a segmented matmul with incompatible input shapes and makes segmented_mm consistent with the rest of the matmul API surface.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions