Skip to content

fix: propagate quantization mode in QuantizedAllToShardedLinear / QuantizedShardedToAllLinear#3133

Merged
angeloskath merged 1 commit intoml-explore:mainfrom
vskiwi:fix-quantized-sharded-mode-propagation
Feb 16, 2026
Merged

fix: propagate quantization mode in QuantizedAllToShardedLinear / QuantizedShardedToAllLinear#3133
angeloskath merged 1 commit intoml-explore:mainfrom
vskiwi:fix-quantized-sharded-mode-propagation

Conversation

@vskiwi
Copy link
Contributor

@vskiwi vskiwi commented Feb 15, 2026

Summary

Fixes #3132

QuantizedAllToShardedLinear and QuantizedShardedToAllLinear in mlx/nn/layers/distributed.py do not accept, store, or pass the mode parameter to mx.quantized_matmul. When an MXFP8-quantized QuantizedLinear is converted via shard_linear(), the mode is silently lost. The resulting sharded layer calls quantized_matmul without mode=, which defaults to "affine" — interpreting FP8 packed weights as affine int8, producing garbage output with no error.

Additionally, MXFP8 does not use biases, but both classes unconditionally accessed self["biases"], which would raise ValueError once the mode fix is applied.

Changes

  • Add mode parameter (default "affine") to both __init__ methods
  • Store self.mode and pass it to mx.quantize and mx.quantized_matmul
  • Use *biases unpacking to handle modes that don't produce biases (mxfp8, mxfp4)
  • Use self.get("biases") instead of self["biases"] for safe access (consistent with QuantizedLinear)
  • Propagate mode from source layer in from_quantized_linear
  • Include mode in _extra_repr output
  • Add distributed test for mxfp8 quantized shard_linear

Impact

This unblocks tensor parallel inference for all MXFP8-quantized models (and likely mxfp4). Confirmed working: GLM-5 754B (mlx-community/GLM-5-8bit-MXFP8, mode=mxfp8, group_size=32, bits=8) on 2× M3 Ultra 512GB at ~14 tok/s with tensor parallel.

No changes to the affine (default) code path — full backward compatibility.

Test plan

  • Existing test_shard_linear test for affine quantization is unchanged and should still pass
  • New mxfp8 test in test_shard_linear verifies mode propagation, biases=None, and output correctness
  • Formatted with black

Made with Cursor

QuantizedAllToShardedLinear and QuantizedShardedToAllLinear did not
accept, store, or forward the `mode` parameter to `mx.quantized_matmul`.
When a non-affine QuantizedLinear (e.g. mode="mxfp8") was converted via
`shard_linear()`, the mode was silently lost and `quantized_matmul`
defaulted to "affine", producing garbage output with no error.

Additionally, MXFP8 does not use biases, but both classes
unconditionally accessed `self["biases"]` which would fail once the mode
fix was applied because `mx.quantize` does not return biases for mxfp8.

Changes:
- Add `mode` parameter (default "affine") to both __init__ methods
- Store `self.mode` and pass it to `mx.quantize` and `mx.quantized_matmul`
- Use `*biases` unpacking to handle modes that don't produce biases
- Use `self.get("biases")` instead of `self["biases"]` for safe access
- Propagate mode from source layer in `from_quantized_linear`
- Include mode in `_extra_repr` output
- Add distributed test for mxfp8 quantized shard_linear

Fixes ml-explore#3132

Co-authored-by: Cursor <cursoragent@cursor.com>
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you that looks great!

I'll merge after the tests pass.

@angeloskath angeloskath merged commit e226af7 into ml-explore:main Feb 16, 2026
16 checks passed
@vskiwi vskiwi deleted the fix-quantized-sharded-mode-propagation branch February 16, 2026 12:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] QuantizedAllToShardedLinear / QuantizedShardedToAllLinear don't propagate quantization mode — silent garbage output for MXFP8 tensor parallel

2 participants