From 9c75983a10b5f6e3c8b9fe5ce459817220004e76 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 28 Oct 2025 03:25:55 -0400 Subject: [PATCH 01/13] fix bug of imatrix contains 0 Signed-off-by: n1ck-guo --- auto_round/data_type/gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 4b4c51942..b7f76a43e 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -360,7 +360,8 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri ## use mean values to fill zero values tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt) tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]) - quant_weights[mean_replace_index, :] = tmp_quant_weights[mean_replace_index, :] + replace_idx = quant_weights == 0 + quant_weights[replace_idx] = tmp_quant_weights[replace_idx] # sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K # if imatrix is None: From 09f2a5f9c58622a81d1aeed3df8ca1ac69aed9d1 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 28 Oct 2025 04:09:18 -0400 Subject: [PATCH 02/13] fix Signed-off-by: n1ck-guo --- auto_round/compressors/utils.py | 2 +- auto_round/export/export_to_gguf/convert.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 6eb43e056..8a8fe602c 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -510,7 +510,7 @@ def gguf_args_check(args_or_ar, formats: list[str] = None, model_type=ModelType. if model_architecture not in ModelBase._model_classes[ModelType.TEXT]: logger.warning( f"Current version of gguf export does not support for {model_architecture}," - " will re-download dependency file." + " will re-download dependency file. Please restart the task." ) redownload = True except ModuleNotFoundError as e: diff --git a/auto_round/export/export_to_gguf/convert.py b/auto_round/export/export_to_gguf/convert.py index a03921624..03981db61 100644 --- a/auto_round/export/export_to_gguf/convert.py +++ b/auto_round/export/export_to_gguf/convert.py @@ -412,7 +412,7 @@ def prepare_tensors(cls): skip = False for tensor_info in cls.gguf_writer.tensors: if new_name in tensor_info: - print("new_name already add to gguf_writer, skip") + logger.info(f"{new_name} already add to gguf_writer, skip") skip = True break if skip: From 2abee8297bfc0a035fc95044709f1b4566a6b506 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 28 Oct 2025 21:05:17 -0400 Subject: [PATCH 03/13] update Signed-off-by: n1ck-guo --- auto_round/data_type/int.py | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index fb6c5f3db..8555114a9 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -17,6 +17,7 @@ from auto_round.data_type.register import register_dtype from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste +from auto_round.logger import logger from auto_round.utils import get_reciprocal @@ -71,9 +72,34 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 imatrix = imatrix.reshape(1, -1) imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) - imatrix = imatrix.reshape(tensor.shape) - - scale = search_scales(tensor, bits, qw=imatrix) + quant_weights = imatrix.reshape(tensor.shape) + + if torch.min(quant_weights) == 0: + logger.warning_once( + "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" + ) + zero_cnt = torch.sum(quant_weights == 0, dim=-1) + replace_index = zero_cnt > group_size // 2 + if torch.sum(replace_index) > 0: + ## fallback to no imatrix + if bits == 2: + tmp_quant_weights = torch.abs(tensor) + elif bits == 4 or bits == 5: + sigma2 = ( + torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32 + ) ## Note 32 is different from QK_K + av_x = torch.sqrt(sigma2) + tmp_quant_weights = torch.abs(tensor) + av_x + quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :] + mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) + if torch.sum(mean_replace_index) > 0: + ## use mean values to fill zero values + tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt) + tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]) + replace_idx = quant_weights == 0 + quant_weights[replace_idx] = tmp_quant_weights[replace_idx] + + scale = search_scales(tensor, bits, qw=quant_weights) scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) int_w = round_ste(tensor / scale + v) q = torch.clamp(int_w, -maxq, maxq - 1) From a91144b1a2308e3072d54c39435b3f4cc5fada3d Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 28 Oct 2025 21:42:59 -0400 Subject: [PATCH 04/13] extract func Signed-off-by: n1ck-guo --- auto_round/data_type/gguf.py | 54 +++++++++++++++++++----------------- auto_round/data_type/int.py | 34 ++++------------------- 2 files changed, 35 insertions(+), 53 deletions(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index b7f76a43e..7dccc51d2 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -285,6 +285,34 @@ def quant_tensor_asym_dq( return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} +def _imatrix_handle_zero(imatrix, weight, bits): + group_size = 16 if bits == 2 else 32 + if torch.min(imatrix) == 0: + logger.warning_once( + "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" + ) + + zero_cnt = torch.sum(imatrix == 0, dim=-1) + replace_index = zero_cnt > group_size // 2 + if torch.sum(replace_index) > 0: + ## fallback to no imatrix + if bits == 2: + tmp_quant_weights = torch.abs(weight) + elif bits == 4 or bits == 5: + sigma2 = torch.sum(torch.pow(weight, 2), dim=-1, keepdim=True) / 32 ## Note 32 is different from QK_K + av_x = torch.sqrt(sigma2) + tmp_quant_weights = torch.abs(weight) + av_x + imatrix[replace_index, :] = tmp_quant_weights[replace_index, :] + mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) + if torch.sum(mean_replace_index) > 0: + ## use mean values to fill zero values + tmp_quant_weights = torch.sum(imatrix, dim=-1) / (imatrix.shape[1] - zero_cnt) + tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, imatrix.shape[1]) + replace_idx = imatrix == 0 + imatrix[replace_idx] = tmp_quant_weights[replace_idx] + return imatrix + + @torch.no_grad() def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None): super_bits = 4 if bits == 2 else 6 @@ -337,31 +365,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri weights = weights.expand(tensor.numel() // weights.numel(), -1) quant_weights = weights.reshape(tensor.shape) - if torch.min(quant_weights) == 0: - logger.warning_once( - "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" - ) - - zero_cnt = torch.sum(quant_weights == 0, dim=-1) - replace_index = zero_cnt > group_size // 2 - if torch.sum(replace_index) > 0: - ## fallback to no imatrix - if bits == 2: - tmp_quant_weights = torch.abs(tensor) - elif bits == 4 or bits == 5: - sigma2 = ( - torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32 - ) ## Note 32 is different from QK_K - av_x = torch.sqrt(sigma2) - tmp_quant_weights = torch.abs(tensor) + av_x - quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :] - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) - if torch.sum(mean_replace_index) > 0: - ## use mean values to fill zero values - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt) - tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]) - replace_idx = quant_weights == 0 - quant_weights[replace_idx] = tmp_quant_weights[replace_idx] + _imatrix_handle_zero(quant_weights, tensor, bits) # sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K # if imatrix is None: diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 8555114a9..dc68aacc3 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -63,6 +63,7 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 Returns: Quantized and de-quantized tensor, scale, zero-point """ + from auto_round.data_type.gguf import _imatrix_handle_zero tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = 2 ** (bits - 1) @@ -72,34 +73,11 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 imatrix = imatrix.reshape(1, -1) imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) - quant_weights = imatrix.reshape(tensor.shape) - - if torch.min(quant_weights) == 0: - logger.warning_once( - "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" - ) - zero_cnt = torch.sum(quant_weights == 0, dim=-1) - replace_index = zero_cnt > group_size // 2 - if torch.sum(replace_index) > 0: - ## fallback to no imatrix - if bits == 2: - tmp_quant_weights = torch.abs(tensor) - elif bits == 4 or bits == 5: - sigma2 = ( - torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32 - ) ## Note 32 is different from QK_K - av_x = torch.sqrt(sigma2) - tmp_quant_weights = torch.abs(tensor) + av_x - quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :] - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) - if torch.sum(mean_replace_index) > 0: - ## use mean values to fill zero values - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt) - tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]) - replace_idx = quant_weights == 0 - quant_weights[replace_idx] = tmp_quant_weights[replace_idx] - - scale = search_scales(tensor, bits, qw=quant_weights) + imatrix = imatrix.reshape(tensor.shape) + + imatrix = _imatrix_handle_zero(imatrix, tensor, bits) + + scale = search_scales(tensor, bits, qw=imatrix) scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) int_w = round_ste(tensor / scale + v) q = torch.clamp(int_w, -maxq, maxq - 1) From 7d2c3db591a6441d263bfb0690da69df1be5faf9 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Tue, 28 Oct 2025 21:46:36 -0400 Subject: [PATCH 05/13] fix Signed-off-by: n1ck-guo --- auto_round/data_type/gguf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 7dccc51d2..84ff0e187 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -286,6 +286,9 @@ def quant_tensor_asym_dq( def _imatrix_handle_zero(imatrix, weight, bits): + if not isinstance(imatrix, torch.Tensor): + return imatrix + group_size = 16 if bits == 2 else 32 if torch.min(imatrix) == 0: logger.warning_once( From 06d06283cb0f67df937d5bab3e0e37cb43bdb010 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Wed, 29 Oct 2025 01:59:59 -0400 Subject: [PATCH 06/13] fix Signed-off-by: n1ck-guo --- auto_round/data_type/gguf.py | 1 + auto_round/data_type/int.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 84ff0e187..df8ccda21 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -305,6 +305,7 @@ def _imatrix_handle_zero(imatrix, weight, bits): sigma2 = torch.sum(torch.pow(weight, 2), dim=-1, keepdim=True) / 32 ## Note 32 is different from QK_K av_x = torch.sqrt(sigma2) tmp_quant_weights = torch.abs(weight) + av_x + tmp_quant_weights = tmp_quant_weights.to(imatrix.dtype) imatrix[replace_index, :] = tmp_quant_weights[replace_index, :] mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) if torch.sum(mean_replace_index) > 0: diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index dc68aacc3..8de278d0a 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -75,7 +75,7 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) imatrix = imatrix.reshape(tensor.shape) - imatrix = _imatrix_handle_zero(imatrix, tensor, bits) + imatrix = _imatrix_handle_zero(imatrix, tensor, bits) scale = search_scales(tensor, bits, qw=imatrix) scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) From 383201bda894fc8411bf6cba367fd0e4d9dcbffe Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Wed, 29 Oct 2025 02:33:30 -0400 Subject: [PATCH 07/13] change sym Signed-off-by: n1ck-guo --- auto_round/data_type/gguf.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index df8ccda21..e9453129f 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -541,27 +541,8 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype): weights = imatrix.reshape(1, -1) weights = weights.expand(tensor.numel() // weights.numel(), -1) quant_weights = weights.reshape(tensor.shape) - if torch.min(quant_weights) == 0: - logger.warning_once( - "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" - ) - zero_cnt = torch.sum(quant_weights == 0, dim=-1) - replace_index = zero_cnt > group_size // 2 - if torch.sum(replace_index) > 0: - if bits == 6: - quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index] - else: - sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K - tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor) - quant_weights[replace_index] = tmp_quant_weights[replace_index] - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) - if torch.sum(mean_replace_index) > 0: - ## use mean values to fill zero values - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt) - tmp_quant_weights = ( - tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape) - ) - quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index] + + quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights) return scale From 21e0807d887ecf9df86b46d8eb9085319e56efa0 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Wed, 29 Oct 2025 02:34:44 -0400 Subject: [PATCH 08/13] correct name Signed-off-by: n1ck-guo --- auto_round/data_type/int.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 8de278d0a..6cbb6cd52 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -45,7 +45,7 @@ def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, f @register_dtype("rtn_int_sym") -def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs): +def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs): """Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community Args: From 39dcf7081afb1528ef559fa9857da71bd9a3aac9 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Wed, 29 Oct 2025 20:32:04 -0400 Subject: [PATCH 09/13] fix Signed-off-by: n1ck-guo --- auto_round/data_type/gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index e9453129f..abbd53d03 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -290,6 +290,7 @@ def _imatrix_handle_zero(imatrix, weight, bits): return imatrix group_size = 16 if bits == 2 else 32 + imatrix = imatrix.reshape(-1, imatrix.shape[-1]) if torch.min(imatrix) == 0: logger.warning_once( "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" @@ -314,7 +315,7 @@ def _imatrix_handle_zero(imatrix, weight, bits): tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, imatrix.shape[1]) replace_idx = imatrix == 0 imatrix[replace_idx] = tmp_quant_weights[replace_idx] - return imatrix + return imatrix.reshape(weight.shape) @torch.no_grad() From 35c052a52110cafad1f5cd6ca0a664293f420190 Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Wed, 29 Oct 2025 21:44:28 -0400 Subject: [PATCH 10/13] fix ling-flash-2.0 infer Signed-off-by: n1ck-guo --- auto_round/compressors/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 9608e6ee4..48d3687af 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1267,8 +1267,10 @@ def get_imatrix_hook(module, input, output): if not hasattr(module, "imatrix"): module.imatrix = squared + module.imatrix_cnt = input.shape[0] else: module.imatrix += squared.to(module.imatrix.device) + module.imatrix_cnt += input.shape[0] hook_handles = [] for name, module in model.named_modules(): @@ -1593,6 +1595,9 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) set_amax_for_all_moe_layers(block, attr_name="act_max") # Normalize imatrix and quantize layers for _, m in block.named_modules(): + # fix issue: Ling-flash-2.0-q2_k_s fail infer on cuda but well on cpu + if hasattr(m, "imatrix"): + m.imatrix /= m.imatrix_cnt if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names: self._quantize_layer_via_rtn(m.tmp_name) all_to_quantized_module_names.remove(m.tmp_name) From b1f3d9f6d4cbf2f361445ec864b033e0225134fe Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Wed, 29 Oct 2025 21:51:12 -0400 Subject: [PATCH 11/13] add issue link Signed-off-by: n1ck-guo --- auto_round/compressors/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 48d3687af..54bde158d 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1596,6 +1596,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) # Normalize imatrix and quantize layers for _, m in block.named_modules(): # fix issue: Ling-flash-2.0-q2_k_s fail infer on cuda but well on cpu + # https://huggingface.co/Intel/Ling-flash-2.0-gguf-q2ks-mixed-AutoRound/discussions/1 if hasattr(m, "imatrix"): m.imatrix /= m.imatrix_cnt if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names: From d6a433a3163e3d63b6a71552fbe43a3196fd255b Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Thu, 30 Oct 2025 00:42:48 -0400 Subject: [PATCH 12/13] update Signed-off-by: n1ck-guo --- auto_round/data_type/gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index abbd53d03..1df1a1cb3 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -315,7 +315,7 @@ def _imatrix_handle_zero(imatrix, weight, bits): tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, imatrix.shape[1]) replace_idx = imatrix == 0 imatrix[replace_idx] = tmp_quant_weights[replace_idx] - return imatrix.reshape(weight.shape) + return imatrix.reshape(weight.shape) @torch.no_grad() @@ -370,7 +370,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri weights = weights.expand(tensor.numel() // weights.numel(), -1) quant_weights = weights.reshape(tensor.shape) - _imatrix_handle_zero(quant_weights, tensor, bits) + quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) # sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K # if imatrix is None: From 1bc765f7fb5db1eeaf90709327dde9c6c04af9ac Mon Sep 17 00:00:00 2001 From: n1ck-guo Date: Thu, 30 Oct 2025 04:12:25 -0400 Subject: [PATCH 13/13] type annotation Signed-off-by: n1ck-guo --- auto_round/data_type/gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 1df1a1cb3..f20c6c7a6 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable, Union import torch @@ -285,7 +286,7 @@ def quant_tensor_asym_dq( return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} -def _imatrix_handle_zero(imatrix, weight, bits): +def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tensor, bits: int): if not isinstance(imatrix, torch.Tensor): return imatrix