Skip to content

[BUG] qmm_naive picks the wrong K-tail path for some FP quantized shapes #3444

@Lyxot

Description

@Lyxot

This shows up on CUDA naive QMM when all of the following are true:

  • transpose=False
  • mode is mxfp4, mxfp8, or nvfp4
  • K % group_size == 0
  • K % max(64, group_size) != 0

Examples:

  • mxfp4 / mxfp8: K = 544, 608
  • nvfp4: K = 528, 544, 560, 592, 608, 624

Reproduce

mx.gather_qmm is the easiest way to observe it.

import mlx.core as mx

def gather_sort(x, indices):
    n, m = indices.shape
    flat = indices.flatten()
    order = mx.argsort(flat)
    inv_order = mx.argsort(order)
    return x.flatten(0, -3)[order // m], flat[order], inv_order


def scatter_unsort(x, inv_order, shape):
    return mx.unflatten(x[inv_order], 0, shape)


mx.set_default_device(mx.gpu)

L, K0, D0, E, I = 64, 512, 544, 4, 2
transpose = False
mode = "mxfp4"
dtype = mx.bfloat16

K, D = (K0, D0) if transpose else (D0, K0)
key = mx.random.key(0)
k1, k2, k3 = mx.random.split(key, 3)

indices = (mx.random.uniform(shape=(L, I), key=k1) * E).astype(mx.uint32)
x = (mx.random.normal((L, 1, 1, K), key=k2) / K**0.5).astype(dtype)
w = (mx.random.normal((E, K, D), key=k3) / K**0.5).astype(dtype)
wq, scales = mx.quantize(w, mode=mode)

y_unsorted = mx.gather_qmm(
    x,
    wq,
    scales,
    None,
    mode=mode,
    transpose=transpose,
    rhs_indices=indices,
)

xs, rhs_sorted, inv_order = gather_sort(x, indices)
y_sorted = mx.gather_qmm(
    xs,
    wq,
    scales,
    None,
    mode=mode,
    transpose=transpose,
    rhs_indices=rhs_sorted,
    sorted_indices=True,
)
y_sorted = scatter_unsort(y_sorted, inv_order, indices.shape)

mx.eval(y_unsorted, y_sorted)
mx.synchronize()

print("max_diff =", mx.max(mx.abs(y_unsorted - y_sorted)).item())

output:

max_diff = 0.059814453125

expected:

max_diff = 0.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions