diff --git a/README.md b/README.md index 58bee872a..6b9def6c8 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,7 @@ refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and for some accuracy results. [2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for - all bits other than 3 bits. Example - models: [Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound](https://huggingface.co/Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound) - and [Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound](https://huggingface.co/Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound). **A more advanced algorithm** tailored for specific configurations may be available in + all bits other than 3 bits. **A more advanced algorithm** tailored for specific configurations may be available in v0.8.1. [2025/05] AutoRound has been integrated into **vLLM**. You can now run models in the AutoRound format directly with @@ -186,58 +184,54 @@ ar = AutoRound(model_name_or_path, scheme="W4A16") # ar = AutoRound(model_name_or_path, nsamples=128, iters=50, lr=5e-3) # Supported formats: "auto_round" (default), "auto_gptq", "auto_awq", "llm_compressor", "gguf:q4_k_m", etc. -ar.quantize_and_save(output_dir="./tmp_autoround", format="auto_round") +ar.quantize_and_save(output_dir="./qmodel", format="auto_round") ``` -
- Detailed Hyperparameters - -- `model`: The PyTorch model to be quantized. - -- `tokenizer`: An optional tokenizer for processing input data. If none, a dataset must be provided. - -- `bits (int)`: Number of bits for quantization (default is 4). - -- `group_size (int)`: Size of the quantization group (default is 128). - -- `sym (bool)`: Whether to use symmetric quantization (default is True). - -- `enable_quanted_input (bool)`: Whether to use the output of the previous quantized block as the input for the current - block for tuning (default is True). +### AutoScheme Usage +Please refer to the [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for more details on AutoScheme. +~~~python +from auto_round import AutoRound, AutoScheme -- `enable_minmax_tuning (bool)`: Whether to enable weight min-max tuning (default is True). +model_name = "Qwen/Qwen3-8B" +avg_bits = 3.0 +scheme = AutoScheme(avg_bits=avg_bits, options=("GGUF:Q2_K_S", "GGUF:Q4_K_S"), ignore_scale_zp_bits=True) +layer_config = {"lm_head": "GGUF:Q6_K"} -- `iters (int)`: Number of tuning iterations (default is 200). +# Change iters to 200 for non-GGUF schemes +ar = AutoRound(model=model_name, scheme=scheme, layer_config=layer_config, iters=0) +ar.quantize_and_save() +~~~ -- `lr (float)`: The learning rate for rounding value (default is None, it will be set to 1.0/iters automatically). - -- `minmax_lr (float)`: The learning rate for min-max tuning (default is None, it will be set to lr automatically). - -- `nsamples (int)`: Number of samples for tuning (default is 128). - -- `seqlen (int)`: Data length of the sequence for tuning (default is 2048). - -- `batch_size (int)`: Batch size for training (default is 8). - -- `scale_dtype (str)`: The data type of quantization scale to be used (default is "float16"), different kernels have - different choices. +
+Important Hyperparameters -- `amp (bool)`: Whether to use automatic mixed precision (default is True). +##### Quantization Scheme & Configuration +- **`scheme` (str|dict|AutoScheme)**: The predefined quantization keys, e.g. `W4A16`, `MXFP4`, `NVFP4`, `GGUF:Q4_K_M`. +- **`bits` (int)**: Number of bits for quantization (default is `None`). If not None, it will override the scheme setting. +- **`group_size` (int)**: Size of the quantization group (default is `None`). If not None, it will override the scheme setting. +- **`sym` (bool)**: Whether to use symmetric quantization (default is `None`). If not None, it will override the scheme setting. +- **`layer_config` (dict)**: Configuration for weight quantization (default is `None`), mainly for mixed schemes. -- `nblocks (int)`: Packing several blocks as one for tuning together (default is 1). -- `gradient_accumulate_steps (int)`: Number of gradient accumulation steps (default is 1). +##### Algorithm Settings +- **`enable_alg_ext` (bool)**: Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`. +- **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `False` (improved RTN enabled). -- `low_gpu_mem_usage (bool)`: Whether to save GPU memory at the cost of ~20% more tuning time (default is False). +##### Tuning Process Parameters +- **`iters` (int)**: Number of tuning iterations (default is `200`). Common values: 0 (RTN mode), 50 (with lr=5e-3 recommended), 1000. Higher values increase accuracy but slow down tuning. +- **`lr` (float)**: The learning rate for rounding value (default is `None`). When None, it will be set to `1.0/iters` automatically. +- **`batch_size` (int)**: Batch size for training (default is `8`). 4 is also commonly used. -- `dataset Union[str, list, tuple, torch.utils.data.DataLoader]`: The dataset name for tuning (default is " - NeelNanda/pile-10k"). Local json file and combination of datasets have been supported, e.g. " - ./tmp.json,NeelNanda/pile-10k:train, mbpp:train+validation+test" +##### Calibration Dataset +- **`dataset` (str|list|tuple|torch.utils.data.DataLoader)**: The dataset for tuning (default is `"NeelNanda/pile-10k"`). Supports local JSON files and dataset combinations, e.g. `"./tmp.json,NeelNanda/pile-10k:train,mbpp:train+validation+test"`. +- **`nsamples` (int)**: Number of samples for tuning (default is `128`). +- **`seqlen` (int)**: Data length of the sequence for tuning (default is `2048`). -- `layer_config (dict)`: Configuration for weight quantization (default is None), mainly for mixed bits - or mixed precision. -- `device`: The device to be used for tuning. The default is set to 'auto', allowing for automatic detection. +##### Device/Speed Configuration +- **`enable_torch_compile` (bool)**: If no exception is raised, typically we recommend setting it to True for faster quantization with lower resource. +- **`low_gpu_mem_usage` (bool)**: Whether to offload intermediate features to CPU at the cost of ~20% more tuning time (default is `False`). +- **`device_map` (str|dict|int)**: The device to be used for tuning, e.g., `"cpu"`, `"cuda"`, `"0,1,2"` (default is `'0'`).
@@ -263,7 +257,7 @@ from auto_round import AutoRoundMLLM model_name_or_path = "Qwen/Qwen2.5-VL-7B-Instruct" # Quantize the model ar = AutoRoundMLLM(model_name_or_path, scheme="W4A16") -output_dir = "./tmp_autoround" +output_dir = "./qmodel" ar.quantize_and_save(output_dir) ``` @@ -307,7 +301,6 @@ sampling_params = {"temperature": 0.6, "top_p": 0.95} outputs = llm.generate(prompts, sampling_params) for prompt, output in zip(prompts, outputs): - print("===============================") print(f"Prompt: {prompt}\nGenerated text: {output['text']}") ``` diff --git a/auto_round/__init__.py b/auto_round/__init__.py index 268065ba4..f83b46160 100644 --- a/auto_round/__init__.py +++ b/auto_round/__init__.py @@ -15,7 +15,8 @@ # support for old api from auto_round.autoround import AutoRoundLLM, AutoRoundMLLM, AutoRoundAdam, AutoRoundDiffusion -from auto_round.schemes import QuantizationScheme, AutoScheme +from auto_round.schemes import QuantizationScheme +from auto_round.auto_scheme import AutoScheme from auto_round.utils import LazyImport diff --git a/auto_round/__main__.py b/auto_round/__main__.py index c403ee863..844f366bb 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -15,9 +15,10 @@ import os import sys +from auto_round.auto_scheme import AutoScheme from auto_round.compressors import BaseCompressor from auto_round.eval.eval_cli import EvalArgumentParser, _eval_init, eval, eval_task_by_task -from auto_round.schemes import PRESET_SCHEMES, AutoScheme +from auto_round.schemes import PRESET_SCHEMES from auto_round.utils import ( clear_memory, get_device_and_parallelism, diff --git a/auto_round/auto_scheme/__init__.py b/auto_round/auto_scheme/__init__.py index f4a3d2b23..f2682d376 100644 --- a/auto_round/auto_scheme/__init__.py +++ b/auto_round/auto_scheme/__init__.py @@ -11,32 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from auto_round.logger import logger -AUTO_SCHEME_METHODS = {} +from auto_round.auto_scheme.gen_auto_scheme import AutoScheme +from auto_round.auto_scheme.register import AUTO_SCHEME_METHODS - -def register_scheme_methods(names): - """Class decorator to register a mixed precision algorithm to the registry. - - Decorator function used before a Pattern subclass. - - Args: - names: A string. Define the export type. - - Returns: - cls: The class of register. - """ - - def register(alg): - if isinstance(names, (tuple, list)): - for name in names: - AUTO_SCHEME_METHODS[name] = alg - else: - AUTO_SCHEME_METHODS[names] = alg - - return alg - - return register - - -import auto_round.auto_scheme.default_alg +try: + import auto_round.auto_scheme.default_alg +except ImportError: + logger.warning("AutoScheme is currently supported only on Linux.") diff --git a/auto_round/auto_scheme/default_alg.abi3.so b/auto_round/auto_scheme/default_alg.abi3.so index 220fb3ce5..41d8d5634 100644 Binary files a/auto_round/auto_scheme/default_alg.abi3.so and b/auto_round/auto_scheme/default_alg.abi3.so differ diff --git a/auto_round/auto_scheme/gen_auto_scheme.py b/auto_round/auto_scheme/gen_auto_scheme.py index aab18ac7f..3b8f9bb58 100644 --- a/auto_round/auto_scheme/gen_auto_scheme.py +++ b/auto_round/auto_scheme/gen_auto_scheme.py @@ -11,21 +11,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from dataclasses import asdict -from typing import Iterable, Union + +from dataclasses import dataclass +from typing import Iterable, Optional, Union import torch -from auto_round import AutoScheme -from auto_round.auto_scheme import AUTO_SCHEME_METHODS +from auto_round.auto_scheme.register import AUTO_SCHEME_METHODS from auto_round.auto_scheme.utils import compute_avg_bits_for_scheme from auto_round.compressors.utils import gguf_type_fallback from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG from auto_round.logger import logger +from auto_round.schemes import QuantizationScheme from auto_round.utils import get_layer_features, get_module +@dataclass +class AutoScheme: + avg_bits: float + options: Union[str, list[Union[QuantizationScheme, str]], tuple[Union[QuantizationScheme, str], ...]] + shared_layers: Optional[Iterable[Iterable[str]]] = None + method: str = "default" + ignore_scale_zp_bits: bool = False + batch_size: Optional[int] = None + nsamples: Optional[int] = None + seqlen: Optional[int] = None + dataset: Optional[str] = None # Import Notice no comma for each item + device_map: Optional[Union[str, torch.device, int, dict]] = None + enable_torch_compile: Optional[bool] = None + disable_opt_rtn: bool = True + low_gpu_mem_usage: bool = True + + def __post_init__(self): + if isinstance(self.options, str): + options = self.options.upper().replace(" ", "") + self.options = options.split(",") + + class GenScheme: """Generate and validate quantization schemes for model layers.""" diff --git a/auto_round/auto_scheme/register.py b/auto_round/auto_scheme/register.py new file mode 100644 index 000000000..fa01e7939 --- /dev/null +++ b/auto_round/auto_scheme/register.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +AUTO_SCHEME_METHODS = {} + + +def register_scheme_methods(names): + """Class decorator to register a mixed precision algorithm to the registry. + + Decorator function used before a Pattern subclass. + + Args: + names: A string. Define the export type. + + Returns: + cls: The class of register. + """ + + def register(alg): + if isinstance(names, (tuple, list)): + for name in names: + AUTO_SCHEME_METHODS[name] = alg + else: + AUTO_SCHEME_METHODS[names] = alg + + return alg + + return register diff --git a/auto_round/autoround.py b/auto_round/autoround.py index a78c94737..cceb0668a 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import Union +from typing import TYPE_CHECKING, Union import torch @@ -26,9 +26,12 @@ MLLMCompressor, ) from auto_round.logger import deprecated, logger -from auto_round.schemes import AutoScheme, QuantizationScheme +from auto_round.schemes import QuantizationScheme from auto_round.utils import is_diffusion_model, is_mllm_model +if TYPE_CHECKING: + from auto_round.auto_scheme.gen_auto_scheme import AutoScheme + class AutoRound: """Automatic weight rounding (Signed Gradient Descent) for LLM quantization diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index c148164ae..869fc74de 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -30,6 +30,7 @@ from tqdm import tqdm from transformers import set_seed +from auto_round.auto_scheme.gen_auto_scheme import AutoScheme from auto_round.compressors.utils import ( block_forward, check_need_act_calibration, @@ -52,7 +53,7 @@ from auto_round.export.export_to_autoround import AutoRoundFormat from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG, ModelType from auto_round.logger import logger -from auto_round.schemes import AutoScheme, QuantizationScheme, get_gguf_scheme, preset_name_to_scheme +from auto_round.schemes import QuantizationScheme, get_gguf_scheme, preset_name_to_scheme from auto_round.sign_sgd import SignSGD from auto_round.special_model_handler import _handle_moe_model from auto_round.utils import ( @@ -139,6 +140,8 @@ def __init__( low_gpu_mem_usage: bool = False, device_map: Union[str, torch.device, int, dict] = 0, enable_torch_compile: bool = False, + enable_alg_ext: bool = False, + disable_opt_rtn: bool = True, seed: int = 42, **kwargs, ): @@ -189,14 +192,9 @@ def __init__( >>> layer_config = { ... "layer1": { - ... "data_type": "int", - ... "bits": 4, + ... "bits": 3, ... "group_size": 128, ... "sym": True, - ... "act_data_type": None, - ... "act_bits": 16, - ... "act_group_size": None, - ... "act_sym": None, ... }, ... "layer2": { ... "W8A16" @@ -214,10 +212,8 @@ def __init__( # Major version releases may pack them with extra configuration options amp = kwargs.pop("amp", True) lr = kwargs.pop("lr", None) - enable_alg_ext = kwargs.pop("enable_alg_ext", False) enable_minmax_tuning = kwargs.pop("enable_minmax_tuning", True) minmax_lr = kwargs.pop("minmax_lr", None) - disable_opt_rtn = kwargs.pop("disable_opt_rtn", False) lr_scheduler = kwargs.pop("lr_scheduler", None) sampler = kwargs.pop("sampler", "rand") not_use_best_mse = kwargs.pop("not_use_best_mse", False) diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 1e4dcdf2b..60cbdb349 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -14,11 +14,9 @@ import copy from copy import deepcopy from dataclasses import dataclass, fields -from typing import Iterable, Optional, Union +from typing import Optional, Union -import torch - -__all__ = ["QuantizationScheme", "get_gguf_scheme", "preset_name_to_scheme", "AutoScheme"] +__all__ = ["QuantizationScheme", "get_gguf_scheme", "preset_name_to_scheme"] @dataclass @@ -285,25 +283,3 @@ def get_gguf_scheme(scheme: Union[str, QuantizationScheme]) -> str: if equal: return key return "" - - -@dataclass -class AutoScheme: - avg_bits: float - options: Union[str, list[Union[QuantizationScheme, str]], tuple[Union[QuantizationScheme, str], ...]] - shared_layers: Optional[Iterable[Iterable[str]]] = None - method: str = "default" - ignore_scale_zp_bits: bool = False - batch_size: Optional[int] = None - nsamples: Optional[int] = None - seqlen: Optional[int] = None - dataset: Optional[str] = None # Import Notice no comma for each item - device_map: Optional[Union[str, torch.device, int, dict]] = None - enable_torch_compile: Optional[bool] = None - disable_opt_rtn: bool = True - low_gpu_mem_usage: bool = True - - def __post_init__(self): - if isinstance(self.options, str): - options = self.options.upper().replace(" ", "") - self.options = options.split(",") diff --git a/docs/step_by_step.md b/docs/step_by_step.md index f692543b7..6efbc85e7 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -119,6 +119,7 @@ AutoRound supports several Schemes: - **W8A16**(bits:8,group_size:128,sym:True,act_bits:16) - **W3A16**(bits:3,group_size:128,sym:True,act_bits:16) - **W2A16**(bits:2,group_size:128,sym:True,act_bits:16) +- **GGUF:Q4_K_M**(all Q*_K,Q*_0,Q*_1 are supported) - **Mixed Bits Weight only** - **NVFP4**(Experimental feature, recommend exporting to llm-compressor format. data_type:nvfp4,act_data_type:nvfp4,static_global_scale,group_size 16) - **MXFP4**(**Research feature,no real kernel**, data_type:mxfp4,act_data_type:mxfp4,rceil,group_size 32) @@ -281,8 +282,6 @@ W2G64 Average Accuracy of 13 tasks and Time Cost Results(Testing was conducted o AutoScheme provide automatically algorithm to provide mixed bits/data_type quantization recipes. For some accuracy result, please refer this doc [here](./auto_scheme_acc.md) -We strongly recommend set `enable_torch_compile` to True to save VRAM. - **Please note that mixed data types are supported during tuning, but cannot be exported to real models at this time..** #### CLI Usage