Skip to content

Fix save_pretrained for quantized models with custom serialization#43096

Open
480284856 wants to merge 2 commits intohuggingface:mainfrom
480284856:fix-save-pretrained-quantized-models
Open

Fix save_pretrained for quantized models with custom serialization#43096
480284856 wants to merge 2 commits intohuggingface:mainfrom
480284856:fix-save-pretrained-quantized-models

Conversation

@480284856
Copy link
Copy Markdown

Problem

When calling save_pretrained() on mxfp4 quantized models, a NotImplementedError was raised because revert_weight_conversion() tried to reverse operations that don't implement reverse_op.

Environment

  • Transformers Version: Commit a7f29523361b2cc12e51c1f5133d95f122f6f45c (main branch)
  • Python Version: 3.10.12
  • Model: openai/gpt-oss-20b (mxfp4 quantized)

Reproduction Code

from transformers import AutoModelForCausalLM

# Load the mxfp4 quantized model
model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b")

# Attempt to save the model
model.save_pretrained("openai/gpt-oss-20b-hf")

Error Traceback

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 3
      1 from transformers import AutoModelForCausalLM
      2 model = AutoModelForCausalLM.from_pretrained("openai/gpt-oss-20b")
----> 3 model.save_pretrained("openai/gpt-oss-20b-hf")

File /workspace/transformers/src/transformers/modeling_utils.py:3237, in PreTrainedModel.save_pretrained(self, save_directory, is_main_process, state_dict, push_to_hub, max_shard_size, variant, token, save_peft_format, save_original_format, **kwargs)
   3235 # Revert all renaming and/or weight operations
   3236 if save_original_format:
-> 3237     state_dict = revert_weight_conversion(model_to_save, state_dict)

File /workspace/transformers/src/transformers/core_model_loading.py:1164, in revert_weight_conversion(model, state_dict)
   1161     return state_dict
   1163 # Reverse all Transform to correctly match keys
-> 1164 reverse_weight_conversion = [conversion.reverse_transform() for conversion in weight_conversions]

File /workspace/transformers/src/transformers/core_model_loading.py:1164, in <listcomp>(.0)
-> 1164 reverse_weight_conversion = [conversion.reverse_transform() for conversion in weight_conversions]

File /workspace/transformers/src/transformers/core_model_loading.py:531, in WeightTransform.reverse_transform(self)
   528 # Add the reverse ops if applicable (it needs to be provided at __init__)
   529 if hasattr(self, "operations"):
   530     # All reverse ops, in reverse order
-> 531     kwargs["operations"] = [op.reverse_op for op in self.operations[::-1]]

File /workspace/transformers/src/transformers/core_model_loading.py:531, in <listcomp>(.0)
-> 531     kwargs["operations"] = [op.reverse_op for op in self.operations[::-1]]

File /workspace/transformers/src/transformers/core_model_loading.py:102, in ConversionOps.reverse_op(self)
   100 @property
   101 def reverse_op(self) -> ConversionOps:
-> 102     raise NotImplementedError

NotImplementedError: 

Solution

Skip revert_weight_conversion() when the quantizer has already provided the state_dict via get_state_dict_and_metadata(), since quantizers handle their own serialization logic.

Changes

  • Added quantizer_provided_state_dict flag to track when quantizer provides state_dict
  • Skip revert_weight_conversion() when quantizer already provided serialized state_dict

Fixes the issue where mxfp4 quantized models cannot be saved due to missing reverse_op implementation.

by Claude Opus 4.5

When a quantizer provides state_dict via get_state_dict_and_metadata(),
the state_dict is already in the correct serialization format. However,
revert_weight_conversion() was still being called, which failed for
quantizers like mxfp4 whose ConversionOps don't implement reverse_op.

This fix skips revert_weight_conversion() when the quantizer has already
provided the state_dict, since quantizers handle their own serialization
logic in get_state_dict_and_metadata().

Fixes NotImplementedError when calling save_pretrained() on mxfp4
quantized models.
@ArthurZucker
Copy link
Copy Markdown
Collaborator

cc @SunMarc I think we plan to bring back serialization

@SunMarc
Copy link
Copy Markdown
Member

SunMarc commented Jan 5, 2026

hmmm indeed, we need to fix this cc @MekkCyber

Copy link
Copy Markdown
Contributor

@MekkCyber MekkCyber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @480284856 ! thanks for the contribution but i'm not really sure if we want to handle it this way, I think it makes more sense to have a reverse op for quantization too instead of handling it differently using a quantization state dict

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants