Skip to content

Enable MetalConfig to load pre-quantized MLX models from HuggingFace Hub#44348

Open
n0kovo wants to merge 4 commits intohuggingface:mainfrom
n0kovo:metal-mlx-pretrained
Open

Enable MetalConfig to load pre-quantized MLX models from HuggingFace Hub#44348
n0kovo wants to merge 4 commits intohuggingface:mainfrom
n0kovo:metal-mlx-pretrained

Conversation

@n0kovo
Copy link
Copy Markdown

@n0kovo n0kovo commented Feb 28, 2026

Summary

Most quantized models for Apple Silicon on the Hub are in MLX format. The MetalConfig quantization backend supports on-the-fly quantization of standard checkpoints but cannot load pre-quantized MLX models. This PR fixes the five issues blocking that:

  • quantizers/auto.py: Detect MLX affine quantization configs (mode=affine + bits + group_size) in AutoQuantizationConfig.from_dict() and map them to the Metal quantization method
  • utils/hub.py: Handle stale shard index files — MLX repos often copy model.safetensors.index.json from the original model referencing non-existent shards. Adds _rebuild_shard_index_from_repo() fallback that discovers actual safetensors files via HfApi and rebuilds the weight_map from their headers
  • quantizers/quantizer_metal.py: Weight conversions for pre-quantized loading (MLX biasesqbiases rename, embed_tokens dequantization, skip auto-exclusion of lm_head for pre-quantized checkpoints)
  • conversion_mapping.py: Qwen3VL key prefix mappings for MLX checkpoint format
  • integrations/metal_quantization.py: Make MetalDequantize flexible with source patterns, add _load_from_state_dict fallback for biases→qbiases rename, and add locally-compiled Metal shader fallback via torch.mps.compile_shader when the Hub kernel is unavailable or targets an incompatible MSL version

Usage

from transformers import MetalConfig, Qwen3VLForConditionalGeneration

model = Qwen3VLForConditionalGeneration.from_pretrained(
    "lmstudio-community/Qwen3-VL-8B-Instruct-MLX-8bit",
    device_map="mps",
    quantization_config=MetalConfig(bits=8, group_size=64),
)

Test plan

  • Model loads successfully with all weight conversions (253 uint32 MetalLinear layers, embed_tokens dequantized to bfloat16, lm_head as MetalLinear)
  • Metal shader fallback compiles and produces correct results (max abs diff < 0.00002 vs reference dequantize+matmul)
  • Basic text generation works
  • CI tests for existing Metal quantization functionality (on-the-fly quantization, dequantize mode)

Most quantized models for Apple Silicon on the Hub are in MLX format.
The MetalConfig quantization backend can quantize standard checkpoints
on-the-fly but cannot load pre-quantized MLX models. This commit fixes
the five issues blocking that:

1. auto.py: Detect MLX affine quantization configs (mode=affine + bits
   + group_size) in AutoQuantizationConfig.from_dict() and map them to
   the Metal quantization method.

2. hub.py: Handle stale shard index files. MLX repos often copy
   model.safetensors.index.json from the original model, referencing
   non-existent shards. Add _rebuild_shard_index_from_repo() fallback
   that discovers actual safetensors files via HfApi and rebuilds the
   weight_map from their headers.

3. quantizer_metal.py: Add weight conversions for pre-quantized loading:
   - Rename MLX "biases" keys to "qbiases" (MetalLinear convention)
   - Dequantize embed_tokens back to float (nn.Embedding expects float)
   - Skip auto-exclusion of lm_head for pre-quantized checkpoints since
     MLX models typically quantize the output head too

4. conversion_mapping.py: Add Qwen3VL key prefix mappings for MLX
   checkpoint format (language_model.model.* -> model.language_model.*,
   language_model.lm_head.* -> lm_head.*, vision_tower.* -> model.visual.*)

5. metal_quantization.py:
   - Make MetalDequantize use source_patterns from kwargs as dict keys
     for flexibility with pattern-specific converters
   - Add _load_from_state_dict fallback for biases->qbiases rename
   - Add locally-compiled Metal shader fallback via torch.mps.compile_shader
     when the Hub kernel is unavailable or targets an incompatible MSL version
The qwen3_vl entry in conversion_mapping.py broke CI tests because:
1. ruff format: formatting issues in quantizer_metal.py and hub.py
2. tests_torch: conversion_mapping entries must be bidirectional (for
   save/load round-trips), but the regex-anchored MLX key patterns
   aren't reversible

Fix by moving the MLX-specific key renamings into MetalHfQuantizer's
get_weight_conversions(), where they only apply during pre-quantized
loads and don't interfere with standard checkpoint tests.
The pre-built metallib targets MSL 4.0 (macOS 26) which is rejected
by the Metal runtime on macOS 15.x, printing "Failed to create Metal
library from embedded header" to stderr before raising. Redirect fd 2
to /dev/null during the smoke test to avoid noisy output.
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: metal

@Rocketknight1
Copy link
Copy Markdown
Member

cc @MekkCyber maybe? You might need to pull in someone else for MLX though

@n0kovo
Copy link
Copy Markdown
Author

n0kovo commented Mar 2, 2026

cc @SunMarc

Copy link
Copy Markdown
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this @n0kovo ! We are not sure yet if this is how we want to run mlx models in transformers. We will keep you updated !

@n0kovo
Copy link
Copy Markdown
Author

n0kovo commented Mar 2, 2026

Thanks for adding this @n0kovo ! We are not sure yet if this is how we want to run run mlx models in transformers. We will keep you updated !

Thanks! Feel free to ping me if I can contribute in any way!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants