Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a13bdf0
fix imatrix pad issue
wenhuach21 Nov 13, 2025
4e20199
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2025
405bde7
update
wenhuach21 Nov 13, 2025
11171c4
Merge branch 'fix_imatrix' of https://github.com/intel/auto-round int…
wenhuach21 Nov 13, 2025
886a6c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2025
e2d7e70
refine
wenhuach21 Nov 13, 2025
cd97899
Merge branch 'fix_imatrix' of https://github.com/intel/auto-round int…
wenhuach21 Nov 13, 2025
2130075
clean
wenhuach21 Nov 13, 2025
9ecf7e6
update
wenhuach21 Nov 13, 2025
ea310ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2025
967af55
update
wenhuach21 Nov 13, 2025
5c5f72d
Merge branch 'fix_imatrix' of https://github.com/intel/auto-round int…
wenhuach21 Nov 13, 2025
356ee30
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2025
5f4d85c
Update auto_round/__main__.py
wenhuach21 Nov 13, 2025
a9fe211
update
wenhuach21 Nov 14, 2025
77ea33f
Merge branch 'fix_imatrix' of https://github.com/intel/auto-round int…
wenhuach21 Nov 14, 2025
63ae0c2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2025
6c289d0
Merge branch 'main' into fix_imatrix
wenhuach21 Nov 14, 2025
36d41af
refine comments
wenhuach21 Nov 14, 2025
a3a19e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2025
e5fce1e
Merge branch 'main' into fix_imatrix
wenhuach21 Nov 14, 2025
c584039
update readme
wenhuach21 Nov 14, 2025
267ff64
refine readme
wenhuach21 Nov 14, 2025
0bc902f
refine
wenhuach21 Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ See our [paper](https://arxiv.org/pdf/2309.05516) for more details. For usage in


## 🆕 What's New
[2025/11] AutoRound now offers preliminary support for an **enhanced GGUF quantization algorithm** via `--enable_alg_ext`. For detailed accuracy benchmarks, please refer to the accompanying [documentation](./docs/gguf_alg_ext_acc.md).
[2025/11] AutoRound now offers preliminary support for an enhanced GGUF quantization algorithm via `--enable_alg_ext`. For detailed accuracy benchmarks, please refer to the [documentation](./docs/gguf_alg_ext_acc.md).

[2025/10] AutoRound has been integrated into **SGLang**. You can now run models in the AutoRound format directly using the latest SGLang later than v0.5.4.

Expand All @@ -46,8 +46,7 @@ refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and
for some accuracy results.

[2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for
all bits other than 3 bits. **A more advanced algorithm** tailored for specific configurations may be available in
v0.8.1.
all bits other than 3 bits.

[2025/05] AutoRound has been integrated into **Transformers** and **vLLM**.

Expand Down Expand Up @@ -192,7 +191,7 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round")
- **`layer_config` (dict)**: Configuration for weight quantization (default is `None`), mainly for mixed schemes.

##### Algorithm Settings
- **`enable_alg_ext` (bool)**: Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
- **`enable_alg_ext` (bool)**: [Experimental Feature] Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`.
- **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `False` (improved RTN enabled).

##### Tuning Process Parameters
Expand All @@ -208,6 +207,7 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round")
##### Device/Speed Configuration
- **`enable_torch_compile` (bool)**: If no exception is raised, typically we recommend setting it to True for faster quantization with lower resource.
- **`low_gpu_mem_usage` (bool)**: Whether to offload intermediate features to CPU at the cost of ~20% more tuning time (default is `False`).
- **`low_cpu_mem_usage` (bool)**: [Experimental Feature]Whether to enable saving immediately to save ram usage (default is `False`).
- **`device_map` (str|dict|int)**: The device to be used for tuning, e.g., `auto`, "cpu"`, `"cuda"`, `"0,1,2"` (default is `'0'`). When using "auto", it will try to use all available GPUs.

</details>
Expand Down
7 changes: 7 additions & 0 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ def __init__(self, *args, **kwargs):
type=float,
help="Learning rate specifically for min-max tuning. " "If None, uses the same value as --lr. ",
)
tuning.add_argument(
"--momentum",
default=0,
type=float,
help="Momentum factor for the optimizer. Default is 0 (no momentum).",
)
tuning.add_argument(
"--gradient_accumulate_steps",
default=1,
Expand Down Expand Up @@ -591,6 +597,7 @@ def tune(args):
extra_config=extra_config,
layer_config=layer_config,
model_dtype=args.model_dtype,
momentum=args.momentum,
)

model_name = args.model.rstrip("/")
Expand Down
30 changes: 22 additions & 8 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __init__(
super_group_size, super_bits, scale_dtype ("fp16" etc.),
nblocks, to_quant_block_names,
enable_norm_bias_tuning, enable_quanted_input,
disable_deterministic_algorithms, mllm, static_kv_dtype,enable_deterministic_algorithms
disable_deterministic_algorithms, mllm, static_kv_dtype,enable_deterministic_algorithms,momentum
Raises:
ValueError: If invalid device is provided or tokenizer is missing for non-str model with iters > 0.
RuntimeError: If model parameters are on meta device.
Expand Down Expand Up @@ -234,6 +234,7 @@ def __init__(
enable_quanted_input: bool = kwargs.pop("enable_quanted_input", True)
disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True)
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)
self.momentum = kwargs.pop("momentum", 0.0)
static_kv_dtype = kwargs.pop("static_kv_dtype", None)
model_dtype = kwargs.pop("model_dtype", None)
device = kwargs.pop("device", None)
Expand Down Expand Up @@ -1567,11 +1568,12 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
# It is best to modify the model structure in the quantize function and check the format,
# because it may cause the gguf format to not be exported normally.
self.model = _handle_moe_model(self.model, formats=formats)
# Assign temporary names after replacing modules
for n, m in self.model.named_modules(): # TODO check if could removed

# Temporary names must be assigned after handle_moe_model;
# placing them earlier would cause them to be removed when the module is replaced.
for n, m in self.model.named_modules():
m.tmp_name = n

# TODO check scale_dtype
if not self.is_auto_scheme:
enable_gguf_official_mixed = True
else:
Expand Down Expand Up @@ -2661,12 +2663,24 @@ def _quantize_block(

lr = torch.tensor(self.lr)
minmax_lr = torch.tensor(self.minmax_lr)
is_adam = "adam" in self.__class__.__name__.lower()

extra_kwargs = {} if is_adam else {"momentum": self.momentum}

if self.enable_minmax_tuning:
optimizer = self.optimizer(
[{"params": round_params}, {"params": minmax_params, "lr": minmax_lr}], lr=lr, weight_decay=0
)
params = [
{"params": round_params},
{"params": minmax_params, "lr": minmax_lr},
]
else:
optimizer = self.optimizer(round_params, lr=lr, weight_decay=0)
params = round_params

optimizer = self.optimizer(
params,
lr=lr,
weight_decay=0,
**extra_kwargs,
)

if len(round_params) + len(minmax_params) <= 0:
dump_info = (
Expand Down
133 changes: 120 additions & 13 deletions auto_round/data_type/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants
from auto_round.logger import logger
from auto_round.utils import get_reciprocal
from auto_round.utils.device import clear_memory


@register_dtype("int_sym_dq")
Expand Down Expand Up @@ -320,7 +321,7 @@ def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tens


@torch.no_grad()
def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None):
def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None, split_num=1):
super_bits = 4 if bits == 2 else 6
super_group_size = 16 if bits == 2 else 8
group_size = 16 if bits == 2 else 32
Expand Down Expand Up @@ -348,6 +349,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri
nstep=params["nstep"],
use_mad=params["use_mad"],
weights=quant_weights,
split_num=split_num,
)
scale = scale.to(scale_dtype)
scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale)
Expand Down Expand Up @@ -428,16 +430,8 @@ def quant_tensor_gguf_asym_dq(
Args:
tensor (torch.Tensor): Input tensor to quantize.
bits (int): Number of bits for quantization.
group_size (int): Group size for per-group quantization.
v (float): Perturbation added before rounding.
min_scale (float): Minimum allowed scale value.
max_scale (float): Maximum allowed scale value.
scale_dtype (torch.dtype): Data type for quantized scale.
tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
q_scale_thresh (float): Threshold to clamp the quantized scale.
super_group_size (int): Number of groups to bundle for secondary quantization.
super_bits (int): Number of bits used in secondary quantization.
imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.

Returns:
Expand All @@ -446,10 +440,19 @@ def quant_tensor_gguf_asym_dq(
orig_dtype = tensor.dtype
maxq = 2**bits - 1
group_size = 16 if bits == 2 else 32
split_num = 1
for dim in tensor.shape:
if dim > 100_000:
split_num = 16
break

tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)

tensor = tensor.to(torch.float32)
if scale is None:
scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(tensor, bits, scale_dtype, imatrix)
scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(
tensor, bits, scale_dtype, imatrix, split_num=split_num
)

inverse_scale = get_reciprocal(scale)
int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq)
Expand All @@ -458,7 +461,7 @@ def quant_tensor_gguf_asym_dq(
return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin}


def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None):
def iterative_wls_quant_search_non_chunk(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None):
"""Adapted from Llamacpp. Performs iterative weighted least squares quantization search.

Args:
Expand Down Expand Up @@ -526,6 +529,112 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
return scale.to(torch.float32), -rmin.to(torch.float32)


# TODO consolidate iterative_wls_quant_search_chunk and non-chunk
def iterative_wls_quant_search_chunk(
data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=8
):
dtype = torch.float32
data = data.to(dtype)
maxq = 2**bits - 1
minq = 0
weights = 1.0 if weights is None else weights.to(dtype)

results_scale = []
results_rmin = []
chunk_size = (data.shape[0] + split_num - 1) // split_num
for start in range(0, data.shape[0], chunk_size):
end = min(start + chunk_size, data.shape[0])
chunk = data[start:end]
chunk_weights = weights if isinstance(weights, float) else weights[start:end]

rmin = torch.min(chunk, dim=1, keepdim=True)[0]
rmax = torch.max(chunk, dim=1, keepdim=True)[0]
sum_w = torch.sum(chunk_weights, dim=1, keepdim=True)
sum_x = torch.sum(chunk_weights * chunk, dim=1, keepdim=True)
scale = (rmax - rmin) / (maxq - minq)
iscale = get_reciprocal(scale)
quant_data = torch.clamp(torch.round(iscale * (chunk - rmin)), minq, maxq)
diff = scale * quant_data + rmin - chunk
best_mad = torch.sum(
(chunk_weights * torch.abs(diff)) if use_mad else chunk_weights * torch.pow(diff, 2), dim=1, keepdim=True
)

for is_ in range(nstep):
factor = rrmin + rdelta * is_ + maxq - minq
scale_new = (rmax - rmin) / factor
iscale_new = get_reciprocal(scale_new)
quant_data_new = torch.clamp(torch.round(iscale_new * (chunk - rmin)), minq, maxq)
mul_weights_quant_data = chunk_weights * quant_data_new
sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True)
sum_l2 = torch.sum(mul_weights_quant_data * quant_data_new, dim=-1, keepdim=True)
sum_xl = torch.sum(mul_weights_quant_data * chunk, dim=-1, keepdim=True)
D = sum_w * sum_l2 - torch.pow(sum_l, 2)
this_scale = (sum_w * sum_xl - sum_x * sum_l) / D
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D
this_min[this_min > 0] = 0
this_scale[this_min > 0] = (sum_xl / sum_l2)[this_min > 0]
reverse_this_scale = get_reciprocal(this_scale)
quant_data = torch.clamp(torch.round(reverse_this_scale * (chunk - this_min)), minq, maxq)
diff = this_scale * quant_data + this_min - chunk
mad = torch.sum(
(chunk_weights * torch.abs(diff)) if use_mad else chunk_weights * torch.pow(diff, 2),
dim=-1,
keepdim=True,
)
idx_to_replace = torch.where((mad < best_mad) & (D > 0))[0]
best_mad[idx_to_replace] = mad[idx_to_replace]
scale[idx_to_replace] = this_scale[idx_to_replace]
rmin[idx_to_replace] = this_min[idx_to_replace]
results_scale.append(scale.to(torch.float32))
results_rmin.append(-rmin.to(torch.float32))
if split_num > 1:
clear_memory(device_list=[data.device])

return torch.cat(results_scale, dim=0), torch.cat(results_rmin, dim=0)


def iterative_wls_quant_search(
data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=1
):
"""Adapted from Llamacpp. Performs iterative weighted least squares quantization search.

Args:
data (torch.Tensor): Input tensor to quantize.
bits (int): Number of quantization bits.
rrmin (float): Initial range scaling factor.
rdelta (float): Step size for range scaling.
nstep (int): Number of search steps.
use_mad (bool): Whether to use mean absolute deviation instead of squared error.
weights (torch.Tensor): Weight matrix for each element.

Returns:
Tuple: (Optimal scale tensor, optimal minimum value tensor)
"""

# TODO this one should change to try catch later
if split_num > 1:
return iterative_wls_quant_search_chunk(
data=data,
bits=bits,
rrmin=rrmin,
rdelta=rdelta,
nstep=nstep,
use_mad=use_mad,
weights=weights,
split_num=split_num,
)
else:
return iterative_wls_quant_search_non_chunk(
data=data,
bits=bits,
rrmin=rrmin,
rdelta=rdelta,
nstep=nstep,
use_mad=use_mad,
weights=weights,
)


@torch.no_grad()
def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K
Expand All @@ -550,7 +659,6 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
return scale


#
@register_dtype("rtn_int_sym_dq")
def quant_tensor_gguf_sym_dq(
tensor,
Expand All @@ -566,7 +674,6 @@ def quant_tensor_gguf_sym_dq(
Args:
tensor: Tensor containing the tensor to be quantized
bits: Number of bits for quantization (e.g., 2, 3, 4, 8)
group_size: Number of elements to share scale for quantization
v: Rounding value perturbation
min_scale: Minimum scale coefficient for tensor
max_scale: Maximum scale coefficient for tensor
Expand Down
1 change: 0 additions & 1 deletion auto_round/data_type/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5
imatrix = 1.0
else:
imatrix = imatrix.reshape(1, -1)

imatrix = reshape_pad_tensor_by_group_size(imatrix, group_size, val=1e-5)[0].view(1, -1)
imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1)
imatrix = imatrix.reshape(tensor.shape)
Expand Down
7 changes: 0 additions & 7 deletions auto_round/export/export_to_awq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,3 @@ def extra_repr(self) -> str:
self.w_bit,
self.group_size,
)


def clear_memory(weight=None):
if weight is not None:
del weight
gc.collect()
torch.cuda.empty_cache()
4 changes: 3 additions & 1 deletion auto_round/export/export_to_gguf/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch

from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K
from auto_round.utils import get_reciprocal
from auto_round.utils import clear_memory, get_reciprocal

GGML_QUANT_TYPE = {}

Expand Down Expand Up @@ -66,6 +66,8 @@ def ggml_quant(
wmin = wmin.to(device) if wmin is not None else wmin
d_scale = d_scale.to(device) if d_scale is not None else d_scale
d_wmin = d_wmin.to(device) if d_wmin is not None else d_wmin
imatrix = imatrix.to(device) if imatrix is not None else imatrix
clear_memory()
new_data = quant_func(
blocks, scale, zp=zp, wmin=wmin, d_scale=d_scale, d_wmin=d_wmin, imatrix=imatrix, original=original
)
Expand Down
Loading