Skip to content

Commit

Permalink
[Quantization] AutoGPTQ refactor and matmul combination support
Browse files Browse the repository at this point in the history
This PR refactors the AutoGPTQ integration to better align with the
framework design. The PR, meanwhile, supports the AutoGPTQ quantization
in MLC LLM with matmul combination.

With this PR, you will be able to compile Llama2 using the following
command:
```python
python -m mlc_llm.build --model=Llama-2-7b-chat-hf --quantization autogptq_llama_q4f16_1 --target cuda
```
to use the AutoGPTQ quantization. **Note that the first run may take
around 10 min for AutoGPTQ quantization computation, and the following
runs will be much quicker.** The AutoGPTQ quantization requires the
Python `auto_gptq` package to have version at least 0.2.0.

Co-authored-by: Lei Wang <LeiWang1999@users.noreply.github.com>
  • Loading branch information
MasterJH5574 and LeiWang1999 committed Aug 16, 2023
1 parent abac1a3 commit b5c1162
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 419 deletions.
24 changes: 12 additions & 12 deletions mlc_llm/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@
from .quantization import QuantizationSpec, NoQuantizationSpec, ParamQuantKind
from .quantization import QuantSpecUpdater
from .group_quantization import GroupQuantizationSpec
from .autogptq_quantization import AutogptqQuantizationSpec, load_autogptq_params
from .autogptq_quantization import AutogptqQuantizationSpec
from .ft_rowwise_quantization import FTRowwiseQuantizationSpec, FTQuantizeUpdater


# The predefined quantization schemes.
quantization_schemes = {
"autogptq_llama_q4f16_0": QuantizationScheme(
pre_quantized=True,
name="autogptq_llama_q4f16_0",
linear_weight=AutogptqQuantizationSpec(
dtype="float16",
mode="int4",
sym=False,
storage_nbit=32,
group_size=128,
),
embedding_table=NoQuantizationSpec("float16"),
final_fc_weight=NoQuantizationSpec("float16"),
),
"autogptq_llama_q4f16_1": QuantizationScheme(
name="autogptq_llama_q4f16_1",
linear_weight=AutogptqQuantizationSpec(
dtype="float16",
mode="int4",
sym=False,
group_size=-1,
transpose=True,
),
embedding_table=NoQuantizationSpec("float16"),
final_fc_weight=NoQuantizationSpec("float16"),
_base_model_prefix="model",
_layers_block_name="model.layers",
_inside_layer_modules=[
["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
["self_attn.o_proj"],
["mlp.gate_proj", "mlp.down_proj", "mlp.up_proj"],
],
_load_quantized_params_func=load_autogptq_params,
),
"q0f16": QuantizationScheme("q0f16", NoQuantizationSpec("float16")),
"q0f32": QuantizationScheme("q0f32", NoQuantizationSpec("float32")),
Expand Down

0 comments on commit b5c1162

Please sign in to comment.