diff --git a/auto_round/auto_scheme/default_alg.abi3.so b/auto_round/auto_scheme/default_alg.abi3.so index 15abd5cf4..83da28647 100644 Binary files a/auto_round/auto_scheme/default_alg.abi3.so and b/auto_round/auto_scheme/default_alg.abi3.so differ diff --git a/auto_round/auto_scheme/gen_auto_scheme.py b/auto_round/auto_scheme/gen_auto_scheme.py index 9dcfe0d3c..9ed0b21fa 100644 --- a/auto_round/auto_scheme/gen_auto_scheme.py +++ b/auto_round/auto_scheme/gen_auto_scheme.py @@ -51,6 +51,7 @@ def __init__( if self.auto_scheme.enable_torch_compile is None else self.auto_scheme.enable_torch_compile ) + self.disable_opt_rtn = self.auto_scheme.disable_opt_rtn self._check_configs() def _check_configs(self) -> None: @@ -89,6 +90,7 @@ def get_layer_config(self) -> dict[str, dict]: self.tokenizer, device_map=self.device_map, enable_torch_compile=self.enable_torch_compile, + disable_opt_rtn=self.disable_opt_rtn, ) layer_config = self.fallback_gguf_layer_config(layer_config) return layer_config diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 3b5cf48d4..4f4d42dfa 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -463,8 +463,8 @@ def _gen_auto_scheme( # mainly using quant_layers and fixed by users from auto_round.auto_scheme.gen_auto_scheme import GenScheme - if self.enable_torch_compile is False: - logger.warning("we strongly recommend to enable torch compile for AutoScheme to save VRAM") + if not self.enable_torch_compile and self.super_bits is None: + logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM") gen_scheme = GenScheme( scheme, self.model, @@ -1275,14 +1275,12 @@ def get_imatrix_hook(module, input, output): if not hasattr(module, "imatrix"): module.imatrix = squared - module.imatrix_cnt = input.shape[0] else: module.imatrix += squared.to(module.imatrix.device) - module.imatrix_cnt += input.shape[0] hook_handles = [] for name, module in model.named_modules(): - if isinstance(module, self.supported_types) and check_to_quantized(module): + if type(module) in self.supported_types and check_to_quantized(module): hook = module.register_forward_hook(get_imatrix_hook) hook_handles.append(hook) return hook_handles @@ -1452,7 +1450,9 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: for module in tqdm(modules, desc="Update weight global scale for fuse module"): update_fused_layer_global_scales(module) - has_gguf_k = any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) + has_gguf_k = ( + any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None + ) self._quantize_embedding_layer() @@ -1595,8 +1595,6 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) set_amax_for_all_moe_layers(block, attr_name="act_max") # Normalize imatrix and quantize layers for _, m in block.named_modules(): - if hasattr(m, "imatrix"): - m.imatrix /= m.imatrix_cnt if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names: self._quantize_layer_via_rtn(m.tmp_name) all_to_quantized_module_names.remove(m.tmp_name) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 6aa19a3d5..4b4c51942 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -16,6 +16,8 @@ from auto_round.data_type.register import register_dtype from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste +from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES +from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants from auto_round.logger import logger from auto_round.utils import get_reciprocal @@ -283,48 +285,11 @@ def quant_tensor_asym_dq( return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} -@register_dtype("rtn_int_asym_dq") -def quant_tensor_gguf_asym_dq( - tensor, - bits=4, - v=0, - min_scale=1.0, - max_scale=1.0, - scale_dtype=torch.float16, - tensor_min=None, - tensor_max=None, - q_scale_thresh=1e-5, - imatrix=None, - **kwargs, -): - """Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K. - Only fit for iters 0 - - Args: - tensor (torch.Tensor): Input tensor to quantize. - bits (int): Number of bits for quantization. - group_size (int): Group size for per-group quantization. - v (float): Perturbation added before rounding. - min_scale (float): Minimum allowed scale value. - max_scale (float): Maximum allowed scale value. - scale_dtype (torch.dtype): Data type for quantized scale. - tensor_min (torch.Tensor, optional): Minimum values for the tensor groups. - tensor_max (torch.Tensor, optional): Maximum values for the tensor groups. - q_scale_thresh (float): Threshold to clamp the quantized scale. - super_group_size (int): Number of groups to bundle for secondary quantization. - super_bits (int): Number of bits used in secondary quantization. - imatrix (torch.Tensor, optional): Importance matrix for weighted quantization. - - Returns: - Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary) - """ - orig_dtype = tensor.dtype - maxq = 2**bits - 1 - group_size = 16 if bits == 2 else 32 +@torch.no_grad() +def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None): super_bits = 4 if bits == 2 else 6 super_group_size = 16 if bits == 2 else 8 - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - tensor = tensor.to(torch.float32) + group_size = 16 if bits == 2 else 32 if bits not in [2, 4, 5]: raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq") quant_weights = None @@ -430,8 +395,52 @@ def quant_tensor_gguf_asym_dq( d_wmin = d_wmin.unsqueeze(-1) scale = (d_scale * q_scale).view(-1, 1) wmin = (d_wmin * q_wmin).view(-1, 1) - inverse_scale = get_reciprocal(scale) + return scale, wmin, d_scale, d_wmin + +@register_dtype("rtn_int_asym_dq") +def quant_tensor_gguf_asym_dq( + tensor: torch.Tensor, + bits: int = 4, + v=0, + scale_dtype=torch.float16, + imatrix=None, + scale=None, + wmin=None, + d_scale=None, + d_wmin=None, + **kwargs, +): + """Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K. + Only fit for iters 0 + + Args: + tensor (torch.Tensor): Input tensor to quantize. + bits (int): Number of bits for quantization. + group_size (int): Group size for per-group quantization. + v (float): Perturbation added before rounding. + min_scale (float): Minimum allowed scale value. + max_scale (float): Maximum allowed scale value. + scale_dtype (torch.dtype): Data type for quantized scale. + tensor_min (torch.Tensor, optional): Minimum values for the tensor groups. + tensor_max (torch.Tensor, optional): Maximum values for the tensor groups. + q_scale_thresh (float): Threshold to clamp the quantized scale. + super_group_size (int): Number of groups to bundle for secondary quantization. + super_bits (int): Number of bits used in secondary quantization. + imatrix (torch.Tensor, optional): Importance matrix for weighted quantization. + + Returns: + Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary) + """ + orig_dtype = tensor.dtype + maxq = 2**bits - 1 + group_size = 16 if bits == 2 else 32 + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + tensor = tensor.to(torch.float32) + if scale is None: + scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(tensor, bits, scale_dtype, imatrix) + + inverse_scale = get_reciprocal(scale) int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq) qdq_result = (scale * int_w - wmin).to(orig_dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) @@ -506,18 +515,58 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u return scale.to(torch.float32), -rmin.to(torch.float32) +@torch.no_grad() +def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype): + from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K + + group_size = 16 + + if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0): + if bits == 3: + scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True) + ##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) + elif bits == 6: + scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) + else: + imatrix = imatrix.to(tensor.device) + weights = imatrix.reshape(1, -1) + weights = weights.expand(tensor.numel() // weights.numel(), -1) + quant_weights = weights.reshape(tensor.shape) + if torch.min(quant_weights) == 0: + logger.warning_once( + "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" + ) + zero_cnt = torch.sum(quant_weights == 0, dim=-1) + replace_index = zero_cnt > group_size // 2 + if torch.sum(replace_index) > 0: + if bits == 6: + quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index] + else: + sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K + tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor) + quant_weights[replace_index] = tmp_quant_weights[replace_index] + mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) + if torch.sum(mean_replace_index) > 0: + ## use mean values to fill zero values + tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt) + tmp_quant_weights = ( + tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape) + ) + quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index] + + scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights) + return scale + + +# @register_dtype("rtn_int_sym_dq") def quant_tensor_gguf_sym_dq( tensor, bits=3, - v=0, - min_scale=1.0, - max_scale=1.0, - scale_dtype=torch.float16, - tensor_min=None, - tensor_max=None, - q_scale_thresh=1e-5, imatrix=None, + scale=None, + d_scale=None, + scale_dtype=torch.float16, **kwargs, ): """Quantize and de-quantize tensor asymmetrically. For Q3_K, Q6_K. @@ -537,72 +586,28 @@ def quant_tensor_gguf_sym_dq( Returns: Quantized and de-quantized tensor, scale, zero-point """ - from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K - from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants + + from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K if bits not in [3, 6]: raise KeyError(f"bits={bits} is not supported by gguf_int_sym_dq, please check.") maxq = 2 ** (bits - 1) group_size = 16 + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + orig_dtype = tensor.dtype super_bits = 6 if bits == 3 else 8 super_group_size = 16 - - tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) ggml_type = f"q{bits}_k" block_size, type_size = GGML_QUANT_SIZES[ggml_type] - orig_dtype = tensor.dtype - tensor = tensor.to(torch.float32) n_blocks = tensor.nelement() // block_size # (nb, 16, 16) tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size) + if scale is None and d_scale is None: + scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype) - if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0): - if bits == 3: - scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True) - ##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) - elif bits == 6: - scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) - else: - imatrix = imatrix.to(tensor.device) - - # if bits == 3: - # # sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K - # # imatrix = imatrix.reshape(1, -1).expand(tensor.numel() // imatrix.numel(), -1).reshape(tensor.shape) - # # quant_weights = imatrix * torch.sqrt(sigma2 + tensor * tensor) - # # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights) - # weights = imatrix.reshape(1, -1) - # weights = weights.expand(tensor.numel() // weights.numel(), -1) - # quant_weights = weights.reshape(tensor.shape) - # elif bits == 6: - - weights = imatrix.reshape(1, -1) - weights = weights.expand(tensor.numel() // weights.numel(), -1) - quant_weights = weights.reshape(tensor.shape) - if torch.min(quant_weights) == 0: - logger.warning_once( - "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" - ) - zero_cnt = torch.sum(quant_weights == 0, dim=-1) - replace_index = zero_cnt > group_size // 2 - if torch.sum(replace_index) > 0: - if bits == 6: - quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index] - else: - sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K - tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor) - quant_weights[replace_index] = tmp_quant_weights[replace_index] - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) - if torch.sum(mean_replace_index) > 0: - ## use mean values to fill zero values - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt) - tmp_quant_weights = ( - tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape) - ) - quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index] - - scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights) + scale = scale.to(scale_dtype) scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) # conduct double quant scale, d_scale = double_quant_tensor_sym(scale, super_bits) @@ -610,7 +615,7 @@ def quant_tensor_gguf_sym_dq( scale = scale.unsqueeze(-1) zp = torch.full_like(scale, maxq) # pylint: disable=E1130 inverse_scale = get_reciprocal(scale) - int_w = torch.round(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq + int_w = round_ste(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq qdq_result = (scale * (int_w - zp)).to(orig_dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) diff --git a/auto_round/schemes.py b/auto_round/schemes.py index 701324aec..96faf62c0 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -299,6 +299,7 @@ class AutoScheme: dataset: Optional[str] = None # Import Notice no comma for each item device_map: Optional[Union[str, torch.device, int, dict]] = None enable_torch_compile: Optional[bool] = None + disable_opt_rtn: bool = True def __post_init__(self): if isinstance(self.options, str): diff --git a/docs/step_by_step.md b/docs/step_by_step.md index f8a576cc2..ba27ef5d1 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -351,15 +351,21 @@ ar.quantize_and_save() The tuning cost of AutoScheme is approximately 2 to 4 times that of model's bf16 size, depending on the options. We tested it on Nvidia A100 80G using torch v2.8. -| Models | Scheme | VRAM Cost
(torch compile) | Time Cost
(torch compile) | VRAM Cost
(w/o torch compile) | Time Cost
(w/o torch compile) | -| -------- | ----------------- | ---------------------------- | ----------------------------- | --------------------------------- | --------------------------------- | -| Qwen3-8B | W2A16 / W4A16 / W8A16 | 34G | 30s × len of options | 61G | 40s × len of options | -| Qwen3-8B | MXFP4 / MXFP8 | 36G | 60s × len of options | 54G | 120s × len of options | -| Qwen3-8B | GGUF* | 54G | 30s × len of options | 50G | 23s × len of options | +We will try to optimize the VRAM usage in the future. + +| Models | Scheme | VRAM Cost
(torch compile) | Time Cost
torch compile | VRAM Cost
wo torch compile | Time Cost
wo torch compile | +| --------- | ----------------- | ------------------------------- | ----------------------------- | -------------------------------- | -------------------------------- | +| Qwen3-8B | W2A16/W4A16/W8A16 | 34G | 30s * len of options | 61G | 40s * len of options | +| Qwen3-8B | MXFP4/MXFP8 | 36G | 60s * len of options | 54G | 120s * len of options | +| Qwen3-8B | GGUF* | 54G | 30s * len of options | 50G | 23S * len of options | +| Qwen3-32B | W2A16/W4A16/W8A16 | OOM with 240G | --- | OOM with 240G | --- | +| Qwen3-32B | MXFP4/MXFP8 | 160G | 200s * len of options | 200G | 240s * len of options | +| Qwen3-32B | GGUF* | 210G | 80s * len of options | 200G | 60s * len of options | + #### Limitations -Embedding layer is supported in AutoScheme, it will use the best scheme in options. +Embedding layer is not supported in AutoScheme, it will use the best scheme in options. ### RTN mode