diff --git a/auto_round/auto_scheme/default_alg.abi3.so b/auto_round/auto_scheme/default_alg.abi3.so
index 15abd5cf4..83da28647 100644
Binary files a/auto_round/auto_scheme/default_alg.abi3.so and b/auto_round/auto_scheme/default_alg.abi3.so differ
diff --git a/auto_round/auto_scheme/gen_auto_scheme.py b/auto_round/auto_scheme/gen_auto_scheme.py
index 9dcfe0d3c..9ed0b21fa 100644
--- a/auto_round/auto_scheme/gen_auto_scheme.py
+++ b/auto_round/auto_scheme/gen_auto_scheme.py
@@ -51,6 +51,7 @@ def __init__(
if self.auto_scheme.enable_torch_compile is None
else self.auto_scheme.enable_torch_compile
)
+ self.disable_opt_rtn = self.auto_scheme.disable_opt_rtn
self._check_configs()
def _check_configs(self) -> None:
@@ -89,6 +90,7 @@ def get_layer_config(self) -> dict[str, dict]:
self.tokenizer,
device_map=self.device_map,
enable_torch_compile=self.enable_torch_compile,
+ disable_opt_rtn=self.disable_opt_rtn,
)
layer_config = self.fallback_gguf_layer_config(layer_config)
return layer_config
diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py
index 3b5cf48d4..4f4d42dfa 100644
--- a/auto_round/compressors/base.py
+++ b/auto_round/compressors/base.py
@@ -463,8 +463,8 @@ def _gen_auto_scheme(
# mainly using quant_layers and fixed by users
from auto_round.auto_scheme.gen_auto_scheme import GenScheme
- if self.enable_torch_compile is False:
- logger.warning("we strongly recommend to enable torch compile for AutoScheme to save VRAM")
+ if not self.enable_torch_compile and self.super_bits is None:
+ logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM")
gen_scheme = GenScheme(
scheme,
self.model,
@@ -1275,14 +1275,12 @@ def get_imatrix_hook(module, input, output):
if not hasattr(module, "imatrix"):
module.imatrix = squared
- module.imatrix_cnt = input.shape[0]
else:
module.imatrix += squared.to(module.imatrix.device)
- module.imatrix_cnt += input.shape[0]
hook_handles = []
for name, module in model.named_modules():
- if isinstance(module, self.supported_types) and check_to_quantized(module):
+ if type(module) in self.supported_types and check_to_quantized(module):
hook = module.register_forward_hook(get_imatrix_hook)
hook_handles.append(hook)
return hook_handles
@@ -1452,7 +1450,9 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
for module in tqdm(modules, desc="Update weight global scale for fuse module"):
update_fused_layer_global_scales(module)
- has_gguf_k = any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", []))
+ has_gguf_k = (
+ any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None
+ )
self._quantize_embedding_layer()
@@ -1595,8 +1595,6 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
set_amax_for_all_moe_layers(block, attr_name="act_max")
# Normalize imatrix and quantize layers
for _, m in block.named_modules():
- if hasattr(m, "imatrix"):
- m.imatrix /= m.imatrix_cnt
if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names:
self._quantize_layer_via_rtn(m.tmp_name)
all_to_quantized_module_names.remove(m.tmp_name)
diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py
index 6aa19a3d5..4b4c51942 100644
--- a/auto_round/data_type/gguf.py
+++ b/auto_round/data_type/gguf.py
@@ -16,6 +16,8 @@
from auto_round.data_type.register import register_dtype
from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste
+from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES
+from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants
from auto_round.logger import logger
from auto_round.utils import get_reciprocal
@@ -283,48 +285,11 @@ def quant_tensor_asym_dq(
return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin}
-@register_dtype("rtn_int_asym_dq")
-def quant_tensor_gguf_asym_dq(
- tensor,
- bits=4,
- v=0,
- min_scale=1.0,
- max_scale=1.0,
- scale_dtype=torch.float16,
- tensor_min=None,
- tensor_max=None,
- q_scale_thresh=1e-5,
- imatrix=None,
- **kwargs,
-):
- """Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K.
- Only fit for iters 0
-
- Args:
- tensor (torch.Tensor): Input tensor to quantize.
- bits (int): Number of bits for quantization.
- group_size (int): Group size for per-group quantization.
- v (float): Perturbation added before rounding.
- min_scale (float): Minimum allowed scale value.
- max_scale (float): Maximum allowed scale value.
- scale_dtype (torch.dtype): Data type for quantized scale.
- tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
- tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
- q_scale_thresh (float): Threshold to clamp the quantized scale.
- super_group_size (int): Number of groups to bundle for secondary quantization.
- super_bits (int): Number of bits used in secondary quantization.
- imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.
-
- Returns:
- Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary)
- """
- orig_dtype = tensor.dtype
- maxq = 2**bits - 1
- group_size = 16 if bits == 2 else 32
+@torch.no_grad()
+def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None):
super_bits = 4 if bits == 2 else 6
super_group_size = 16 if bits == 2 else 8
- tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
- tensor = tensor.to(torch.float32)
+ group_size = 16 if bits == 2 else 32
if bits not in [2, 4, 5]:
raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq")
quant_weights = None
@@ -430,8 +395,52 @@ def quant_tensor_gguf_asym_dq(
d_wmin = d_wmin.unsqueeze(-1)
scale = (d_scale * q_scale).view(-1, 1)
wmin = (d_wmin * q_wmin).view(-1, 1)
- inverse_scale = get_reciprocal(scale)
+ return scale, wmin, d_scale, d_wmin
+
+@register_dtype("rtn_int_asym_dq")
+def quant_tensor_gguf_asym_dq(
+ tensor: torch.Tensor,
+ bits: int = 4,
+ v=0,
+ scale_dtype=torch.float16,
+ imatrix=None,
+ scale=None,
+ wmin=None,
+ d_scale=None,
+ d_wmin=None,
+ **kwargs,
+):
+ """Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K.
+ Only fit for iters 0
+
+ Args:
+ tensor (torch.Tensor): Input tensor to quantize.
+ bits (int): Number of bits for quantization.
+ group_size (int): Group size for per-group quantization.
+ v (float): Perturbation added before rounding.
+ min_scale (float): Minimum allowed scale value.
+ max_scale (float): Maximum allowed scale value.
+ scale_dtype (torch.dtype): Data type for quantized scale.
+ tensor_min (torch.Tensor, optional): Minimum values for the tensor groups.
+ tensor_max (torch.Tensor, optional): Maximum values for the tensor groups.
+ q_scale_thresh (float): Threshold to clamp the quantized scale.
+ super_group_size (int): Number of groups to bundle for secondary quantization.
+ super_bits (int): Number of bits used in secondary quantization.
+ imatrix (torch.Tensor, optional): Importance matrix for weighted quantization.
+
+ Returns:
+ Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary)
+ """
+ orig_dtype = tensor.dtype
+ maxq = 2**bits - 1
+ group_size = 16 if bits == 2 else 32
+ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
+ tensor = tensor.to(torch.float32)
+ if scale is None:
+ scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(tensor, bits, scale_dtype, imatrix)
+
+ inverse_scale = get_reciprocal(scale)
int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq)
qdq_result = (scale * int_w - wmin).to(orig_dtype)
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
@@ -506,18 +515,58 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
return scale.to(torch.float32), -rmin.to(torch.float32)
+@torch.no_grad()
+def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype):
+ from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K
+
+ group_size = 16
+
+ if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0):
+ if bits == 3:
+ scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True)
+ ##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
+ elif bits == 6:
+ scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
+ else:
+ imatrix = imatrix.to(tensor.device)
+ weights = imatrix.reshape(1, -1)
+ weights = weights.expand(tensor.numel() // weights.numel(), -1)
+ quant_weights = weights.reshape(tensor.shape)
+ if torch.min(quant_weights) == 0:
+ logger.warning_once(
+ "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
+ )
+ zero_cnt = torch.sum(quant_weights == 0, dim=-1)
+ replace_index = zero_cnt > group_size // 2
+ if torch.sum(replace_index) > 0:
+ if bits == 6:
+ quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
+ else:
+ sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
+ tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
+ quant_weights[replace_index] = tmp_quant_weights[replace_index]
+ mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
+ if torch.sum(mean_replace_index) > 0:
+ ## use mean values to fill zero values
+ tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt)
+ tmp_quant_weights = (
+ tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape)
+ )
+ quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index]
+
+ scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
+ return scale
+
+
+#
@register_dtype("rtn_int_sym_dq")
def quant_tensor_gguf_sym_dq(
tensor,
bits=3,
- v=0,
- min_scale=1.0,
- max_scale=1.0,
- scale_dtype=torch.float16,
- tensor_min=None,
- tensor_max=None,
- q_scale_thresh=1e-5,
imatrix=None,
+ scale=None,
+ d_scale=None,
+ scale_dtype=torch.float16,
**kwargs,
):
"""Quantize and de-quantize tensor asymmetrically. For Q3_K, Q6_K.
@@ -537,72 +586,28 @@ def quant_tensor_gguf_sym_dq(
Returns:
Quantized and de-quantized tensor, scale, zero-point
"""
- from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K
- from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants
+
+ from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K
if bits not in [3, 6]:
raise KeyError(f"bits={bits} is not supported by gguf_int_sym_dq, please check.")
maxq = 2 ** (bits - 1)
group_size = 16
+ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
+ orig_dtype = tensor.dtype
super_bits = 6 if bits == 3 else 8
super_group_size = 16
-
- tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size)
ggml_type = f"q{bits}_k"
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
- orig_dtype = tensor.dtype
-
tensor = tensor.to(torch.float32)
n_blocks = tensor.nelement() // block_size
# (nb, 16, 16)
tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size)
+ if scale is None and d_scale is None:
+ scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype)
- if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0):
- if bits == 3:
- scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True)
- ##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
- elif bits == 6:
- scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None)
- else:
- imatrix = imatrix.to(tensor.device)
-
- # if bits == 3:
- # # sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
- # # imatrix = imatrix.reshape(1, -1).expand(tensor.numel() // imatrix.numel(), -1).reshape(tensor.shape)
- # # quant_weights = imatrix * torch.sqrt(sigma2 + tensor * tensor)
- # # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
- # weights = imatrix.reshape(1, -1)
- # weights = weights.expand(tensor.numel() // weights.numel(), -1)
- # quant_weights = weights.reshape(tensor.shape)
- # elif bits == 6:
-
- weights = imatrix.reshape(1, -1)
- weights = weights.expand(tensor.numel() // weights.numel(), -1)
- quant_weights = weights.reshape(tensor.shape)
- if torch.min(quant_weights) == 0:
- logger.warning_once(
- "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0"
- )
- zero_cnt = torch.sum(quant_weights == 0, dim=-1)
- replace_index = zero_cnt > group_size // 2
- if torch.sum(replace_index) > 0:
- if bits == 6:
- quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index]
- else:
- sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K
- tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor)
- quant_weights[replace_index] = tmp_quant_weights[replace_index]
- mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2)
- if torch.sum(mean_replace_index) > 0:
- ## use mean values to fill zero values
- tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt)
- tmp_quant_weights = (
- tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape)
- )
- quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index]
-
- scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
+ scale = scale.to(scale_dtype)
scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale)
# conduct double quant
scale, d_scale = double_quant_tensor_sym(scale, super_bits)
@@ -610,7 +615,7 @@ def quant_tensor_gguf_sym_dq(
scale = scale.unsqueeze(-1)
zp = torch.full_like(scale, maxq) # pylint: disable=E1130
inverse_scale = get_reciprocal(scale)
- int_w = torch.round(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq
+ int_w = round_ste(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq
qdq_result = (scale * (int_w - zp)).to(orig_dtype)
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
diff --git a/auto_round/schemes.py b/auto_round/schemes.py
index 701324aec..96faf62c0 100644
--- a/auto_round/schemes.py
+++ b/auto_round/schemes.py
@@ -299,6 +299,7 @@ class AutoScheme:
dataset: Optional[str] = None # Import Notice no comma for each item
device_map: Optional[Union[str, torch.device, int, dict]] = None
enable_torch_compile: Optional[bool] = None
+ disable_opt_rtn: bool = True
def __post_init__(self):
if isinstance(self.options, str):
diff --git a/docs/step_by_step.md b/docs/step_by_step.md
index f8a576cc2..ba27ef5d1 100644
--- a/docs/step_by_step.md
+++ b/docs/step_by_step.md
@@ -351,15 +351,21 @@ ar.quantize_and_save()
The tuning cost of AutoScheme is approximately 2 to 4 times that of model's bf16 size, depending on the options.
We tested it on Nvidia A100 80G using torch v2.8.
-| Models | Scheme | VRAM Cost
(torch compile) | Time Cost
(torch compile) | VRAM Cost
(w/o torch compile) | Time Cost
(w/o torch compile) |
-| -------- | ----------------- | ---------------------------- | ----------------------------- | --------------------------------- | --------------------------------- |
-| Qwen3-8B | W2A16 / W4A16 / W8A16 | 34G | 30s × len of options | 61G | 40s × len of options |
-| Qwen3-8B | MXFP4 / MXFP8 | 36G | 60s × len of options | 54G | 120s × len of options |
-| Qwen3-8B | GGUF* | 54G | 30s × len of options | 50G | 23s × len of options |
+We will try to optimize the VRAM usage in the future.
+
+| Models | Scheme | VRAM Cost
(torch compile) | Time Cost
torch compile | VRAM Cost
wo torch compile | Time Cost
wo torch compile |
+| --------- | ----------------- | ------------------------------- | ----------------------------- | -------------------------------- | -------------------------------- |
+| Qwen3-8B | W2A16/W4A16/W8A16 | 34G | 30s * len of options | 61G | 40s * len of options |
+| Qwen3-8B | MXFP4/MXFP8 | 36G | 60s * len of options | 54G | 120s * len of options |
+| Qwen3-8B | GGUF* | 54G | 30s * len of options | 50G | 23S * len of options |
+| Qwen3-32B | W2A16/W4A16/W8A16 | OOM with 240G | --- | OOM with 240G | --- |
+| Qwen3-32B | MXFP4/MXFP8 | 160G | 200s * len of options | 200G | 240s * len of options |
+| Qwen3-32B | GGUF* | 210G | 80s * len of options | 200G | 60s * len of options |
+
#### Limitations
-Embedding layer is supported in AutoScheme, it will use the best scheme in options.
+Embedding layer is not supported in AutoScheme, it will use the best scheme in options.
### RTN mode