Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and [fbaldassarri](https://huggingface.co/fbaldassarri). For usage instructions,


## 🆕 What's New
[2025/10] AutoRound team proposed a fast algorithm to generate mixed bits/datatypes schemes in minutes. Please
[2025/10] We proposed a fast algorithm to generate mixed bits/datatypes schemes in minutes. Please
refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and [this guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for usage instructions.

[2025/09] AutoRound now includes experimental support for the mxfp4 and nvfp4 dtypes. For accuracy results, see the [documentation](./docs/mxnv_acc.md)
Expand Down Expand Up @@ -68,7 +68,7 @@ Support **AutoRound, AutoAWQ, AutoGPTQ, and GGUF** for maximum compatibility. De
Quantize 7B models in about 10 minutes on a single GPU. Details are shown in [quantization costs](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#quantization-costs)

✅ **Fast mixed bits/data-types scheme generation**
Automatically configure in minutes, with about 2X-4X the model’s BF16 VRAM size as overhead.
Automatically configure in minutes, with about 2X-4X the model’s BF16 VRAM size as overhead. Accuracy [results](./docs/auto_scheme_acc.md) and [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme).

✅ **10+ VLMs Support**
Out-of-the-box quantization for 10+ vision-language models [example models](https://huggingface.co/collections/OPEA/vlms-autoround-675bc712fdd6a55ebaf11bfa), [support matrix](https://github.com/intel/auto-round/tree/main/auto_round/mllm#support-matrix)
Expand Down
Binary file modified auto_round/auto_scheme/default_alg.abi3.so
Binary file not shown.
10 changes: 7 additions & 3 deletions auto_round/auto_scheme/gen_auto_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class GenScheme:

def __init__(
self,
auto_scheme: AutoScheme, # TODO support shared layer
auto_scheme: AutoScheme,
model: torch.nn.Module,
quant_layer_names: Iterable[str],
fixed_layer_scheme: dict[str, dict],
dataset: str = "pile-10k", # TODO use auto-round dataset
dataset: str = "pile-10k",
device_map: Union[str, torch.device, int, dict, None] = None,
tokenizer=None,
enable_torch_compile=False,
Expand All @@ -46,7 +46,11 @@ def __init__(
self.fixed_layer_scheme = fixed_layer_scheme
self.dataset = dataset
self.device_map = device_map if self.auto_scheme.device_map is None else self.auto_scheme.device_map
self.enable_torch_compile = enable_torch_compile
self.enable_torch_compile = (
enable_torch_compile
if self.auto_scheme.enable_torch_compile is None
else self.auto_scheme.enable_torch_compile
)
self._check_configs()

def _check_configs(self) -> None:
Expand Down
8 changes: 5 additions & 3 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ def _gen_auto_scheme(
# mainly using quant_layers and fixed by users
from auto_round.auto_scheme.gen_auto_scheme import GenScheme

if self.enable_torch_compile is False:
logger.warning("we strongly recommend to enable torch compile for AutoScheme to save VRAM")
gen_scheme = GenScheme(
scheme,
self.model,
Expand Down Expand Up @@ -583,9 +585,9 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled")

if is_debug_mode() and self.enable_torch_compile:
self.enable_torch_compile = False
logger.warning("reset enable_torch_compile to `False` as debug mode is enabled")
# if is_debug_mode() and self.enable_torch_compile:
# self.enable_torch_compile = False
# logger.warning("reset enable_torch_compile to `False` as debug mode is enabled")

if (self.data_type.startswith("fp") or self.act_data_type.startswith("fp")) and self.enable_torch_compile:
self.enable_torch_compile = False
Expand Down
1 change: 1 addition & 0 deletions auto_round/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ class AutoScheme:
seqlen: Optional[int] = None
dataset: Optional[str] = None # Import Notice no comma for each item
device_map: Optional[Union[str, torch.device, int, dict]] = None
enable_torch_compile: Optional[bool] = None

def __post_init__(self):
if isinstance(self.options, str):
Expand Down
50 changes: 41 additions & 9 deletions docs/step_by_step.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ AutoRound supports several Schemes:
- **W8A16**(bits:8,group_size:128,sym:True,act_bits:16)
- **W3A16**(bits:3,group_size:128,sym:True,act_bits:16)
- **W2A16**(bits:2,group_size:128,sym:True,act_bits:16)
- **Mixed bits Weight only**
- **Mixed Bits Weight only**
- **NVFP4**(Experimental feature, recommend exporting to llm-compressor format. data_type:nvfp4,act_data_type:nvfp4,static_global_scale,group_size 16)
- **MXFP4**(**Research feature,no real kernel**, data_type:mxfp4,act_data_type:mxfp4,rceil,group_size 32)
- **FPW8A16**(**Research feature,no real kernel**, data_type:fp8,act_data_type 16:,group_size 0->per tensor )
Expand Down Expand Up @@ -160,15 +160,15 @@ CPU, Intel GPU, HPU and CUDA for both quantization and inference.
auto-round --model facebook/opt-125m --scheme "W4A16" --format "auto_gptq,auto_awq,auto_round"
```

- **Best Settings:**
- **AutoRoundBest recipe:**

This setting provides the best accuracy in most scenarios but is 4–5× slower than the standard AutoRound recipe. It is especially recommended for 2-bit quantization and is a good choice if sufficient resources are available.

```bash
auto-round-best --model facebook/opt-125m --scheme "W4A16" --format "auto_gptq,auto_awq,auto_round"
```

- **Light Settings:**
- **AutoRoundLight Settings:**

This setting offers the best speed (2-3X faster than AutoRound), but it may cause a significant accuracy drop for small models and 2-bit quantization. It is recommended for 4-bit settings and models larger than 3B

Expand All @@ -195,7 +195,9 @@ output_dir = "./tmp_autoround"
ar.quantize_and_save(output_dir, format="auto_gptq,auto_awq,auto_round")
```

#### Mixed bits Usage
#### Mixed Bits Usage
AutoRound(>0.8) offers auto-scheme to generate mixed bits recipe autocially, please refer to [AutoScheme](#autoscheme) section for more details.

Auto-GPTQ and Auto-AWQ only support a limited set of mixed-bit configurations. If you're unsure about the details, we recommend using the AutoRound format.

vLLM and SGLang fuse MoE and QKV layers, so it's recommended not to assign different bit widths to these layers.
Expand Down Expand Up @@ -279,8 +281,11 @@ W2G64 Average Accuracy of 13 tasks and Time Cost Results(Testing was conducted o

AutoScheme provide automatically algorithm to provide mixed bits/data_type quantization recipes. For some accuracy result, please refer this doc [here](./auto_scheme_acc.md)

We strongly recommend set `enable_torch_compile` to True to save VRAM.

**Please note that mixed data types are supported during tuning, but cannot be exported to real models at this time..**
### CLI Usage

#### CLI Usage
use `iters=200`for tuning.
~~~bash
auto_round \
Expand All @@ -292,25 +297,25 @@ auto_round \
--format fake
~~~

### API Usage
#### API Usage
~~~
avg_bits= 3.0
scheme = AutoScheme(avg_bits=avg_bits, options=("W2A16G64“, "W4A16","W8A16"))
ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1)
ar.quantize_and_save()
~~~

### Hyperparameters in AutoScheme
#### Hyperparameters in AutoScheme
`avg_bits(float)`: Target average bits for the whole model, only to be quantized layer will be counted in the average bits calculation.

`options(Union[str, list[Union[QuantizationScheme, str]])`: the options of quantization schemes to choose from. It could be a string like "W4A16", or a list of strings or QuantizationScheme objects.

`ignore_scale_zp_bits(bool)`: Whether to ignore the bits of scale and zero point in average bits calculation. Default is False.

`shared_layers (Optional[Iterable[Iterable[str]]])` only supported in API now

`device_map (Optional[str,dict,torch.device])` only supported in API now, as auto-scheme used more VRAM than auto-round tuning, so you could set a different device_map for it.

`shared_layers (Optional[Iterable[Iterable[str]]])` only supported in API now

In some serving frameworks, certain layers (e.g., QKV or MoE) are fused to accelerate inference. These fused layers may require the same data type and bit configuration. The shared_layers option simplifies this setup by supporting both regex and full-name matching. **Note that regex matching is applied in a block-wise manner.**


Expand All @@ -329,6 +334,33 @@ ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1)
model, layer_config = ar.quantize()
```

Besides, if you want to fix the scheme for some layers, you could set it via `layer_config` in AutoRound API.
```python
from auto_round import AutoRound, AutoScheme

model_name = "Qwen/Qwen3-8B"
avg_bits = 3.0
scheme = AutoScheme(avg_bits=avg_bits, options=("GGUF:Q2_K_S", "GGUF:Q4_K_S"), ignore_scale_zp_bits=True)
layer_config = {"lm_head": "GGUF:Q6_K"}

ar = AutoRound(model=model_name, scheme=scheme, layer_config=layer_config, iters=0)
ar.quantize_and_save()
```

#### AutoScheme Cost
The tuning cost of AutoScheme is approximately 2 to 4 times that of model's bf16 size, depending on the options.
We tested it on Nvidia A100 80G using torch v2.8.

| Models | Scheme | VRAM Cost <br /> (torch compile) | Time Cost <br /> (torch compile) | VRAM Cost <br /> (w/o torch compile) | Time Cost <br /> (w/o torch compile) |
| -------- | ----------------- | ---------------------------- | ----------------------------- | --------------------------------- | --------------------------------- |
| Qwen3-8B | W2A16 / W4A16 / W8A16 | 34G | 30s × len of options | 61G | 40s × len of options |
| Qwen3-8B | MXFP4 / MXFP8 | 36G | 60s × len of options | 54G | 120s × len of options |
| Qwen3-8B | GGUF* | 54G | 30s × len of options | 50G | 23s × len of options |


#### Limitations
Embedding layer is supported in AutoScheme, it will use the best scheme in options.


### RTN mode
AutoRound also supports RTN (Round-To-Nearest) mode for fast, calibration-free baseline quantization. try setting `iters=0` and use `group_size=32` for better results.
Expand Down