-
Notifications
You must be signed in to change notification settings - Fork 31.4k
[core] Fix torchao #42289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Fix torchao #42289
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
57f7874 to
9607be8
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few nits but overall fine !
| source_keys: list[str], | ||
| target_keys: list[str], | ||
| full_layer_name: str, | ||
| model, | ||
| missing_keys, | ||
| config, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most of the args should be optional kwargs so that we can clean the other convert function with **kwargs and only put args that are being used but that's fine, we should do that later on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
| if self.pre_quantized: | ||
| return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is typically one of the cases where we would get the wrong numel calculation since we are skipping them. We should try to fix that at some point, as this should be quite simple
| WeightConverter( | ||
| source_keys=["weight:_data"], | ||
| target_keys="weight", | ||
| operations=[TorchAoDeserialize(self)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A WeightRename should be enough in this case no ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes both are fine i guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI we just changed weight:_data to weight_qdata so these things can be attached to module directly incase we need it in the future. pytorch/ao@ba3ac9f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weight converter is better than WeightRename here because there is an op!
| full_layer_name: str | None = None, | ||
| missing_keys=None, | ||
| **kwargs, | ||
| ) -> dict[str, torch.Tensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the safe serialization don't work yet because of torchao, so it is fine to just clean a bit, we can come back to that later on
| # print("metadata", self.hf_quantizer.metadata) | ||
| raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") | ||
|
|
||
| new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in a followup pr, we can modify this to work with all tensor subclasses and for sharded checkpoint files.
im thinking that in this convert function, we load in the tensor subclass components (ie _weight_qdata) as module parameters. after all files are loaded, we can delete them and replace the actual layer weights with the reconstructed quantized tensors.
see #41998 for details - will this approach still work with the new refactoring? cc @jerryzh168
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@liangel-02 yeah I think our original approach should still work, I guess it's fine to land this PR first and you can re-open #41998 on top of these new changes, since you are more familiar with this part
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks both for chiming in! 🤗
| if self.pre_quantized: | ||
| return [ | ||
| WeightConverter( | ||
| source_keys=["weight:qdata", "weight:scale", "weight:zero_point"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe also add [weight_qdata, weight_scale] as well since zero_point may be optional, like https://github.com/pytorch/ao/blob/2ff1eb2e356275cfbe46832327387d382c72945d/torchao/quantization/quantize_/workflows/float8/float8_tensor.py#L99
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do that in a follow up pr since the safetensors support is broken with the latest torchao version
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! thanks 🤗
| # print("metadata", self.hf_quantizer.metadata) | ||
| raise ValueError("To use `safetensors` serialization, you should have `torchao>=0.14.0` installed") | ||
|
|
||
| new_param = unflatten_tensor_state_dict(param_data, self.hf_quantizer.metadata)[full_layer_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks both for chiming in! 🤗
| source_keys: list[str], | ||
| target_keys: list[str], | ||
| full_layer_name: str, | ||
| model, | ||
| missing_keys, | ||
| config, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed
| if hf_quantizer is not None: | ||
| weight_conversions.extend(hf_quantizer.get_weight_conversions()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
|
[For maintainers] Suggested jobs to run (before merge) run-slow: finegrained_fp8, torchao_integration |
What does this PR do?
Refactors torchao quantization method to use conversion ops instead of the classical
create_quantized_param