This shows up on CUDA naive QMM when all of the following are true:
transpose=False
mode is mxfp4, mxfp8, or nvfp4
K % group_size == 0
K % max(64, group_size) != 0
Examples:
mxfp4 / mxfp8: K = 544, 608
nvfp4: K = 528, 544, 560, 592, 608, 624
Reproduce
mx.gather_qmm is the easiest way to observe it.
import mlx.core as mx
def gather_sort(x, indices):
n, m = indices.shape
flat = indices.flatten()
order = mx.argsort(flat)
inv_order = mx.argsort(order)
return x.flatten(0, -3)[order // m], flat[order], inv_order
def scatter_unsort(x, inv_order, shape):
return mx.unflatten(x[inv_order], 0, shape)
mx.set_default_device(mx.gpu)
L, K0, D0, E, I = 64, 512, 544, 4, 2
transpose = False
mode = "mxfp4"
dtype = mx.bfloat16
K, D = (K0, D0) if transpose else (D0, K0)
key = mx.random.key(0)
k1, k2, k3 = mx.random.split(key, 3)
indices = (mx.random.uniform(shape=(L, I), key=k1) * E).astype(mx.uint32)
x = (mx.random.normal((L, 1, 1, K), key=k2) / K**0.5).astype(dtype)
w = (mx.random.normal((E, K, D), key=k3) / K**0.5).astype(dtype)
wq, scales = mx.quantize(w, mode=mode)
y_unsorted = mx.gather_qmm(
x,
wq,
scales,
None,
mode=mode,
transpose=transpose,
rhs_indices=indices,
)
xs, rhs_sorted, inv_order = gather_sort(x, indices)
y_sorted = mx.gather_qmm(
xs,
wq,
scales,
None,
mode=mode,
transpose=transpose,
rhs_indices=rhs_sorted,
sorted_indices=True,
)
y_sorted = scatter_unsort(y_sorted, inv_order, indices.shape)
mx.eval(y_unsorted, y_sorted)
mx.synchronize()
print("max_diff =", mx.max(mx.abs(y_unsorted - y_sorted)).item())
output:
max_diff = 0.059814453125
expected:
This shows up on CUDA naive QMM when all of the following are true:
transpose=Falsemodeismxfp4,mxfp8, ornvfp4K % group_size == 0K % max(64, group_size) != 0Examples:
mxfp4/mxfp8:K = 544,608nvfp4:K = 528,544,560,592,608,624Reproduce
mx.gather_qmmis the easiest way to observe it.output:
expected: