Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HfQuantizer class for quantization-related stuff in modeling_utils.py #26610

Merged
merged 83 commits into from Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
e0650b2
squashed earlier commits for easier rebase
poedator Dec 23, 2023
42adf9d
rm rebase leftovers
poedator Dec 23, 2023
7f57f26
4bit save enabled @quantizers
poedator Dec 23, 2023
f1f5da0
TMP gptq test use exllama
poedator Dec 24, 2023
a94d3a7
fix AwqConfigTest::test_wrong_backend for A100
poedator Dec 25, 2023
0b30de4
quantizers AWQ fixes
poedator Dec 25, 2023
4cdaf0d
_load_pretrained_model low_cpu_mem_usage branch
poedator Dec 25, 2023
0db1107
quantizers style
poedator Dec 25, 2023
89d1177
remove require_low_cpu_mem_usage attr
poedator Dec 25, 2023
0c71b00
rm dtype arg from process_model_before_weight_loading
poedator Dec 25, 2023
2b4122a
rm config_origin from Q-config
poedator Dec 25, 2023
02ad562
rm inspect from q_config
poedator Dec 25, 2023
3e51d51
fixed docstrings in QuantizationConfigParser
poedator Dec 25, 2023
2569367
logger.warning fix
poedator Dec 25, 2023
3259243
mv is_loaded_in_4(8)bit to BnbHFQuantizer
poedator Dec 25, 2023
ab61417
is_accelerate_available error msg fix in quantizer
poedator Dec 25, 2023
95e44cd
split is_model_trainable in bnb quantizer class
poedator Dec 25, 2023
b936cfb
rm llm_int8_skip_modules as separate var in Q
poedator Dec 25, 2023
0b40d21
Q rm todo
poedator Dec 25, 2023
c53a3fb
fwd ref to HFQuantizer in type hint
poedator Dec 25, 2023
dbd93f2
rm note re optimum.gptq.GPTQQuantizer
poedator Dec 26, 2023
e34bd58
quantization_config in __init__ simplified
poedator Dec 26, 2023
fcd5a7a
replaced NonImplemented with create_quantized_param
poedator Dec 26, 2023
954c5e6
rm load_in_4/8_bit deprecation warning
poedator Dec 26, 2023
49e163f
QuantizationConfigParser refactoring
poedator Dec 26, 2023
f8b9e07
awq-related minor changes
poedator Jan 8, 2024
5eaf9ac
awq-related changes
poedator Jan 8, 2024
d678d99
awq config.modules_to_not_convert
poedator Jan 8, 2024
7c9c49b
raise error if no q-method in q-config in args
poedator Jan 8, 2024
0d739d3
minor cleanup
poedator Jan 10, 2024
b5f2bab
awq quantizer docstring
poedator Jan 10, 2024
af33463
combine common parts in bnb process_model_before_weight_loading
poedator Jan 10, 2024
d4af5f1
revert test_gptq
poedator Jan 10, 2024
94f2cc7
.process_model_ cleanup
poedator Jan 10, 2024
ec77d10
restore dict config warning
poedator Jan 10, 2024
f5b9849
removed typevars in quantizers.py
poedator Jan 16, 2024
fb37bb8
cleanup post-rebase 16 jan
poedator Jan 16, 2024
cdc71c8
QuantizationConfigParser classmethod refactor
poedator Jan 16, 2024
e6df6ed
rework of handling of unexpected aux elements of bnb weights
poedator Jan 17, 2024
1c433f5
moved q-related stuff from save_pretrained to quantizers
poedator Jan 17, 2024
60781dd
refactor v1
younesbelkada Jan 24, 2024
842391a
more changes
younesbelkada Jan 24, 2024
0803440
fix some tests
younesbelkada Jan 24, 2024
594d1a9
remove it from main init
younesbelkada Jan 24, 2024
a771ab7
ooops
younesbelkada Jan 24, 2024
aa4ec34
Apply suggestions from code review
younesbelkada Jan 25, 2024
53619de
fix awq issues
younesbelkada Jan 25, 2024
cd4aa90
Merge remote-tracking branch 'upstream/main' into hf-quantizer-work
younesbelkada Jan 25, 2024
a988d01
fix
younesbelkada Jan 25, 2024
a911e7d
fix
younesbelkada Jan 25, 2024
3886559
fix
younesbelkada Jan 25, 2024
43e5e70
fix
younesbelkada Jan 25, 2024
c1dcaa3
fix
younesbelkada Jan 25, 2024
b0ac4a7
fix
younesbelkada Jan 25, 2024
1575c47
Merge branch 'main' into hf-quantizer-work
younesbelkada Jan 25, 2024
ad8d7f6
add docs
younesbelkada Jan 25, 2024
89cf6cf
Apply suggestions from code review
younesbelkada Jan 26, 2024
0ebaf4e
Apply suggestions from code review
younesbelkada Jan 26, 2024
adaae05
Update docs/source/en/hf_quantizer.md
younesbelkada Jan 26, 2024
f0b5f96
address comments
younesbelkada Jan 26, 2024
30e1fc2
fix
younesbelkada Jan 26, 2024
3b7e625
Merge branch 'hf-quantizer-work' of https://github.com/younesbelkada/…
younesbelkada Jan 26, 2024
493d117
fixup
younesbelkada Jan 26, 2024
48c5761
Update src/transformers/modeling_utils.py
younesbelkada Jan 26, 2024
3744fb1
Update src/transformers/modeling_utils.py
younesbelkada Jan 26, 2024
c4995ab
address final comment
younesbelkada Jan 26, 2024
17f95bf
Merge branch 'hf-quantizer-work' of https://github.com/younesbelkada/…
younesbelkada Jan 26, 2024
abb4db3
update
younesbelkada Jan 26, 2024
7e5a5b8
Update src/transformers/quantizers/base.py
younesbelkada Jan 26, 2024
122b494
Update src/transformers/quantizers/auto.py
younesbelkada Jan 26, 2024
901ace5
fix
younesbelkada Jan 26, 2024
2da5233
Merge remote-tracking branch 'upstream/main' into hf-quantizer-work
younesbelkada Jan 29, 2024
2ab7fd5
add kwargs update
younesbelkada Jan 30, 2024
242682c
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada Jan 30, 2024
e387f68
Merge branch 'quant' into hf-quantizer-work
younesbelkada Jan 30, 2024
4c0c33e
fixup
younesbelkada Jan 30, 2024
c37b222
add `optimum_quantizer` attribute
younesbelkada Jan 30, 2024
ca40b04
oops
younesbelkada Jan 30, 2024
c0ed16a
Merge pull request #5 from younesbelkada/hf-quantizer-work
younesbelkada Jan 30, 2024
7a764fb
rm unneeded file
younesbelkada Jan 30, 2024
377943d
Merge branch 'quant' of https://github.com/poedator/transformers into…
younesbelkada Jan 30, 2024
85d4656
Merge remote-tracking branch 'upstream/main' into HEAD
younesbelkada Jan 30, 2024
deb7696
fix doctests
younesbelkada Jan 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Expand Up @@ -137,6 +137,8 @@
title: Overview
- local: quantization
title: Quantization
- local: hf_quantizer
title: Contribute new quantization method
- sections:
- local: perf_train_gpu_one
title: Methods and tools for efficient training on a single GPU
Expand Down
70 changes: 70 additions & 0 deletions docs/source/en/hf_quantizer.md
@@ -0,0 +1,70 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Contribute new quantization method

Transformers supports and integrates many quantization methods such as QLoRA, GPTQ, LLM.int8, and AWQ. However, there are other quantization approaches that are not yet integrated. To make adding and using these quantization methods with Transformers models easier, you should use the [`HfQuantizer`] class. The [`HfQuantizer`] is designed as an internal helper class for adding a quantization method instead of something you apply to every PyTorch module.

This guide will show you how to integrate a new quantization method with the [`HfQuantizer`] class.


## Requirements

Before integrating a new quantization method into Transformers, ensure the method you are trying to add meets the following prerequisites. Only quantization methods that can be run with PyTorch modules are currently supported.

- The quantization method is available through a Python package that is pip-installable by anyone (it is also fine if you can only install the package from source). Ideally, pre-compiled kernels are included in the pip package.
- The method can run on commonly-used hardware (CPU, GPU, ...).
- The method is wrapped in a `nn.Module` (e.g., `Linear8bitLt`, `Linear4bit`), and the quantized linear layer should have the following definition:

```py
class Linear4bit(nn.Module):
def __init__(self, ...):
...

def forward(self, x):
return my_4bit_kernel(x, self.weight, self.bias)
```
This way, Transformers models can be easily quantized by replacing some instances of `nn.Linear` with a target class.
- The quantization method should be serializable. You can save the quantized weights locally or push them to the Hub.
- Make sure the package that contains the quantization kernels/primitive is stable (no frequent breaking changes).

For some quantization methods, they may require "pre-quantizing" the models through data calibration (e.g., AWQ). In this case, we prefer to only support inference in Transformers and let the third-party library maintained by the ML community deal with the model quantization itself.

## Build a new HFQuantizer class

1. 📕 Create a new quantization config class inside `src/transformers/utils/quantization_config.py` and make sure to expose the new quantization config inside Transformers main `init` by adding it to the `_import_structure` object of `src/transformers/__init__.py`.

2- 🗃 Create a new file inside `src/transformers/quantizers/` named `quantizer_your_method.py`, and make it inherit from `src/transformers/quantizers/base.py::HfQuantizer`. Make sure to add the new quantizer and quantization config in the quantization auto-mapping in `src/transformers/quantizers/auto.py`

3- 🔩 Define the following class attributes/property methods for your quantization method:

* `requires_calibration`: Whether the quantization method requires a data calibration process. If set to `True`, you can only support inference (with quantized weights) and not inference and quantization.
* `required_packages`: A list of strings of the required packages to use the quantized weights. You might need to define some new utility methods such as `is_auto_awq_available` in `transformers/src/utils/import_utils.py`.
* `requires_parameters_quantization`: Only required if your quantization method requires extra attention to the underlying `nn.Parameter` object. For example, bitsandbytes uses `Params4bit` and `Int8Param`, which requires some extra attention when quantizing the model. Most of the recent quantization method packs int2/int4 weights inside `torch.uint8` weights, so this flag should not be really required (set to `False` by default).
* `is_serializable`: A property method to determine whether the method is serializable or not.
* `is_trainable`: A property method to determine whether you can fine-tune models on top of the quantization method (with or without PEFT approaches).


4- 🪛 Write the `validate_environment` and `update_torch_dtype` methods. These methods are called before creating the quantized model to ensure users use the right configuration. You can have a look at how this is done on other quantizers.

5- 🖋 Write the `_process_model_before_weight_loading` method. In Transformers, the quantized models are initialized first on the `"meta"` device before loading the weights. This means the `_process_model_before_weight_loading` method takes care of manipulating the model skeleton to replace some modules (e.g., `nn.Linear`) with the target modules (quantization modules). You can define a module replacement logic or any other utility method by creating a new file in `transformers/src/integrations/` and exposing the relevant methods in that folder's `__init__.py` file. The best starting point would be to have a look at another quantization methods such as `quantizer_awq.py`

6- 🖊 Write the `_process_model_after_weight_loading` method. This method enables implementing additional features that require manipulating the model after loading the weights.

7- 📖 Document everything! Make sure your quantization method is documented in the `docs/source/en/quantization.md` file.

8- 🟢 Add tests! You should add tests by first adding the package in our nightly Dockerfile inside `docker/transformers-all-latest-gpu` and then adding a new test file in `tests/quantization/xxx`. Feel free to check out how it is implemented for other quantization methods.

6 changes: 6 additions & 0 deletions docs/source/en/quantization.md
Expand Up @@ -20,6 +20,12 @@ Quantization techniques focus on representing data with less information while a

Transformers supports several quantization schemes to help you run inference with large language models (LLMs) and finetune adapters on quantized models. This guide will show you how to use Activation-aware Weight Quantization (AWQ), AutoGPTQ, and bitsandbytes.

<Tip>

Interested in adding a new quantization method to Transformers? Read the [HfQuantizer](./hf_quantizer) guide to learn how!

</Tip>

## AWQ

<Tip>
Expand Down
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Expand Up @@ -1001,6 +1001,7 @@
"pipeline",
],
"processing_utils": ["ProcessorMixin"],
"quantizers": [],
"testing_utils": [],
"tokenization_utils": ["PreTrainedTokenizer"],
"tokenization_utils_base": [
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/integrations/awq.py
Expand Up @@ -187,17 +187,18 @@ def fuse_awq_modules(model, quantization_config):
Args:
model (`~PreTrainedModel`):
The model to fuse - note this model should have been converted into AWQ format beforehand.
quantization_config (`dict`):
quantization_config (`Union[AwqConfig, dict]`):
The quantization configuration to use.
"""
# We need to convert it from dict in order to get an AwqConfig object
# otherwise the fields `backend` etc. will not be available
# https://github.com/huggingface/transformers/pull/27411#discussion_r1414044495
awq_config = AwqConfig.from_dict(quantization_config)
backend = awq_config.backend
if isinstance(quantization_config, dict):
quantization_config = AwqConfig.from_dict(quantization_config)
backend = quantization_config.backend

modules_to_fuse = get_modules_to_fuse(model, awq_config)
modules_to_not_convert = getattr(awq_config, "modules_to_not_convert", None)
modules_to_fuse = get_modules_to_fuse(model, quantization_config)
modules_to_not_convert = getattr(quantization_config, "modules_to_not_convert", None)

if backend == AwqBackendPackingMethod.AUTOAWQ:
from awq.modules.fused.attn import QuantAttentionFused
Expand Down