Skip to content

Conversation

@MekkCyber
Copy link
Contributor

@MekkCyber MekkCyber commented Nov 19, 2025

What does this PR do?

Refactors torchao quantization method to use conversion ops instead of the classical create_quantized_param

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few nits but overall fine !

Comment on lines 119 to 125
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
model,
missing_keys,
config,
**kwargs,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of the args should be optional kwargs so that we can clean the other convert function with **kwargs and only put args that are being used but that's fine, we should do that later on

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Comment on lines +241 to +242
if self.pre_quantized:
return False
Copy link
Member

@SunMarc SunMarc Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is typically one of the cases where we would get the wrong numel calculation since we are skipping them. We should try to fix that at some point, as this should be quite simple

Comment on lines +552 to +555
WeightConverter(
source_keys=["weight:_data"],
target_keys="weight",
operations=[TorchAoDeserialize(self)],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A WeightRename should be enough in this case no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes both are fine i guess

Copy link
Contributor

@jerryzh168 jerryzh168 Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI we just changed weight:_data to weight_qdata so these things can be attached to module directly incase we need it in the future. pytorch/ao@ba3ac9f

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weight converter is better than WeightRename here because there is an op!

full_layer_name: str | None = None,
missing_keys=None,
**kwargs,
) -> dict[str, torch.Tensor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the safe serialization don't work yet because of torchao, so it is fine to just clean a bit, we can come back to that later on

# print("metadata", self.hf_quantizer.metadata)
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")

new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name]
Copy link
Contributor

@liangel-02 liangel-02 Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in a followup pr, we can modify this to work with all tensor subclasses and for sharded checkpoint files.

im thinking that in this convert function, we load in the tensor subclass components (ie _weight_qdata) as module parameters. after all files are loaded, we can delete them and replace the actual layer weights with the reconstructed quantized tensors.

see #41998 for details - will this approach still work with the new refactoring? cc @jerryzh168

Copy link
Contributor

@jerryzh168 jerryzh168 Nov 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@liangel-02 yeah I think our original approach should still work, I guess it's fine to land this PR first and you can re-open #41998 on top of these new changes, since you are more familiar with this part

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks both for chiming in! 🤗

if self.pre_quantized:
return [
WeightConverter(
source_keys=["weight:qdata", "weight:scale", "weight:zero_point"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe also add [weight_qdata, weight_scale] as well since zero_point may be optional, like https://github.com/pytorch/ao/blob/2ff1eb2e356275cfbe46832327387d382c72945d/torchao/quantization/quantize_/workflows/float8/float8_tensor.py#L99

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's do that in a follow up pr since the safetensors support is broken with the latest torchao version

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! thanks 🤗

# print("metadata", self.hf_quantizer.metadata)
raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed")

new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks both for chiming in! 🤗

Comment on lines 119 to 125
source_keys: list[str],
target_keys: list[str],
full_layer_name: str,
model,
missing_keys,
config,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

Comment on lines +3833 to +3834
if hf_quantizer is not None:
weight_conversions.extend(hf_quantizer.get_weight_conversions())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: finegrained_fp8, torchao_integration

@MekkCyber MekkCyber merged commit 5169c23 into main Nov 25, 2025
24 checks passed
@MekkCyber MekkCyber deleted the fix-torchao-v2 branch November 25, 2025 09:22
@rogeryoungh rogeryoungh mentioned this pull request Nov 26, 2025
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants