Skip to content

Commit

Permalink
[Quantization] AutoGPTQ refactor and matmul combination support (#694)
Browse files Browse the repository at this point in the history
This PR refactors the AutoGPTQ integration to better align with the
framework design. The PR, meanwhile, supports the AutoGPTQ quantization
in MLC LLM with matmul combination.

With this PR, you will be able to compile Llama2 using the following
command:
```python
python -m mlc_llm.build --model=Llama-2-7b-chat-hf --quantization autogptq_llama_q4f16_1 --target cuda
```
to use the AutoGPTQ quantization. **Note that the first run may take
around 10 min for AutoGPTQ quantization computation, and the following
runs will be much quicker.** The AutoGPTQ quantization requires the
Python `auto_gptq` package to have version at least 0.2.0.

Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>
  • Loading branch information
LeiWang1999 and MasterJH5574 committed Aug 25, 2023
1 parent 94bda91 commit 5fe6344
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 420 deletions.
19 changes: 12 additions & 7 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,11 @@ class BuildArgs:
no_cublas: bool = field(
default=False,
metadata={
"help": ("Disable the step that offloads matmul to cuBLAS. Without this flag, "
"matmul will be offloaded to cuBLAS if quantization mode is q0f16 or q0f32, "
"target is CUDA and TVM has been built with cuBLAS enbaled."),
"help": (
"Disable the step that offloads matmul to cuBLAS. Without this flag, "
"matmul will be offloaded to cuBLAS if quantization mode is q0f16 or q0f32, "
"target is CUDA and TVM has been built with cuBLAS enbaled."
),
"action": "store_true",
},
)
Expand Down Expand Up @@ -358,13 +360,16 @@ def mod_transform_before_build(
) # pylint: disable=not-callable

if "num_attention_heads" in config and "hidden_size" in config:
max_seq_len = config["max_sequence_length"]

max_seq_len = None
if args.max_seq_len > 0:
max_seq_len = args.max_seq_len
elif "max_sequence_length" in config:
max_seq_len = config["max_sequence_length"]

mod = fuse_split_rotary_embedding(mod, config["num_attention_heads"], config["hidden_size"],
max_seq_len)
if max_seq_len:
mod = fuse_split_rotary_embedding(
mod, config["num_attention_heads"], config["hidden_size"], max_seq_len
)

if args.target_kind == "cuda":
patterns = []
Expand Down
24 changes: 12 additions & 12 deletions mlc_llm/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@
from .quantization import QuantizationSpec, NoQuantizationSpec, ParamQuantKind
from .quantization import QuantSpecUpdater
from .group_quantization import GroupQuantizationSpec
from .autogptq_quantization import AutogptqQuantizationSpec, load_autogptq_params
from .autogptq_quantization import AutogptqQuantizationSpec
from .ft_rowwise_quantization import FTRowwiseQuantizationSpec, FTQuantizeUpdater


# The predefined quantization schemes.
quantization_schemes = {
"autogptq_llama_q4f16_0": QuantizationScheme(
pre_quantized=True,
name="autogptq_llama_q4f16_0",
linear_weight=AutogptqQuantizationSpec(
dtype="float16",
mode="int4",
sym=False,
storage_nbit=32,
group_size=128,
),
embedding_table=NoQuantizationSpec("float16"),
final_fc_weight=NoQuantizationSpec("float16"),
),
"autogptq_llama_q4f16_1": QuantizationScheme(
name="autogptq_llama_q4f16_1",
linear_weight=AutogptqQuantizationSpec(
dtype="float16",
mode="int4",
sym=False,
group_size=-1,
transpose=True,
),
embedding_table=NoQuantizationSpec("float16"),
final_fc_weight=NoQuantizationSpec("float16"),
_base_model_prefix="model",
_layers_block_name="model.layers",
_inside_layer_modules=[
["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
["self_attn.o_proj"],
["mlp.gate_proj", "mlp.down_proj", "mlp.up_proj"],
],
_load_quantized_params_func=load_autogptq_params,
),
"q0f16": QuantizationScheme("q0f16", NoQuantizationSpec("float16")),
"q0f32": QuantizationScheme("q0f32", NoQuantizationSpec("float32")),
Expand Down

0 comments on commit 5fe6344

Please sign in to comment.