Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hqq serialization #33141

Merged
merged 29 commits into from
Sep 30, 2024
Merged

Hqq serialization #33141

merged 29 commits into from
Sep 30, 2024

Conversation

mobicham
Copy link
Contributor

@mobicham mobicham commented Aug 27, 2024

Follow-up to #32379
The goal of this PR is to add full support to save/load HQQ-quantized models directly in transformers. So far, serialization was done on the hqq-lib side via the .pt format which is not safe and doesn't work with very large models (>100B params) since the model is not sharded.

What was done during this PR:

  • Make sure saving/loading HQQ-quantized models works properly.
  • Make sure multi-gpu support works with the hqq backends (required some updates the hqq lib side)
  • Make sure adding biases in architectures that do not have a bias by default works (biases are used in some HQQ-calibrated models)
  • Added update_expected_keys() call in the quantizer. This allows loading quantized models that were initialized with torch.nn.Linear instead

Full gist to try it out: https://gist.github.com/mobicham/701dd564c52590203ee09631425ad797

@mobicham
Copy link
Contributor Author

mobicham commented Aug 27, 2024

1/3
5cb7d81
Removed the check_old_param hack.
The problem however is that HQQLinear.state_dict is huge, which makes loading extremely slow. So I added run_expected_keys_check which skips those checks for HQQLinear params. I am not sure if it's a clean way. If you just init a dummy HQQLinear you wouldn't get all the state_dict params anyway 🤔 so if you disable that check it will complain that the parameters is not in the expected keys, let me know if there's a better way of doing this

@LysandreJik LysandreJik requested a review from SunMarc August 27, 2024 11:28
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Let's fix the issue regarding the torchao backend and we can merge this. I left a few comments

src/transformers/quantizers/quantizer_hqq.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@mobicham
Copy link
Contributor Author

2/3: Multi-gpu loading
Loading on multi-gpu looks like it's working fine. There's an issue with the BitBlas backend I just reported here
Forcing the input to use the same device was done on the hqq lib side.

3/3: state_dict on the same safetensor chunk
I run tests with different models and it's working fine ( gist):

model_id  = 'meta-llama/Meta-Llama-3-8B-Instruct' #OK
model_id  = 'meta-llama/Meta-Llama-3-70B' #OK 
model_id = "facebook/opt-125m" #OK
model_id = "meta-llama/Llama-2-13b-chat-hf" #OK
model_id = "microsoft/Phi-3-mini-128k-instruct" #OK
model_id = "google/gemma-2-9b-it" #OK
model_id = "google/gemma-2-2b" #OK

so I think for the moment we can leave it until someone reports some issue, I can't reproduce the problem anyway.

Next steps:

  • Revisit the comments above (@mobicham )
  • Change/disable settings for hqqConfig because now saving/loading doesn't support quant scales/zeros as well as meta-data offloading. Need to deprecate it as well on the hqq lib side and a new pip version 2.0.0 (@mobicham )

@mobicham
Copy link
Contributor Author

@SunMarc

  • Reverted back to if isinstance(module, (torch.nn.Linear, HQQLinear)): but we still need that run_expected_keys_check otherwise it breaks
  • Updated the default HqqConfig default params since quant_scale, quant_zero, and offload_meta are now deprecated. Also done on the hqq-lib side. I also updated the tests, the doc and made a new hqq lib pip release 0.2.0

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a couple of comments !

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/quantizers/quantizer_hqq.py Outdated Show resolved Hide resolved
tests/quantization/hqq/test_hqq.py Show resolved Hide resolved
src/transformers/utils/quantization_config.py Show resolved Hide resolved
src/transformers/utils/quantization_config.py Show resolved Hide resolved
@mobicham
Copy link
Contributor Author

Regarding this: #33141 (comment)
The issue is that to remove that additional check, we need to have all the HQQLinear dict keys for each layer in the list of expected keys. There are 19 keys per HQQLinear module. For a small model like LLama3-8B, that means 32*7*19=4256 checks per parameter which is extremely slow

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a suggestion about axis

src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
@blap
Copy link

blap commented Sep 26, 2024

Just for curiosity, what miss to merge?

@SunMarc
Copy link
Member

SunMarc commented Sep 26, 2024

Just for curiosity, what miss to merge?

Waiting for @mobicham to check the latest review and give me to heads-up to merge ! This should be done soon ! Also it looks like that there are some conflits to fix

@SunMarc
Copy link
Member

SunMarc commented Sep 30, 2024

Thanks for iterating @mobicham! Merging!

@SunMarc SunMarc merged commit f5247ac into huggingface:main Sep 30, 2024
24 checks passed
Cyrilvallez pushed a commit that referenced this pull request Sep 30, 2024
* HQQ model serialization attempt

* fix hqq dispatch and unexpected keys

* style

* remove check_old_param

* revert to check HQQLinear in quantizer_hqq.py

* revert to check HQQLinear in quantizer_hqq.py

* update HqqConfig default params

* make ci happy

* make ci happy

* revert to HQQLinear check in quantizer_hqq.py

* check hqq_min version 0.2.0

* set axis=1 as default in quantization_config.py

* validate_env with hqq>=0.2.0 version message

* deprecated hqq kwargs message

* make ci happy

* remove run_expected_keys_check hack + bump to 0.2.1 min hqq version

* fix unexpected_keys hqq update

* add pre_quantized check

* add update_expected_keys to base quantizerr

* ci base.py fix?

* ci base.py fix?

* fix "quantization typo" src/transformers/utils/quantization_config.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix post merge

---------

Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@rohit-gupta
Copy link

@mobicham minor documentation issue, but the transformers documentation page for quantization has a giant features matrix which still says serialization of HQQ models is not supported

https://huggingface.co/docs/transformers/main/quantization/overview

@SunMarc
Copy link
Member

SunMarc commented Nov 27, 2024

Would you like to open a PR to fix this @rohit-gupta ?

@mobicham
Copy link
Contributor Author

mobicham commented Dec 2, 2024

@rohit-gupta thanks for flagging !

@blap
Copy link

blap commented Dec 2, 2024

now model.save_pretrained(save_path) give this:


Traceback (most recent call last):
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\hqq1b.py", line 35, in <module>
    model.save_pretrained(save_path)
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\transformers\modeling_utils.py", line 2932, in save_pretrained
    state_dict_split = split_torch_state_dict_into_shards(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\huggingface_hub\serialization\_torch.py", line 330, in split_torch_state_dict_into_shards
    return split_state_dict_into_shards_factory(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\huggingface_hub\serialization\_base.py", line 108, in split_state_dict_into_shards_factory
    storage_id = get_storage_id(tensor)
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\huggingface_hub\serialization\_torch.py", line 382, in get_torch_storage_id
    if tensor.device.type == "meta":
       ^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'device'


@mobicham
Copy link
Contributor Author

mobicham commented Dec 3, 2024

@blap is this related to the latest transformer changes? Otherwise, which hqq version causes this?

@blap
Copy link

blap commented Dec 3, 2024

@blap is this related to the latest transformer changes? Otherwise, which hqq version causes this?

I think so. I didn't had this problem in the release of hqq in transformers.
hqq version: 0.2.3
transformers version: 4.47.0.dev0

@mobicham
Copy link
Contributor Author

mobicham commented Dec 3, 2024

@blap is this related to the latest transformer changes? Otherwise, which hqq version causes this?

I think so. I didn't had this problem in the release of hqq in transformers. hqq version: 0.2.3 transformers version: 4.47.0.dev0

@SunMarc do you know what was changed by any chance?

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* HQQ model serialization attempt

* fix hqq dispatch and unexpected keys

* style

* remove check_old_param

* revert to check HQQLinear in quantizer_hqq.py

* revert to check HQQLinear in quantizer_hqq.py

* update HqqConfig default params

* make ci happy

* make ci happy

* revert to HQQLinear check in quantizer_hqq.py

* check hqq_min version 0.2.0

* set axis=1 as default in quantization_config.py

* validate_env with hqq>=0.2.0 version message

* deprecated hqq kwargs message

* make ci happy

* remove run_expected_keys_check hack + bump to 0.2.1 min hqq version

* fix unexpected_keys hqq update

* add pre_quantized check

* add update_expected_keys to base quantizerr

* ci base.py fix?

* ci base.py fix?

* fix "quantization typo" src/transformers/utils/quantization_config.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix post merge

---------

Co-authored-by: Marc Sun <marc@huggingface.co>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@blap
Copy link

blap commented Dec 9, 2024

Transformers version 4.48.0.dev0 still has this problem...

@mobicham
Copy link
Contributor Author

mobicham commented Dec 9, 2024

Any one from the HF team can track down this problem please? What changed ? Nothing on the hqq lib side changed much.

@blap
Copy link

blap commented Dec 17, 2024

@SunMarc ?

@SunMarc
Copy link
Member

SunMarc commented Dec 24, 2024

Can you share your script @blap ? I'll have a look asap !

@blap
Copy link

blap commented Dec 24, 2024

Can you share your script @blap ? I'll have a look asap !


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

model_id      = "mllmTeam/PhoneLM-1.5B"
repo          = "PhoneLM-1.5B"
nbits         = 4
group_size    = None
axis          = 0
save_path     = repo+"-nbits"+str(nbits)+"-GS"+str(group_size)+"-Axis"+str(axis)+"-HQQ2"
cache_dir     = repo+"-cache"
device        = "cpu"
compute_dtype = torch.float16

#Quantize
quant_config  = HqqConfig(nbits=nbits, group_size=group_size, axis=axis, quant_scale=False, quant_zero=False)

#Load the model
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    torch_dtype=compute_dtype, 
    cache_dir=cache_dir,
    device_map=device, 
    quantization_config=quant_config,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)

# Save
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

Error:


Traceback (most recent call last):
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\hqq1b.py", line 32, in <module>
    model.save_pretrained(save_path)
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\transformers\modeling_utils.py", line 2971, in save_pretrained
    state_dict_split = split_torch_state_dict_into_shards(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\huggingface_hub\serialization\_torch.py", line 369, in split_torch_state_dict_into_shards
    return split_state_dict_into_shards_factory(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\huggingface_hub\serialization\_base.py", line 108, in split_state_dict_into_shards_factory
    storage_id = get_storage_id(tensor)
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\huggingface_hub\serialization\_torch.py", line 746, in get_torch_storage_id
    if tensor.device.type == "meta":
       ^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'device'


@blap
Copy link

blap commented Dec 30, 2024

So...
Any ideas how to save?

@mobicham
Copy link
Contributor Author

@blap why don't you use the latest release ? It works fine last time I tried (last week)

@blap
Copy link

blap commented Dec 30, 2024

@blap why don't you use the latest release ? It works fine last time I tried (last week)

Which version do you use?

Version 4.45.2 give me this:


Traceback (most recent call last):
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\hqq1b.py", line 37, in <module>
    model.save_pretrained(save_path)
  File "C:\Users\Admin\Desktop\Python\0.LLMs\hqq\venv\Lib\site-packages\transformers\modeling_utils.py", line 2565, in save_pretrained
    raise ValueError(
ValueError: The model is quantized with QuantizationMethod.HQQ and is not serializable - check out the warnings from the logger on the traceback to understand the reason why the quantized model is not serializable.


@mobicham
Copy link
Contributor Author

@blap 4.47.0 works for sure

@blap
Copy link

blap commented Dec 30, 2024

@blap 4.47.0 works for sure

I just got the same error in this version too.
I tried others models without success.
I only run properly only on hqq, not on transformers.
Can you show some code, please?

@mobicham
Copy link
Contributor Author

@blap

# pip install transformers==4.47.0;
# pip install hqq --upgrade;
##################################################################
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
quant_model = "quant_model"

quant_config = HqqConfig(nbits=4, group_size=64, axis=1)

model = AutoModelForCausalLM.from_pretrained(model_path,
                                            torch_dtype=torch.float16,
                                            cache_dir='.',
                                            device_map="cuda:0",
                                            quantization_config=quant_config,
                                            low_cpu_mem_usage=True)

tokenizer = AutoTokenizer.from_pretrained(model_path)

model.save_pretrained(quant_model)
tokenizer.save_pretrained(quant_model)

@blap
Copy link

blap commented Dec 30, 2024

I found the problem:
If I use group_size=None I got the error.

@mobicham
Copy link
Contributor Author

mobicham commented Dec 31, 2024

I found the problem: If I use group_size=None I got the error.

Hmm interesting, thanks for flagging! Fixed here.

Would recommend using 64 or 128 though, some of the fast kernels like Marlin in VLLM and TinyGemm in torchao don't support group_size=None anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants