Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
8ceb6a1
reduce vram
wenhuach21 Nov 18, 2025
97f460e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
dd27f91
update
wenhuach21 Nov 18, 2025
0ea4fa2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
d40de66
update
wenhuach21 Nov 18, 2025
67a1b34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
2468f0a
fix bug
wenhuach21 Nov 18, 2025
0bd2cf9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
7701487
update
wenhuach21 Nov 18, 2025
b54c1e4
Merge branch 'optimize_gguf_vram' of https://github.com/intel/auto-ro…
wenhuach21 Nov 18, 2025
42e5cc0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2025
1307ddd
git push
wenhuach21 Nov 18, 2025
6d6d86a
fix accuracy bug
wenhuach21 Nov 18, 2025
e2586f9
trigger ut
wenhuach21 Nov 18, 2025
c7b3c24
clean code
wenhuach21 Nov 18, 2025
8ad2019
q80 q4k
wenhuach21 Nov 19, 2025
d316854
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
1743472
q5k
wenhuach21 Nov 19, 2025
5ffa12b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
db5c642
all ggufs use inplace ops
wenhuach21 Nov 19, 2025
ec6cb46
update
wenhuach21 Nov 19, 2025
ab5067b
Merge branch 'optimize_gguf_vram' of https://github.com/intel/auto-ro…
wenhuach21 Nov 19, 2025
5a503b4
update
wenhuach21 Nov 19, 2025
737977a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
c3b9213
Update auto_round/export/export_to_gguf/packing.py
wenhuach21 Nov 19, 2025
c2fe267
Update auto_round/export/export_to_gguf/packing.py
wenhuach21 Nov 19, 2025
963e6f9
Update auto_round/export/export_to_gguf/packing.py
wenhuach21 Nov 19, 2025
4c6366a
Update auto_round/export/export_to_gguf/packing.py
wenhuach21 Nov 19, 2025
343dbb6
Update auto_round/data_type/gguf.py
wenhuach21 Nov 19, 2025
932f407
Update auto_round/compressors/base.py
wenhuach21 Nov 19, 2025
8304a02
Update auto_round/export/export_to_gguf/packing.py
wenhuach21 Nov 19, 2025
bc86fdc
fix by comments
wenhuach21 Nov 19, 2025
08514f9
Merge branch 'optimize_gguf_vram' of https://github.com/intel/auto-ro…
wenhuach21 Nov 19, 2025
033330e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
c905bd2
fix line too long
wenhuach21 Nov 19, 2025
a616579
update readme
wenhuach21 Nov 19, 2025
cd01f13
update
wenhuach21 Nov 19, 2025
876267e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 19, 2025
1138c73
clean code
wenhuach21 Nov 20, 2025
111933f
Merge branch 'optimize_gguf_vram' of https://github.com/intel/auto-ro…
wenhuach21 Nov 20, 2025
be5e13c
update
wenhuach21 Nov 20, 2025
d97429f
Merge branch 'main' into optimize_gguf_vram
wenhuach21 Nov 20, 2025
190ea03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
f16cde5
fix typo
wenhuach21 Nov 20, 2025
9f408d1
update
wenhuach21 Nov 20, 2025
2d059f3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
7f27d72
Merge branch 'main' into optimize_gguf_vram
wenhuach21 Nov 20, 2025
78499a7
try to fix ut failure
wenhuach21 Nov 20, 2025
575103f
try to fix ut failure
wenhuach21 Nov 20, 2025
55efbbc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
9fa4cd9
try to fix ut failure
wenhuach21 Nov 20, 2025
56abaa8
Merge branch 'optimize_gguf_vram' of https://github.com/intel/auto-ro…
wenhuach21 Nov 20, 2025
70a3fdb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 20, 2025
b035b4f
try to fix ut failure
wenhuach21 Nov 20, 2025
a7cd959
update
wenhuach21 Nov 21, 2025
ae99930
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2025
79b2490
fix
wenhuach21 Nov 21, 2025
29b5188
update
wenhuach21 Nov 21, 2025
f840d2b
Merge branch 'optimize_gguf_vram' of https://github.com/intel/auto-ro…
wenhuach21 Nov 21, 2025
5e0bb6c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2025
d6d2979
fix typo
wenhuach21 Nov 21, 2025
9dcea24
Merge branch 'optimize_gguf_vram' of https://github.com/intel/auto-ro…
wenhuach21 Nov 21, 2025
1c8fe02
fix bug of gguf mllm
n1ck-guo Nov 24, 2025
0c254d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2025
f85fe7e
refine a little
wenhuach21 Nov 24, 2025
94085dc
refine a little
wenhuach21 Nov 24, 2025
1c57b19
Merge branch 'main' into optimize_gguf_vram
wenhuach21 Nov 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,14 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round")
<summary>Important Hyperparameters</summary>

##### Quantization Scheme & Configuration
- **`scheme` (str|dict|AutoScheme)**: The predefined quantization keys, e.g. `W4A16`, `MXFP4`, `NVFP4`, `GGUF:Q4_K_M`.
- **`scheme` (str|dict|AutoScheme)**: The predefined quantization keys, e.g. `W4A16`, `MXFP4`, `NVFP4`, `GGUF:Q4_K_M`. For MXFP4/NVFP4, we recommend exporting to LLM-Compressor format.
- **`bits` (int)**: Number of bits for quantization (default is `None`). If not None, it will override the scheme setting.
- **`group_size` (int)**: Size of the quantization group (default is `None`). If not None, it will override the scheme setting.
- **`sym` (bool)**: Whether to use symmetric quantization (default is `None`). If not None, it will override the scheme setting.
- **`layer_config` (dict)**: Configuration for weight quantization (default is `None`), mainly for mixed schemes.
- **`layer_config` (dict)**: Configuration for layer_wise scheme (default is `None`), mainly for customized mixed schemes.

##### Algorithm Settings
- **`enable_alg_ext` (bool)**: [Experimental Feature] Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
- **`enable_alg_ext` (bool)**: [Experimental Feature] Only for `iters>0`. Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
- **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `False` (improved RTN enabled).

##### Tuning Process Parameters
Expand All @@ -217,7 +217,8 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round")

</details>

### AutoScheme Usage
### Adaptive Bits/Dtype Usage
AutoScheme provide automatically algorithm to provide mixed bits/data_type quantization recipes. For some accuracy result, please refer to this [doc](https://github.com/intel/auto-round/blob/main/docs/auto_scheme_acc.md).
Please refer to the [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for more details on AutoScheme.
~~~python
from auto_round import AutoRound, AutoScheme
Expand Down Expand Up @@ -299,7 +300,7 @@ for output in outputs:


### SGLang (Intel GPU/CUDA)
Please note that support for the MoE models and visual language models is currently limited.
**Please note that support for the MoE models and visual language models is currently limited.**

```python
import sglang as sgl
Expand Down
84 changes: 51 additions & 33 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,11 +713,10 @@ def _check_compatibility(self) -> None:
raise ValueError("Gguf format is not compatible with other formats, please choose only one of them")
if has_gguf and self.iters != 0 and self.bits != 3 and not self.enable_alg_ext:
logger.warning(
"`iters=0` is recommended when exporting to GGUF format except for bits 3,"
" as we have optimized the RTN method for this case."
" Or add enable_alg_ext to use the new algorithm,"
" refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md"
" to check the acc."
"`iters=0` is recommended when exporting to current GGUF format"
" or add `enable_alg_ext` for better accuracy with much more tuning cost."
" Please refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md"
" for the accuracy results."
)

if (
Expand Down Expand Up @@ -1087,11 +1086,16 @@ def _quantize_embedding_layer(self):
dtype = f"rtn_{dtype}"

quant_func = QUANT_FUNC_WITH_DTYPE[dtype]
dtype = module.weight.dtype
# As typically float32 are used in RTN to search scale zp,
# to avoid cache a bf16 copy we'd better use float32
if config["super_group_size"] is not None:
dtype = torch.float32

# Attempt quantization on GPU, fall back to CPU if OOM
try:
weight, scale, zp = quant_func(
module.weight.to(self.device),
module.weight.to(dtype=dtype, device=self.device),
**{k: config[k] for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]},
)
except torch.OutOfMemoryError:
Expand Down Expand Up @@ -1124,8 +1128,9 @@ def _quantize_embedding_layer(self):

# Update config
self.layer_config.setdefault(name, {}).update(config)

# Release memory
del weight
del scale
del zp
clear_memory(device_list=self.device_list)

return is_quantized
Expand Down Expand Up @@ -1224,7 +1229,7 @@ def get_imatrix_hook(module, input, output):
for hook in hooks:
hook.remove()

def _quantize_layer_via_rtn(self, name: str) -> None:
def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=True) -> None:
"""Quantizes a layer using RTN (Round-To-Nearest) if available.

This function attempts to quantize a layer by switching its data type to a
Expand All @@ -1241,19 +1246,20 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
RuntimeError: If quantization fails for reasons unrelated to memory.
"""
m = get_module(self.model, name)
if dtype is not None:
m = m.to(dtype)

if is_fp8_linear(m):
m = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device)
set_module(self.model, name, m)

tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.device
# Step 1: Try quantization on GPU first, fall back to CPU if OOM
# if only export gguf, using gguf-packing instead of rtn
if self.immediate_packing and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn:
m = m.to(tuning_device)
m.scale = None
m.zp = None
else:
try:
tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.device
m = m.to(tuning_device)
m = WrapperLinear(
m,
Expand All @@ -1265,7 +1271,6 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
disable_opt_rtn=self.disable_opt_rtn,
)
m = m.unwrapper({})
m.to("cpu")
except torch.OutOfMemoryError:
cuda_error_msg = traceback.format_exc()
m = m.orig_layer if hasattr(m, "orig_layer") else m
Expand All @@ -1285,18 +1290,23 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
raise

# Step 2: Optional immediate packing/export
if self.immediate_packing:
if self.immediate_packing: # For gguf, packing conducts on block level
self._immediate_pack(name)
if to_cpu:
m = m.to("cpu")
else:
if to_cpu:
m = m.to("cpu")
set_module(self.model, name, m)

if self.immediate_saving:
all_to_quantized_module_names = [n for n, m in self.model.named_modules() if check_to_quantized(m)]
last_module = (len(all_to_quantized_module_names) == 0) or (name == all_to_quantized_module_names[-1])
m = get_module(self.model, name)
immediate_saving(self, m, name, last_module)

def _immediate_pack(self, name: str):
if not self.immediate_packing:
return
m = get_module(self.model, name)
if not check_to_quantized(m):
return
Expand Down Expand Up @@ -1353,16 +1363,18 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
for module in tqdm(modules, desc="Update weight global scale for fuse module"):
update_fused_layer_global_scales(module)

has_gguf_k = (
any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None
)

self._quantize_embedding_layer()
if not (any("gguf" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None):
self._quantize_embedding_layer() # leave to gguf itself to handle

self.model.to("cpu")
# Release memory
clear_memory(device_list=self.device_list)

enable_imatrix = False
if not self.disable_opt_rtn:
has_gguf_k = (
any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None
)
if has_gguf_k:
enable_imatrix = True
elif self.data_type == "int" and self.sym:
Expand Down Expand Up @@ -1498,39 +1510,44 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
self.device,
self.cache_device,
)

if len(self.device_list) > 1:
accelerate.hooks.remove_hook_from_submodules(block)

if is_nv_fp(self.act_data_type) or is_static_wfp8afp8(self):
# enable moe experts act_max automatic generation for Linear
set_amax_for_all_moe_layers(block, attr_name="act_max")
# Normalize imatrix and quantize layers
if self.low_gpu_mem_usage:
block.to("cpu")
clear_memory(device_list=self.device_list)

for _, m in block.named_modules():
# fix issue: Ling-flash-2.0-q2_k_s fail infer on cuda but well on cpu
# https://huggingface.co/Intel/Ling-flash-2.0-gguf-q2ks-mixed-AutoRound/discussions/1
if hasattr(m, "imatrix"):
m.imatrix /= m.imatrix_cnt
if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names:
self._quantize_layer_via_rtn(m.tmp_name)
self._quantize_layer_via_rtn(m.tmp_name, to_cpu=False)
all_to_quantized_module_names.remove(m.tmp_name)
if not self.immediate_saving:
mv_module_from_gpu(block)
if block_name == block_names[-1]:
clear_memory(input_ids, device_list=self.device_list)
else:
clear_memory(device_list=self.device_list)

memory_monitor.log_summary()
pbar.update(1)

pbar.close()
cnt = 1
block_names_cnt = len(flatten_list(get_block_names(self.model, True)))
clear_mem_freq = len(all_to_quantized_module_names) // block_names_cnt
if clear_mem_freq == 0:
clear_mem_freq = 1
# Process remaining layers not in blocks
for name in all_to_quantized_module_names:
self._quantize_layer_via_rtn(name)
if cnt % clear_mem_freq == 0:
clear_memory(device_list=self.device_list)
cnt = 1
cnt += 1
dtype = None
if self.super_group_size is not None:
dtype = torch.float32
self._quantize_layer_via_rtn(name, dtype=dtype)
# clear_memory(device_list=self.device_list)

def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tensor]:
keys = inputs.keys()
Expand Down Expand Up @@ -1631,6 +1648,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
logger.info("start to cache block inputs")
all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names)
is_quantized_embedding = self._quantize_embedding_layer()
clear_memory(device_list=self.device_list)
all_q_inputs = None
if is_quantized_embedding:
all_inputs = copy.deepcopy(self.inputs)
Expand Down Expand Up @@ -2838,7 +2856,7 @@ def _quantize_block(
if auto_offload:
mv_module_from_gpu(block)

clear_memory(input_ids)
clear_memory(input_ids, device_list=self.device_list)
memory_info_summary = memory_monitor.get_summary()
logger.infoclean(dump_info + "," + memory_info_summary)

Expand All @@ -2848,7 +2866,7 @@ def _quantize_block(
accelerate.hooks.remove_hook_from_submodules(block)
if auto_offload:
mv_module_from_gpu(block)
clear_memory(input_ids)
clear_memory(input_ids, device_list=self.device_list)
memory_info_summary = memory_monitor.get_summary()
logger.infoclean(dump_info + "," + memory_info_summary)

Expand Down
Loading