[Quantization] Add metal quantization for MPS devices!#43934
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Super nice, missing some tests tho!!
SunMarc
left a comment
There was a problem hiding this comment.
Thanks ! Maybe we can change the name to Metal insteal of Mlx as it can create confusion ? In the future, we might have mlx if we add compatibility with mlx models. Please add some e2e tests + add tests to check that we have the right dtype after quantization and dequantization
| orig_dtype = value.dtype # e.g. bfloat16 for Llama | ||
| return { | ||
| target_key: w_packed, | ||
| scale_key: scales.to(orig_dtype), | ||
| bias_key: biases.to(orig_dtype), |
There was a problem hiding this comment.
fine but I think _affine_quantize_tensor should return them in the right dtype already
There was a problem hiding this comment.
not sure about this, since when we quantize we keep the scales dtype the same as the weight dtype before quantization which is float32
212b192 to
6ac192b
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Thanks a lot ! Can you just update the overview docs to add this quantization method ?
|
[For maintainers] Suggested jobs to run (before merge) run-slow: metal |
…3934) * first commit * style * fix * fix * mlx -> metal * other fixes * add tests * fixes * weight -> qweight * fix * tests * fix style * fix * toctree * some docs * qweight -> weight * fix dtype * rm print * overview --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
What does this PR do?
Adds mlx quantization for mps devices leveraging the
kernelslibrary for pre-built kernels !!