Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HQQ quantization support #29637

Merged
merged 78 commits into from
May 2, 2024
Merged

Add HQQ quantization support #29637

merged 78 commits into from
May 2, 2024

Conversation

mobicham
Copy link
Contributor

@mobicham mobicham commented Mar 13, 2024

This PR is intended to add support for Half-Quadratic Quantization (HQQ) to the transformers library as requested in #28328

HQQ has been gaining popularity lately since it's fast to quantize and produces good quality models without using any calibration data. More details here: https://github.com/mobiusml/hqq/

Since quantization requires a cuda device, the quantization step is happening in create_quantized_param().

The tricky part is the serialization: the current logic unfortunately is not compatible with HQQLinear's state_dict structure. For now, I am using the same logic from the hqq package, which stores state_dicts of the modules and uses torch.save for saving the weights. Until we figure out a better way of doing it, this is the best solution I found so far.

I added a progress-bar (inside the model loading progress-bar :D) to keep track of the quantization step. I noticed that the quantization step is slower than doing it with the hqq package.

I wrote some basic tests to check if the quantization is done properly on a Mistral model with different settings.

Full example here: hqq_transformers_llama_example.py

Let me know if you have any questions or requests!

Thank you in advance!

@amyeroberts
Copy link
Collaborator

amyeroberts commented Mar 13, 2024

Hi @mobicham, thanks for opening this PR!

Is there an associated issue/ feature request for this? Is so, could you add to the PR description?

cc @younesbelkada @SunMarc

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mobicham, thanks for working on this ! This looks very good already. I left a few comments. Let me know you have any questions !

docker/transformers-quantization-latest-gpu/Dockerfile Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
src/transformers/quantizers/quantizer_hqq.py Outdated Show resolved Hide resolved
Comment on lines 104 to 112
if type(module) is not torch.nn.Linear:
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we do this in check_quantized_param ? If not, could you add a comment about the reason.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in check_quantized_param

Suggested change
if type(module) is not torch.nn.Linear:
return

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to keep this ?

src/transformers/quantizers/quantizer_hqq.py Outdated Show resolved Hide resolved
src/transformers/quantizers/quantizer_hqq.py Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mobicham for this huge work ! I had a look at the PR with @SunMarc and left some comments, the serialization / loading HQQ weights logic seems quite involved so far, maybe we could first go with a v1 with just on-the-fly quantization then I will need to refactor the HF Quantizers to be able to incorporate some of the new logic there, wdyt?

docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
docs/source/en/quantization.md Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/quantizers/quantizer_hqq.py Outdated Show resolved Hide resolved
@rationalism
Copy link

@mobicham Tried running this branch on my machine (merged into the 4.39.3 Transformers release branch), got this error:

(lm_fun) alyssa@alyssa-desktop:~/lm_fun/transformers$ CUDA_VISIBLE_DEVICES=0,1 python ../mixtral_quant.py 
  0%|                                                                                                                                                                                | 0/51 [00:00<?, ?it/s]
Loading checkpoint shards:   0%|                                                                                                                                                     | 0/19 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/alyssa/lm_fun/transformers/../mixtral_quant.py", line 23, in <module>
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda", quantization_config=quant_config)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/auto/auto_factory.py", line 563, in from_pretrained
    return model_class.from_pretrained(
  File "/home/alyssa/lm_fun/transformers/src/transformers/modeling_utils.py", line 3619, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/alyssa/lm_fun/transformers/src/transformers/modeling_utils.py", line 4046, in _load_pretrained_model
    new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
  File "/home/alyssa/lm_fun/transformers/src/transformers/modeling_utils.py", line 814, in _load_state_dict_into_meta_model
    not hf_quantizer.check_quantized_param(
TypeError: HQQHfQuantizer.check_quantized_param() got an unexpected keyword argument 'param_device'

Code I was running:

from transformers import AutoModelForCausalLM, AutoTokenizer, HQQConfig                                                                                                                                     
from hqq.core.quantize import *                                                                                                                                                                             
                                                                                                                                                                                                            
# model_id  = "mistral-community/Mixtral-8x22B-v0.1"                                                                                                                                                        
model_id  = "mistralai/Mixtral-8x7B-v0.1"                                                                                                                                                                   
                                                                                                                                                                                                            
attn_params     = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True)                                                                                                             
attn_params['scale_quant_params']['group_size'] = 256                                                                                                                                                       
experts_params = BaseQuantizeConfig(nbits=2, group_size=16, quant_zero=True, quant_scale=True)                                                                                                              
                                                                                                                                                                                                            
#Each type of linear layer (referred to as linear tag) will use different quantization parameters                                                                                                           
quant_config  = HQQConfig({                                                                                                                                                                                 
  'self_attn.q_proj': attn_params,                                                                                                                                                                          
  'self_attn.k_proj': attn_params,                                                                                                                                                                          
  'self_attn.v_proj': attn_params,                                                                                                                                                                          
  'self_attn.o_proj': attn_params,                                                                                                                                                                          
                                                                                                                                                                                                            
  'block_sparse_moe.experts.w1': experts_params,                                                                                                                                                            
  'block_sparse_moe.experts.w2': experts_params,                                                                                                                                                            
  'block_sparse_moe.experts.w3': experts_params                                                                                                                                                             
})                                                                                                                                                                                                          
                                                                                                                                                                                                            
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda", quantization_config=quant_config)  

@rationalism
Copy link

@mobicham Tried saving and loading a model on the mobiusml:stable branch, got this error:

>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained('/home/alyssa/lm_fun/text-generation-webui/models/mixtral-8x22B-HQQ-2bit-g8')
/home/alyssa/lm_fun/text-generation-webui/models/mixtral-8x22B-HQQ-2bit-g8
None
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/auto/auto_factory.py", line 563, in from_pretrained
    return model_class.from_pretrained(
  File "/home/alyssa/lm_fun/transformers/src/transformers/modeling_utils.py", line 3062, in from_pretrained
    save_dir = BaseHQQModel.try_snapshot_download(pretrained_model_name_or_path, cache_dir)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/hqq/models/base.py", line 281, in try_snapshot_download
    save_dir = pjoin(cache_dir, save_dir_or_hub)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/posixpath.py", line 76, in join
    a = os.fspath(a)
TypeError: expected str, bytes or os.PathLike object, not NoneType

Adding a cache_dir argument to the from_pretrained() call fixed it, but ideally I think you should be able to omit that?

@rationalism
Copy link

rationalism commented Apr 12, 2024

@mobicham Tried using this to do an HQQ quantization of the new Mixtral-8x22B model on my desktop. It worked!, and then the model saved to disk fine. But when I tried to reload it from disk, this part appeared to work:

# Load weights                                                                                                                                                                                  
try:                                                                                                                                                                                                            
    loaded_weights = BaseHQQModel.load_weights(save_dir)                                                                                                                                                    except Exception:                                                                                                                                                                               
    logger.warning("Failed to load the HQQ weights")                                                                                                                                            
    return                                 

while this part caused it to OOM (even though it had successfully loaded before):

            # load modules                                                                                                                                                                                  
            for name in logging.tqdm(name_to_module.keys(), "Loading"):                                                                                                                                     
                module = name_to_module[name]                                                                                                                                                               
                parent = find_parent(model, name)                                                                                                                                                           
                node = name.split(".")[-1]                                                                                                                                                                  
                setattr(parent, node, load_hqq_module(module, loaded_weights, compute_dtype, hqq_device))  

here's the stack trace:

>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> model = AutoModelForCausalLM.from_pretrained('/home/alyssa/lm_fun/text-generation-webui/models/mixtral-8x22B-HQQ-2bit-g8', cache_dir='/home/alyssa/.cache/huggingface')
Loading:  59%|███████████████████████████████████████████████████████████████████████████████████████████▏                                                              | 1329/2243 [00:38<00:26, 34.79it/s]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/auto/auto_factory.py", line 563, in from_pretrained
    return model_class.from_pretrained(
  File "/home/alyssa/lm_fun/transformers/src/transformers/modeling_utils.py", line 3099, in from_pretrained
    setattr(parent, node, load_hqq_module(module, loaded_weights, compute_dtype, hqq_device))
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/utils/hqq_utils.py", line 93, in load_hqq_module
    module.load_state_dict(state_dict)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/hqq/core/quantize.py", line 543, in load_state_dict
    self.cuda(self.device)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/hqq/core/quantize.py", line 419, in cuda
    self.W_q.data, self.meta = Quantizer.cuda(self.W_q.data, self.meta, device)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/hqq/core/quantize.py", line 220, in cuda
    return Quantizer.to_inplace(W_q, meta, device=device)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/hqq/core/quantize.py", line 190, in to_inplace
    .to(device)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 24.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 6.62 MiB is free. Including non-PyTorch memory, this process has 23.57 GiB memory in use. Of the allocated memory 23.11 GiB is allocated by PyTorch, and 69.73 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

the cause seems like it's likely a mismatch between the device map in the BaseHQQModel, and the device_map used as an argument here? Thanks a lot

@rationalism
Copy link

@mobicham Used a monkey-patch to solve the device_map problem. Now the model loads! (although it still uses much more VRAM and CPU RAM when you load it from a pre-quantized folder than when you quantize it from scratch, which is annoying)

however, it now errors out when you try to do generation:

>>> text = "Hello my name is"
>>> inputs = tokenizer(text, return_tensors="pt").to("cuda:0")
>>> model = AutoModelForCausalLM.from_pretrained('/home/alyssa/lm_fun/text-generation-webui/models/mixtral-8x22B-HQQ-2bit-g8', cache_dir='/home/alyssa/.cache/huggingface', device_map=DEVICE_MAP_MIXTRAL)
Loading: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2243/2243 [00:49<00:00, 45.08it/s]
>>> outputs = model.generate(**inputs, max_new_tokens=20)
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/generation/utils.py", line 1522, in generate
    result = self._greedy_search(
  File "/home/alyssa/lm_fun/transformers/src/transformers/generation/utils.py", line 2405, in _greedy_search
    outputs = self(
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/mixtral/modeling_mixtral.py", line 1360, in forward
    outputs = self.model(
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/mixtral/modeling_mixtral.py", line 1228, in forward
    layer_outputs = decoder_layer(
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/mixtral/modeling_mixtral.py", line 931, in forward
    hidden_states = self.input_layernorm(hidden_states)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/mixtral/modeling_mixtral.py", line 181, in forward
    return self.weight * hidden_states.to(input_dtype)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

@rationalism
Copy link

@mobicham Tried generating from a multi-device HQQ model that had been quantized live (vs. being quantized, saved and then re-loaded), got an interesting looking stack trace:

Traceback (most recent call last):
  File "/home/alyssa/lm_fun/mixtral_quant.py", line 47, in <module>
    outputs = model.generate(**inputs, max_new_tokens=20)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/generation/utils.py", line 1522, in generate
    result = self._greedy_search(
  File "/home/alyssa/lm_fun/transformers/src/transformers/generation/utils.py", line 2405, in _greedy_search
    outputs = self(
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/mixtral/modeling_mixtral.py", line 1360, in forward
    outputs = self.model(
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/mixtral/modeling_mixtral.py", line 1228, in forward
    layer_outputs = decoder_layer(
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/mixtral/modeling_mixtral.py", line 934, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/lm_fun/transformers/src/transformers/models/mixtral/modeling_mixtral.py", line 730, in forward
    query_states = self.q_proj(hidden_states)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/hqq/core/quantize.py", line 793, in forward_aten_backprop
    return HQQMatmulNoCacheDeq.apply(x, self.dequantize_aten, self.bias)
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/alyssa/anaconda3/envs/lm_fun/lib/python3.10/site-packages/hqq/core/quantize.py", line 261, in forward
    out = torch.matmul(x, dequantize().t())
RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
Aborted

here's the code I used:

import time                                                                                                                                                                                                 
from transformers import AutoModelForCausalLM, AutoTokenizer, HQQConfig                                                                                                                                     
from hqq.core.quantize import *                                                                                                                                                                             
                                                                                                                                                                                                            
DEVICE_MAP_MIXTRAL = {'model.embed_tokens': 0, 'model.layers.0': 0, 'model.layers.1': 0, 'model.layers.2': 0, 'model.layers.3': 0, 'model.layers.4': 0, 'model.layers.5': 0,                                
                      'model.layers.6': 0, 'model.layers.7': 0, 'model.layers.8': 0, 'model.layers.9': 0, 'model.layers.10': 0, 'model.layers.11': 0, 'model.layers.12': 0,                                 
                      'model.layers.13': 0, 'model.layers.14': 0, 'model.layers.15': 0, 'model.layers.16': 0, 'model.layers.17': 0, 'model.layers.18': 0, 'model.layers.19': 0,                             
                      'model.layers.20': 0, 'model.layers.21': 0, 'model.layers.22': 0, 'model.layers.23': 0, 'model.layers.24': 0, 'model.layers.25': 0, 'model.layers.26': 0,                             
                      'model.layers.27': 0, 'model.layers.28': 1, 'model.layers.29': 1, 'model.layers.30': 1, 'model.layers.31': 1, 'model.layers.32': 1, 'model.layers.33': 1,                             
                      'model.layers.34': 1, 'model.layers.35': 1, 'model.layers.36': 1, 'model.layers.37': 1, 'model.layers.38': 1, 'model.layers.39': 1, 'model.layers.40': 1,                             
                      'model.layers.41': 1, 'model.layers.42': 1, 'model.layers.43': 1, 'model.layers.44': 1, 'model.layers.45': 1, 'model.layers.46': 1, 'model.layers.47': 1,                             
                      'model.layers.48': 1, 'model.layers.49': 1, 'model.layers.50': 1, 'model.layers.51': 1, 'model.layers.52': 1, 'model.layers.53': 1, 'model.layers.54': 1,                             
                      'model.layers.55': 1,                                                                                                                                                                 
                      'model.norm': 1, 'lm_head': 1}                                                                                                                                                        
                                                                                                                                                                                                            
MODEL_DIR = '/home/alyssa/lm_fun/text-generation-webui/models/'                                                                                                                                             
model_id = MODEL_DIR + 'mistral-community_Mixtral-8x22B-v0.1'                                                                                                                                               
# quant_model_save = MODEL_DIR + 'mixtral-8x22B-HQQ-2bit-g8'                                                                                                                                                
# model_id  = "mistral-community/Mixtral-8x22B-v0.1"                                                                                                                                                        
# model_id  = "mistralai/Mixtral-8x7B-v0.1"                                                                                                                                                                 
                                                                                                                                                                                                            
attn_params     = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True, quant_scale=True, offload_meta=True)                                                                                          
attn_params['scale_quant_params']['group_size'] = 128                                                                                                                                                       
experts_params = BaseQuantizeConfig(nbits=2, group_size=8, quant_zero=True, quant_scale=True, offload_meta=True)                                                                                            
                                                                                                                                                                                                            
#Each type of linear layer (referred to as linear tag) will use different quantization parameters                                                                                                           
quant_config  = HQQConfig({                                                                                                                                                                                 
  'self_attn.q_proj': attn_params,                                                                                                                                                                          
  'self_attn.k_proj': attn_params,                                                                                                                                                                          
  'self_attn.v_proj': attn_params,                                                                                                                                                                          
  'self_attn.o_proj': attn_params,                                                                                                                                                                          
                                                                                                                                                                                                            
  'block_sparse_moe.experts.w1': experts_params,                                                                                                                                                            
  'block_sparse_moe.experts.w2': experts_params,                                                                                                                                                            
  'block_sparse_moe.experts.w3': experts_params                                                                                                                                                             
})                                                                                                                                                                                                          
                                                                                                                                                                                                            
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=DEVICE_MAP_MIXTRAL, quantization_config=quant_config)                                                         
print("Model loaded successfully")                                                                                                                                                                          
print(model.hf_device_map)                                                                                                                                                                                  
                                                                                                                                                                                                            
tokenizer = AutoTokenizer.from_pretrained(model_id)                                                                                                                                                         
text = "Hello my name is"                                                                                                                                                                                   
inputs = tokenizer(text, return_tensors="pt").to("cuda:0")                                                                                                                                                  
                                                                                                                                                                                                            
print("Begin text generation")                                                                                                                                                                              
outputs = model.generate(**inputs, max_new_tokens=20)                                                                                                                                                       
print(tokenizer.decode(outputs[0], skip_special_tokens=True))                                                                                                                                               
print("Generation successful")       

@mobicham
Copy link
Contributor Author

@rationalism I haven't tested it since the pull request, things have changed since then so it is likely to break with a merge.
I will take a look at all of these issues next. There has been a lot of update on the HQQ library as well.

@rationalism
Copy link

@mobicham Thank you very much! Very excited

@mobicham mobicham reopened this Apr 24, 2024
@mobicham
Copy link
Contributor Author

Updated PR. You can test it with this: https://gist.github.com/mobicham/cb07c1eff443ad0918c49ab7bb03e269
There's an issue with multi-gpu that I just can't figure out why it crashes.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again ! Just left two tiny comments while doing a small review !
EDIT: False alarm

docs/source/en/quicktour.md Outdated Show resolved Hide resolved
src/transformers/integrations/integration_utils.py Outdated Show resolved Hide resolved
@younesbelkada
Copy link
Contributor

cc @amyeroberts this is ready for a final review 🙏 btw I don't know why a llama test is failing on main, seems unrelated to this PR though !

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work iterating on this - looks great!

r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as the arguments are correct

That's what this method is supposed to check :)

Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return self.quant_config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, HQQBaseQuantizeConfig is a dict in the hqq library 👍

@amyeroberts
Copy link
Collaborator

@younesbelkada Yes - llama tests are unrelated and because of an upstream commit. I think we're free to merge!

@danielhanchen
Copy link
Contributor

@mobicham Fabulous work!!!

@mobicham
Copy link
Contributor Author

mobicham commented May 3, 2024

Thank you @danielhanchen !

@kadirnar
Copy link
Contributor

kadirnar commented May 3, 2024

@mobicham
Can I use it with flash-attention?

My Code:

import logging
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, BitsAndBytesConfig, HqqConfig

hqq_config  = HqqConfig(
    nbits=1, 
    group_size=64, 
    quant_zero=False, 
    quant_scale=False, axis=0) #axis=0 is used by default


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class SpeechToText:
    """Class for converting audio to text using a pre-trained speech recognition model."""

    def __init__(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
        self.model = None
        self.device = None

        if self.model is None:
            self.load_model(model_id)
        else:
            logging.info("Model already loaded.")

    def load_model(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
        """
        Loads the pre-trained speech recognition model and moves it to the specified device.

        Args:
            model_id (str): Identifier of the pre-trained model to be loaded.
        """
        logging.info("Loading model...")
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            model_id,
            quantization_config=quant_config,
            low_cpu_mem_usage=True,
            use_safetensors=True,
            attn_implementation="flash_attention_2",
            device_map="auto")

        logging.info("Model loaded successfully.")

        processor = AutoProcessor.from_pretrained(model_id)

        self.processor = processor
        self.model = model

    def __call__(
            self,
            chunk_length_s: int = 30,
            stride_length_s: int = 5,
            audio_path: str = "test.mp3",
            max_new_tokens: int = 128,
            batch_size: int = 100,
            language: str = "turkish"):
        """
        Converts audio to text using the pre-trained speech recognition model.

        Args:
            audio_path (str): Path to the audio file to be transcribed.

        Returns:
            str: Transcribed text from the audio.
        """
        pipe = pipeline(
            "automatic-speech-recognition",
            model=self.model,
            chunk_length_s=chunk_length_s,
            stride_length_s=stride_length_s,
            max_new_tokens=max_new_tokens,
            batch_size=100,
            device_map="auto",
            return_timestamps=True,
            tokenizer=self.processor.tokenizer,
            feature_extractor=self.processor.feature_extractor,
            model_kwargs={"use_flash_attention_2": True},
            generate_kwargs={"language": language},
        )
        logging.info("Transcribing audio...")
        result = pipe(audio_path)
        return result

output = SpeechToText(model_id="distil-whisper/distil-large-v3", quant_config=hqq_config) # or bnb_config
transcript = output(
    audio_path = "testv0.mp3",
    chunk_length_s = 30,
    stride_length_s = 5,
    max_new_tokens = 128,
    batch_size = 100,
    language = "english",
)

Error Message:

File [/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:51](https://mu4lijajurse00-8888.proxy.runpod.net/lab/tree/workspace/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py#line=50), in _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax)
     49 maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
     50 q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 51 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
     52     q,
     53     k,
     54     v,
     55     None,
     56     alibi_slopes,
     57     dropout_p,
     58     softmax_scale,
     59     causal,
     60     window_size[0],
     61     window_size[1],
     62     return_softmax,
     63     None,
     64 )
     65 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: FlashAttention only support fp16 and bf16 data type

Cli-env:

- `transformers` version: 4.41.0.dev0
- Platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.23.0
- Safetensors version: 0.4.3
- Accelerate version: 0.29.3
- Accelerate config: 	not found
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>-

@kadirnar
Copy link
Contributor

kadirnar commented May 3, 2024

It works when you set the parameter to "sdpa".

attn_implementation="flash_attention_2"

@mobicham
Copy link
Contributor Author

mobicham commented May 3, 2024

@kadirnar that's not related to hqq. Flash attention only works with fp16/bfp16 as the error says, try torch_dtype=torch.float16 when you load your model.

@kadirnar
Copy link
Contributor

kadirnar commented May 4, 2024

@mobicham
Thank you for your help. I tested hqq and 4bit method separately. Just 1 second faster and the amount of vram is the same. Where am I making a mistake?

hqq_config  = HqqConfig(
    nbits=1,
    group_size=64,
    quant_zero=False,
    quant_scale=False, axis=0) #axis=0 is used by default


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

Load-Model:

  def load_model(self, model_id: str = "openai/whisper-large-v3", quant_config=None):
      model = AutoModelForSpeechSeq2Seq.from_pretrained(
          model_id,
          quantization_config=quant_config,
          low_cpu_mem_usage=True,
          use_safetensors=True,
          attn_implementation="flash_attention_2",
          torch_dtype=torch.float16,
          device_map='auto',
          max_memory={0: "24GiB"}
      )
      logging.info("Model loaded successfully.")

      processor = AutoProcessor.from_pretrained(model_id)

      self.processor = processor
      self.model = model
image

@mobicham
Copy link
Contributor Author

mobicham commented May 4, 2024

@kadirnar

Just 1 second faster and the amount of vram is the same. Where am I making a mistake?

Well, that's already good :) !

For faster auto-regressive generation, you need:

from hqq.utils.patching import prepare_for_inference
prepare_for_inference(model, backend="torchao_int4")
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

You need to change your quant settings to use axis=1 instead of axis=0 for this to work.

I am not sure how torch.compile would work with AutoModelForSpeechSeq2Seq. But you should get a 2-3x speed-up even without torch.compile.

Please refer to the repo for the documentation.

@kadirnar
Copy link
Contributor

kadirnar commented May 4, 2024

Thank you very much for your help ❤️ I added Hqq optimization to the WhisperPlus library and now it runs 2 seconds faster. I will share today. Torch.compile didn't run.

AttributeError: 'WhisperForConditionalGeneration' object has no attribute 'base_class'. Did you mean: '_auto_class'?

@kadirnar
Copy link
Contributor

kadirnar commented May 4, 2024

@sanchit-gandhi ,
Do Automatic Speech Recognition models support torch.compile?

@mobicham
Copy link
Contributor Author

mobicham commented May 4, 2024

@kadirnar if you got that error it means it's not working properly. I will check on Monday, the whole pipeline was only tested on AutoModelForCausalLM. Please open an issue here: https://github.com/mobiusml/hqq/

@appoose
Copy link

appoose commented May 4, 2024

@kadirnar Can you also try the 4-bit version with torchao_int4, since it will be much faster ? So you should get faster execution time but hopefully with much lower WER.

@kadirnar
Copy link
Contributor

kadirnar commented May 4, 2024

@kadirnar Can you also try the 4-bit version with torchao_int4, since it will be much faster ? So you should get faster execution time but hopefully with much lower WER.

I tried it with 4bit, speed and word count were the same. I don't know the accuracy of the words.

@appoose
Copy link

appoose commented May 4, 2024 via email

@kadirnar
Copy link
Contributor

kadirnar commented May 5, 2024

@appoose

I tested for 1 bit and 4 bit. Accuracy loss is very low. I will just create a detailed doc in the 0.66 WhisperPlus repo.

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

Wer: 120.88 (hqq)
Wer: 120.48 (base)
Wer: 120.14 (bnb_config)

https://github.com/kadirnar/whisper-plus/blob/main/benckmarks.md

@appoose
Copy link

appoose commented May 6, 2024

It is great to see similar performance to base. But 120% WER ( assuming that is what you used) is quite poor (since it translates to 1.2 errors per word it transcribes). Not sure if it is because of general quality of model or some mismatch in eval script.

@huseinzol05
Copy link
Contributor

huseinzol05 commented May 6, 2024

@kadirnar , this is my gist to convert whisper model into fully HQQ, https://gist.github.com/huseinzol05/70daae3a4557616f315e7744ba3fcc93, but seems the speed is not faster than flash attention 2 on 30 second examples, but simple matmul is faster, https://gist.github.com/huseinzol05/ff59996034604d17c1e53074e9adc03f, @mobicham any thought?

@mobicham
Copy link
Contributor Author

mobicham commented May 6, 2024

@huseinzol05 @kadirnar taking a look at this, let's move this conversation to here please: mobiusml/hqq#68

itazap pushed a commit that referenced this pull request May 14, 2024
* update HQQ transformers integration

* push import_utils.py

* add force_hooks check in modeling_utils.py

* fix | with Optional

* force bias as param

* check bias is Tensor

* force forward for multi-gpu

* review fixes pass

* remove torch grad()

* if any key in linear_tags fix

* add cpu/disk check

* isinstance return

* add multigpu test + refactor tests

* clean hqq_utils imports in hqq.py

* clean hqq_utils imports in quantizer_hqq.py

* delete hqq_utils.py

* Delete src/transformers/utils/hqq_utils.py

* ruff init

* remove torch.float16 from __init__ in test

* refactor test

* isinstance -> type in quantizer_hqq.py

* cpu/disk device_map check in quantizer_hqq.py

* remove type(module) nn.linear check in quantizer_hqq.py

* add BaseQuantizeConfig import inside HqqConfig init

* remove hqq import in hqq.py

* remove accelerate import from test_hqq.py

* quant config.py doc update

* add hqqconfig to main_classes doc

* make style

* __init__ fix

* ruff __init__

* skip_modules list

* hqqconfig format fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* test_hqq.py remove mistral comment

* remove self.using_multi_gpu is False

* torch_dtype default val set and logger.info

* hqq.py isinstance fix

* remove torch=None

* torch_device test_hqq

* rename test_hqq

* MODEL_ID in test_hqq

* quantizer_hqq setattr fix

* quantizer_hqq typo fix

* imports quantizer_hqq.py

* isinstance quantizer_hqq

* hqq_layer.bias reformat quantizer_hqq

* Step 2 as comment in quantizer_hqq

* prepare_for_hqq_linear() comment

* keep_in_fp32_modules fix

* HqqHfQuantizer reformat

* quantization.md hqqconfig

* quantization.md model example reformat

* quantization.md # space

* quantization.md space   })

* quantization.md space   })

* quantization_config fix doc

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* axis value check in quantization_config

* format

* dynamic config explanation

* quant config method in quantization.md

* remove shard-level progress

* .cuda fix modeling_utils

* test_hqq fixes

* make fix-copies

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet