Skip to content

Commit

Permalink
Refine WOQ Config (#1356)
Browse files Browse the repository at this point in the history
Co-authored-by: Spycsh <sihan.chen@intel.com>
Co-authored-by: Cheng Penghui <penghui.cheng@intel.com>
Co-authored-by: wenhuach21 <108330088+wenhuach21@users.noreply.github.com>
  • Loading branch information
4 people committed Mar 13, 2024
1 parent 260155a commit 97f0db9
Show file tree
Hide file tree
Showing 43 changed files with 1,686 additions and 992 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ outputs = model.generate(inputs)
You can also load the low-bit model quantized by GPTQ/AWQ/RTN/AutoRound algorithm.
```python
from transformers import AutoTokenizer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, GPTQConfig

# Download Hugging Face GPTQ/AWQ model or use local quantize model
model_name = "PATH_TO_MODEL" # local path to model
woq_config = WeightOnlyQuantConfig(use_gptq=True) # use_awq=True for AWQ; use_autoround=True for AutoRound
woq_config = GPTQConfig(bits=4) # use AwqConfig for AWQ models, and AutoRoundConfig for AutoRound models
prompt = "Once upon a time, a little girl"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
Expand Down
74 changes: 49 additions & 25 deletions docs/weightonlyquant.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Weight Only Quantization (WOQ)

1. [Introduction](#introduction)

2. [Supported Framework Model Matrix](#supported-framework-model-matrix)
2. [Supported Algorithms](#supported-algorithms)

3. [Examples For CPU/CUDA](#examples-for-cpu-and-cuda)

Expand All @@ -12,40 +12,65 @@ Weight Only Quantization (WOQ)
## Introduction

As large language models (LLMs) become more prevalent, there is a growing need for new and improved quantization methods that can meet the computational demands of these modern architectures while maintaining the accuracy. Compared to [normal quantization](https://github.com/intel/intel-extension-for-transformers/blob/main/docs/quantization.md) like W8A8, weight only quantization is probably a better trade-off to balance the performance and the accuracy, since we will see below that the bottleneck of deploying LLMs is the memory bandwidth and normally weight only quantization could lead to better accuracy.
## Supported Framework Model Matrix
## Supported Algorithms

| Algorithms/Framework | PyTorch | LLM Runtime |
|:--------------:|:----------:|:----------:|
| RTN | &#10004; | &#10004; |
| AWQ | &#10004; | stay tuned |
| TEQ | &#10004; | stay tuned |
| GPTQ | &#10004; | &#10004; |
| Support Device | Rtn | Awq | Teq | GPTQ | AutoRound |
|:--------------:|:----------:|:----------:|:----------:|:----:|:----:|
| Intel CPU | &#10004; | &#10004; | &#10004; | &#10004; | &#10004; |
| Intel GPU | &#10004; | stay tuned | stay tuned | stay tuned | stay tuned |

| Support Device | RTN | AWQ | TEQ | GPTQ |
|:--------------:|:----------:|:----------:|:----------:|:----:|
| CPU | &#10004; | &#10004; | &#10004; | &#10004; |
| GPU | &#10004; | stay tuned | stay tuned | stay tuned |
> **RTN:** A quantification method that we can think of very intuitively. It does not require additional datasets and is a very fast quantization method. Generally speaking, RTN will convert the weight into a uniformly distributed integer data type, but some algorithms, such as Qlora, propose a non-uniform NF4 data type and prove its theoretical optimality.
**RTN**[[1\]](https://github.com/intel/intel-extension-for-transformers/blob/548c13ed2e19cde91729530ca26c3b875c1b3d10/docs/weightonlyquant.md#1)(&#9733;&#9733;&#9733;): Rounding to Nearest (RTN) is an intuitively simple method that rounds values to the nearest integer. It boasts simplicity, requiring no additional datasets, and offers fast quantization. Besides, it could be easily applied in other datatype like NF4(non-uniform). Typically, it performs well on configurations such as W4G32 or W8, but worse than advanced algorithms at lower precision level.

> **GPTQ:** A new one-shot weight quantization method based on approximate second-order information, that is both highly-accurate and highly efficient. The weights of each column are updated based on the fixed-scale pseudo-quantization error and the inverse of the Hessian matrix calculated from the activations. The updated columns sharing the same scale may generate a new max/min value, so the scale needs to be saved for restoration.

> **AWQ:** Proved that protecting only 1% of salient weights can greatly reduce quantization error. the salient weight channels are selected by observing the distribution of activation and weight per channel. The salient weights are also quantized after multiplying a big scale factor before quantization for preserving.
**Teq**[[2\]](https://github.com/intel/intel-extension-for-transformers/blob/548c13ed2e19cde91729530ca26c3b875c1b3d10/docs/weightonlyquant.md#4)(&#9733;&#9733;&#9733;): To our knowledge, it is the first trainable equivalent ransformation method (summited for peer review in 202306). However, it requires more memory than other methods as model-wise loss is used and the equivalent transformation imposes certain requirements on model architecture.

> **TEQ:** A trainable equivalent transformation that preserves the FP32 precision in weight-only quantization. It is inspired by AWQ while providing a new solution to search for the optimal per-channel scaling factor between activations and weights.

**GPTQ**[[2\]](https://github.com/intel/intel-extension-for-transformers/blob/548c13ed2e19cde91729530ca26c3b875c1b3d10/docs/weightonlyquant.md#2)(&#9733;&#9733;&#9733;&#9733;): GPTQ is a widely adopted method based on the Optimal Brain Surgeon. It quantizes weight block by block and fine-tunes the remaining unquantized ones to mitigate quantization errors. Occasionally, Non-positive semidefinite matrices may occur, necessitating adjustments to hyperparameters.



**Awq**[[4\]](https://github.com/intel/intel-extension-for-transformers/blob/548c13ed2e19cde91729530ca26c3b875c1b3d10/docs/weightonlyquant.md#3)(&#9733;&#9733;&#9733;&#9733;): AWQ is a popular method that explores weight min-max values and equivalent transformations in a handcrafted space. While effective, the equivalent transformation imposes certain requirements on model architecture, limiting its applicability to broader models or increasing engineering efforts.



**AutoRound**[[5\]](https://github.com/intel/intel-extension-for-transformers/blob/548c13ed2e19cde91729530ca26c3b875c1b3d10/docs/weightonlyquant.md#5)(&#9733;&#9733;&#9733;&#9733;&#9734;): AutoRound utilizes sign gradient descent to optimize rounding values and minmax values of weights within just 200 steps, showcasing impressive performance compared to recent methods like GPTQ/AWQ. Additionally, it offers hypeparameters tuning compatibility to further enhance performance. However, due to its reliance on gradient backpropagation, currently it is not quite fit for backends like ONNX.

### references
<a id="1">[1]</a>
Gunho Park, Baeseong Park, Se Jung Kwon, Byeongwook Kim, Youngjoo Lee, and Dongsoo Lee.
nuqmm: Quantized matmul for efficient inference of large-scale generative language models.
arXiv preprint arXiv:2206.09557, 2022.

<a id="2">[2]</a>
Cheng, W., Cai, Y., Lv, K & Shen, H. (2023).
TEQ: Trainable Equivalent Transformation for Quantization of LLMs.
arXiv preprint arXiv:2310.10944.

<a id="3">[3]</a>
Frantar, Elias, et al. "Gptq: Accurate post-training quantization for generative pre-trained transformers." arXiv preprint arXiv:2210.17323 (2022).

<a id="4">[4]</a>
Lin, Ji, et al.(2023).
AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration.
arXiv preprint arXiv:2306.00978.

<a id="5">[5]</a>
Cheng, W., Zhang, W., Shen, H., Cai, Y., He, X., & Lv, K. (2023).
Optimize weight rounding via signed gradient descent for the quantization of llms.
arXiv preprint arXiv:2309.05516.

## Examples For CPU AND CUDA

Our motivation is improve CPU support for weight only quantization, since `bitsandbytes` only support CUDA GPU device. We have extended the `from_pretrained` function so that `quantization_config` can accept [`WeightOnlyQuantConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/quantization_config.py#L28) to implement conversion on the CPU. We not only support PyTorch but also provide LLM Runtime backend based cpp programming language. Here are the example codes.
Our motivation is to improve CPU support for weight only quantization, since `bitsandbytes`, `auto-gptq`, `autoawq` only support CUDA GPU device. We have extended the `from_pretrained` function so that `quantization_config` can accept [`RtnConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/config.py#L608), [`AwqConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/config.py#L793), [`TeqConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/config.py#L28), [`GPTQConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/config.py#L855), [`AutoroundConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/config.py#L912) to implement conversion on the CPU. We not only support PyTorch but also provide LLM Runtime backend based cpp programming language. Here are the example codes.

### Example for CPU device
4-bit/8-bit inference with `WeightOnlyQuantConfig` on CPU device.
4-bit/8-bit inference with `RtnConfig`, `AwqConfig`, `TeqConfig`, `GPTQConfig`, `AutoRoundConfig` on CPU device.
```bash
cd intel_extension_for_transformers/llm/runtime/graph
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
cd examples/huggingface/pytorch/text-generation/quantization
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, RtnConfig
model_name_or_path = "Intel/neural-chat-7b-v3-3"
# weight_dtype: int8/int4, compute_dtype: int8/fp32
woq_config = WeightOnlyQuantConfig(weight_dtype="int4", compute_dtype="int8")
woq_config = RtnConfig(bits=4, compute_dtype="int8")
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
quantization_config=woq_config,
Expand Down Expand Up @@ -82,7 +107,7 @@ gen_ids = woq_model.generate(input_ids, max_new_tokens=32, **generate_kwargs)
gen_text = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
print(gen_text)
```
`load_in_4bit` and `load_in_8bit` both support on CPU and CUDA GPU device. If device set to use GPU, the BitsAndBytesConfig will be used, if the device set to use CPU, the WeightOnlyQuantConfig will be used.
`load_in_4bit` and `load_in_8bit` both support on CPU and CUDA GPU device. If device set to use GPU, the BitsAndBytesConfig will be used, if the device set to use CPU, the RtnConfig will be used.
```bash
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
woq_model = AutoModelForCausalLM.from_pretrained(
Expand Down Expand Up @@ -160,7 +185,6 @@ pip install intel-extension-for-transformers
import intel_extension_for_pytorch as ipex
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
from transformers import AutoTokenizer
import torch

device = "xpu"
model_name = "Qwen/Qwen-7B"
Expand All @@ -171,7 +195,7 @@ inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
qmodel = AutoModelForCausalLM.from_pretrained(model_name, load_in_4bit=True, device_map="xpu", trust_remote_code=True)

# optimize the model with ipex, it will improve performance.
qmodel = ipex.optimize_transformers(qmodel, inplace=True, dtype=torch.float16, quantization_config={}, device="xpu")
qmodel = ipex.optimize_transformers(qmodel, inplace=True, dtype=torch.float16, woq=True, device="xpu")

output = user_model.generate(inputs)
```
Expand All @@ -195,7 +219,7 @@ model.save_pretrained("saved_dir")
loaded_model = AutoModelForCausalLM.from_pretrained("saved_dir", trust_remote_code=True)

# Before executed the loaded model, you can call ipex.optimize_transformers function.
loaded_model = ipex.optimize_transformers(loaded_model, inplace=True, dtype=torch.float16, woq=True, device="xpu")
loaded_model = ipex.optimize_transformers(loaded_model, inplace=True, dtype=torch.float16, quantization_config={}, device="xpu")

output = loaded_model.generate(inputs)

Expand Down
4 changes: 2 additions & 2 deletions examples/huggingface/neural_speed/perplexity/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_ppl(sum_nll, sum_nll2, cnt: int):

def perplexity(model_name, dataset_name, **kwargs):
import datasets
from intel_extension_for_transformers.transformers import (AutoModelForCausalLM, WeightOnlyQuantConfig)
from intel_extension_for_transformers.transformers import (AutoModelForCausalLM, RtnConfig)
from transformers import AutoTokenizer, AutoConfig
model_name = try_resolve_dir(model_name)
dataset_name = try_resolve_dir(dataset_name)
Expand Down Expand Up @@ -107,7 +107,7 @@ def perplexity(model_name, dataset_name, **kwargs):
for k in kwargs
if k in ['use_cache', 'compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml']
}
woq_config = WeightOnlyQuantConfig(**woq_kwargs)
woq_config = RtnConfig(**woq_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)

model_kwargs = {k: kwargs[k] for k in kwargs if k in ['n_keep', 'shift_roped_k', 'memory_dtype']}
Expand Down
4 changes: 2 additions & 2 deletions examples/huggingface/neural_speed/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pathlib import Path
from typing import List, Optional
from transformers import AutoTokenizer,TextStreamer
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, RtnConfig
def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(description="Convert a PyTorch model to a NE compatible file")
parser.add_argument("--model_path",type=Path,
Expand All @@ -32,7 +32,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--max_new_tokens", type=int, help="max_new_tokens", default=300)
args = parser.parse_args(args_in)
model_name = args.model_path
woq_config = WeightOnlyQuantConfig(load_in_4bit=True, use_quant=args.not_quant,
woq_config = RtnConfig(load_in_4bit=True, use_quant=args.not_quant,
weight_dtype=args.weight_dtype, compute_dtype=args.compute_dtype, group_size=args.group_size, use_gptq=args.use_gptq)
prompt = args.prompt
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
Expand Down

0 comments on commit 97f0db9

Please sign in to comment.