Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
84b0545
refine
wenhuach21 Oct 29, 2025
9c660c8
Merge branch 'main' into refine_1129
wenhuach21 Oct 29, 2025
6460958
mv AutoScheme class
wenhuach21 Oct 29, 2025
302bb18
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
e0b7301
Add autoscheme usage in homepage
wenhuach21 Oct 29, 2025
6ddec8a
update
wenhuach21 Oct 29, 2025
92d4350
Merge branch 'refine_1129' of https://github.com/intel/auto-round int…
wenhuach21 Oct 29, 2025
ab882cf
fix preci
wenhuach21 Oct 29, 2025
3206b2d
update
wenhuach21 Oct 29, 2025
1c31fcc
update
wenhuach21 Oct 29, 2025
3e43b48
fix post_init
wenhuach21 Oct 29, 2025
4c96eac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
a96737b
fix import error
n1ck-guo Oct 29, 2025
3fe7e65
fix
wenhuach21 Oct 29, 2025
e9b893b
Merge branch 'refine_1129' of https://github.com/intel/auto-round int…
wenhuach21 Oct 29, 2025
b687410
Merge branch 'main' into refine_1129
wenhuach21 Oct 29, 2025
5e24487
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
b545e18
try to fix import issue in another way
wenhuach21 Oct 29, 2025
d4bd21e
Merge branch 'refine_1129' of https://github.com/intel/auto-round int…
wenhuach21 Oct 29, 2025
4487fc2
try to fix preci issue
wenhuach21 Oct 29, 2025
7158730
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
6b6246f
try to fix preci issue
wenhuach21 Oct 29, 2025
9a75aae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2025
b9d5824
update readme
wenhuach21 Oct 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 38 additions & 45 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and
for some accuracy results.

[2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for
all bits other than 3 bits. Example
models: [Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound](https://huggingface.co/Intel/Qwen3-235B-A22B-q2ks-mixed-AutoRound)
and [Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound](https://huggingface.co/Intel/DeepSeek-R1-0528-q2ks-mixed-AutoRound). **A more advanced algorithm** tailored for specific configurations may be available in
all bits other than 3 bits. **A more advanced algorithm** tailored for specific configurations may be available in
v0.8.1.

[2025/05] AutoRound has been integrated into **vLLM**. You can now run models in the AutoRound format directly with
Expand Down Expand Up @@ -186,58 +184,54 @@ ar = AutoRound(model_name_or_path, scheme="W4A16")
# ar = AutoRound(model_name_or_path, nsamples=128, iters=50, lr=5e-3)

# Supported formats: "auto_round" (default), "auto_gptq", "auto_awq", "llm_compressor", "gguf:q4_k_m", etc.
ar.quantize_and_save(output_dir="./tmp_autoround", format="auto_round")
ar.quantize_and_save(output_dir="./qmodel", format="auto_round")
```

<details>
<summary>Detailed Hyperparameters</summary>

- `model`: The PyTorch model to be quantized.

- `tokenizer`: An optional tokenizer for processing input data. If none, a dataset must be provided.

- `bits (int)`: Number of bits for quantization (default is 4).

- `group_size (int)`: Size of the quantization group (default is 128).

- `sym (bool)`: Whether to use symmetric quantization (default is True).

- `enable_quanted_input (bool)`: Whether to use the output of the previous quantized block as the input for the current
block for tuning (default is True).
### AutoScheme Usage
Please refer to the [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for more details on AutoScheme.
~~~python
from auto_round import AutoRound, AutoScheme

- `enable_minmax_tuning (bool)`: Whether to enable weight min-max tuning (default is True).
model_name = "Qwen/Qwen3-8B"
avg_bits = 3.0
scheme = AutoScheme(avg_bits=avg_bits, options=("GGUF:Q2_K_S", "GGUF:Q4_K_S"), ignore_scale_zp_bits=True)
layer_config = {"lm_head": "GGUF:Q6_K"}

- `iters (int)`: Number of tuning iterations (default is 200).
# Change iters to 200 for non-GGUF schemes
ar = AutoRound(model=model_name, scheme=scheme, layer_config=layer_config, iters=0)
ar.quantize_and_save()
~~~

- `lr (float)`: The learning rate for rounding value (default is None, it will be set to 1.0/iters automatically).

- `minmax_lr (float)`: The learning rate for min-max tuning (default is None, it will be set to lr automatically).

- `nsamples (int)`: Number of samples for tuning (default is 128).

- `seqlen (int)`: Data length of the sequence for tuning (default is 2048).

- `batch_size (int)`: Batch size for training (default is 8).

- `scale_dtype (str)`: The data type of quantization scale to be used (default is "float16"), different kernels have
different choices.
<details>
<summary>Important Hyperparameters</summary>

- `amp (bool)`: Whether to use automatic mixed precision (default is True).
##### Quantization Scheme & Configuration
- **`scheme` (str|dict|AutoScheme)**: The predefined quantization keys, e.g. `W4A16`, `MXFP4`, `NVFP4`, `GGUF:Q4_K_M`.
- **`bits` (int)**: Number of bits for quantization (default is `None`). If not None, it will override the scheme setting.
- **`group_size` (int)**: Size of the quantization group (default is `None`). If not None, it will override the scheme setting.
- **`sym` (bool)**: Whether to use symmetric quantization (default is `None`). If not None, it will override the scheme setting.
- **`layer_config` (dict)**: Configuration for weight quantization (default is `None`), mainly for mixed schemes.

- `nblocks (int)`: Packing several blocks as one for tuning together (default is 1).

- `gradient_accumulate_steps (int)`: Number of gradient accumulation steps (default is 1).
##### Algorithm Settings
- **`enable_alg_ext` (bool)**: Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
- **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `False` (improved RTN enabled).

- `low_gpu_mem_usage (bool)`: Whether to save GPU memory at the cost of ~20% more tuning time (default is False).
##### Tuning Process Parameters
- **`iters` (int)**: Number of tuning iterations (default is `200`). Common values: 0 (RTN mode), 50 (with lr=5e-3 recommended), 1000. Higher values increase accuracy but slow down tuning.
- **`lr` (float)**: The learning rate for rounding value (default is `None`). When None, it will be set to `1.0/iters` automatically.
- **`batch_size` (int)**: Batch size for training (default is `8`). 4 is also commonly used.

- `dataset Union[str, list, tuple, torch.utils.data.DataLoader]`: The dataset name for tuning (default is "
NeelNanda/pile-10k"). Local json file and combination of datasets have been supported, e.g. "
./tmp.json,NeelNanda/pile-10k:train, mbpp:train+validation+test"
##### Calibration Dataset
- **`dataset` (str|list|tuple|torch.utils.data.DataLoader)**: The dataset for tuning (default is `"NeelNanda/pile-10k"`). Supports local JSON files and dataset combinations, e.g. `"./tmp.json,NeelNanda/pile-10k:train,mbpp:train+validation+test"`.
- **`nsamples` (int)**: Number of samples for tuning (default is `128`).
- **`seqlen` (int)**: Data length of the sequence for tuning (default is `2048`).

- `layer_config (dict)`: Configuration for weight quantization (default is None), mainly for mixed bits
or mixed precision.

- `device`: The device to be used for tuning. The default is set to 'auto', allowing for automatic detection.
##### Device/Speed Configuration
- **`enable_torch_compile` (bool)**: If no exception is raised, typically we recommend setting it to True for faster quantization with lower resource.
- **`low_gpu_mem_usage` (bool)**: Whether to offload intermediate features to CPU at the cost of ~20% more tuning time (default is `False`).
- **`device_map` (str|dict|int)**: The device to be used for tuning, e.g., `"cpu"`, `"cuda"`, `"0,1,2"` (default is `'0'`).

</details>

Expand All @@ -263,7 +257,7 @@ from auto_round import AutoRoundMLLM
model_name_or_path = "Qwen/Qwen2.5-VL-7B-Instruct"
# Quantize the model
ar = AutoRoundMLLM(model_name_or_path, scheme="W4A16")
output_dir = "./tmp_autoround"
output_dir = "./qmodel"
ar.quantize_and_save(output_dir)
```

Expand Down Expand Up @@ -307,7 +301,6 @@ sampling_params = {"temperature": 0.6, "top_p": 0.95}

outputs = llm.generate(prompts, sampling_params)
for prompt, output in zip(prompts, outputs):
print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
```

Expand Down
3 changes: 2 additions & 1 deletion auto_round/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

# support for old api
from auto_round.autoround import AutoRoundLLM, AutoRoundMLLM, AutoRoundAdam, AutoRoundDiffusion
from auto_round.schemes import QuantizationScheme, AutoScheme
from auto_round.schemes import QuantizationScheme
from auto_round.auto_scheme import AutoScheme
from auto_round.utils import LazyImport


Expand Down
3 changes: 2 additions & 1 deletion auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import os
import sys

from auto_round.auto_scheme import AutoScheme
from auto_round.compressors import BaseCompressor
from auto_round.eval.eval_cli import EvalArgumentParser, _eval_init, eval, eval_task_by_task
from auto_round.schemes import PRESET_SCHEMES, AutoScheme
from auto_round.schemes import PRESET_SCHEMES
from auto_round.utils import (
clear_memory,
get_device_and_parallelism,
Expand Down
34 changes: 7 additions & 27 deletions auto_round/auto_scheme/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_round.logger import logger

AUTO_SCHEME_METHODS = {}
from auto_round.auto_scheme.gen_auto_scheme import AutoScheme
from auto_round.auto_scheme.register import AUTO_SCHEME_METHODS


def register_scheme_methods(names):
"""Class decorator to register a mixed precision algorithm to the registry.

Decorator function used before a Pattern subclass.

Args:
names: A string. Define the export type.

Returns:
cls: The class of register.
"""

def register(alg):
if isinstance(names, (tuple, list)):
for name in names:
AUTO_SCHEME_METHODS[name] = alg
else:
AUTO_SCHEME_METHODS[names] = alg

return alg

return register


import auto_round.auto_scheme.default_alg
try:
import auto_round.auto_scheme.default_alg
except ImportError:
logger.warning("AutoScheme is currently supported only on Linux.")
Binary file modified auto_round/auto_scheme/default_alg.abi3.so
Binary file not shown.
32 changes: 27 additions & 5 deletions auto_round/auto_scheme/gen_auto_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,43 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import asdict
from typing import Iterable, Union

from dataclasses import dataclass
from typing import Iterable, Optional, Union

import torch

from auto_round import AutoScheme
from auto_round.auto_scheme import AUTO_SCHEME_METHODS
from auto_round.auto_scheme.register import AUTO_SCHEME_METHODS
from auto_round.auto_scheme.utils import compute_avg_bits_for_scheme
from auto_round.compressors.utils import gguf_type_fallback
from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG
from auto_round.logger import logger
from auto_round.schemes import QuantizationScheme
from auto_round.utils import get_layer_features, get_module


@dataclass
class AutoScheme:
avg_bits: float
options: Union[str, list[Union[QuantizationScheme, str]], tuple[Union[QuantizationScheme, str], ...]]
shared_layers: Optional[Iterable[Iterable[str]]] = None
method: str = "default"
ignore_scale_zp_bits: bool = False
batch_size: Optional[int] = None
nsamples: Optional[int] = None
seqlen: Optional[int] = None
dataset: Optional[str] = None # Import Notice no comma for each item
device_map: Optional[Union[str, torch.device, int, dict]] = None
enable_torch_compile: Optional[bool] = None
disable_opt_rtn: bool = True
low_gpu_mem_usage: bool = True

def __post_init__(self):
if isinstance(self.options, str):
options = self.options.upper().replace(" ", "")
self.options = options.split(",")


class GenScheme:
"""Generate and validate quantization schemes for model layers."""

Expand Down
39 changes: 39 additions & 0 deletions auto_round/auto_scheme/register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

AUTO_SCHEME_METHODS = {}


def register_scheme_methods(names):
"""Class decorator to register a mixed precision algorithm to the registry.

Decorator function used before a Pattern subclass.

Args:
names: A string. Define the export type.

Returns:
cls: The class of register.
"""

def register(alg):
if isinstance(names, (tuple, list)):
for name in names:
AUTO_SCHEME_METHODS[name] = alg
else:
AUTO_SCHEME_METHODS[names] = alg

return alg

return register
7 changes: 5 additions & 2 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from __future__ import annotations

from typing import Union
from typing import TYPE_CHECKING, Union

import torch

Expand All @@ -26,9 +26,12 @@
MLLMCompressor,
)
from auto_round.logger import deprecated, logger
from auto_round.schemes import AutoScheme, QuantizationScheme
from auto_round.schemes import QuantizationScheme
from auto_round.utils import is_diffusion_model, is_mllm_model

if TYPE_CHECKING:
from auto_round.auto_scheme.gen_auto_scheme import AutoScheme


class AutoRound:
"""Automatic weight rounding (Signed Gradient Descent) for LLM quantization
Expand Down
14 changes: 5 additions & 9 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from tqdm import tqdm
from transformers import set_seed

from auto_round.auto_scheme.gen_auto_scheme import AutoScheme
from auto_round.compressors.utils import (
block_forward,
check_need_act_calibration,
Expand All @@ -52,7 +53,7 @@
from auto_round.export.export_to_autoround import AutoRoundFormat
from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG, ModelType
from auto_round.logger import logger
from auto_round.schemes import AutoScheme, QuantizationScheme, get_gguf_scheme, preset_name_to_scheme
from auto_round.schemes import QuantizationScheme, get_gguf_scheme, preset_name_to_scheme
from auto_round.sign_sgd import SignSGD
from auto_round.special_model_handler import _handle_moe_model
from auto_round.utils import (
Expand Down Expand Up @@ -139,6 +140,8 @@ def __init__(
low_gpu_mem_usage: bool = False,
device_map: Union[str, torch.device, int, dict] = 0,
enable_torch_compile: bool = False,
enable_alg_ext: bool = False,
disable_opt_rtn: bool = True,
seed: int = 42,
**kwargs,
):
Expand Down Expand Up @@ -189,14 +192,9 @@ def __init__(

>>> layer_config = {
... "layer1": {
... "data_type": "int",
... "bits": 4,
... "bits": 3,
... "group_size": 128,
... "sym": True,
... "act_data_type": None,
... "act_bits": 16,
... "act_group_size": None,
... "act_sym": None,
... },
... "layer2": {
... "W8A16"
Expand All @@ -214,10 +212,8 @@ def __init__(
# Major version releases may pack them with extra configuration options
amp = kwargs.pop("amp", True)
lr = kwargs.pop("lr", None)
enable_alg_ext = kwargs.pop("enable_alg_ext", False)
enable_minmax_tuning = kwargs.pop("enable_minmax_tuning", True)
minmax_lr = kwargs.pop("minmax_lr", None)
disable_opt_rtn = kwargs.pop("disable_opt_rtn", False)
lr_scheduler = kwargs.pop("lr_scheduler", None)
sampler = kwargs.pop("sampler", "rand")
not_use_best_mse = kwargs.pop("not_use_best_mse", False)
Expand Down
28 changes: 2 additions & 26 deletions auto_round/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
import copy
from copy import deepcopy
from dataclasses import dataclass, fields
from typing import Iterable, Optional, Union
from typing import Optional, Union

import torch

__all__ = ["QuantizationScheme", "get_gguf_scheme", "preset_name_to_scheme", "AutoScheme"]
__all__ = ["QuantizationScheme", "get_gguf_scheme", "preset_name_to_scheme"]


@dataclass
Expand Down Expand Up @@ -285,25 +283,3 @@ def get_gguf_scheme(scheme: Union[str, QuantizationScheme]) -> str:
if equal:
return key
return ""


@dataclass
class AutoScheme:
avg_bits: float
options: Union[str, list[Union[QuantizationScheme, str]], tuple[Union[QuantizationScheme, str], ...]]
shared_layers: Optional[Iterable[Iterable[str]]] = None
method: str = "default"
ignore_scale_zp_bits: bool = False
batch_size: Optional[int] = None
nsamples: Optional[int] = None
seqlen: Optional[int] = None
dataset: Optional[str] = None # Import Notice no comma for each item
device_map: Optional[Union[str, torch.device, int, dict]] = None
enable_torch_compile: Optional[bool] = None
disable_opt_rtn: bool = True
low_gpu_mem_usage: bool = True

def __post_init__(self):
if isinstance(self.options, str):
options = self.options.upper().replace(" ", "")
self.options = options.split(",")
Loading