diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 869fc74de..b6c236399 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1135,8 +1135,10 @@ def get_imatrix_hook(module, input, output): if not hasattr(module, "imatrix"): module.imatrix = squared + module.imatrix_cnt = input.shape[0] else: module.imatrix += squared.to(module.imatrix.device) + module.imatrix_cnt += input.shape[0] hook_handles = [] for name, module in model.named_modules(): @@ -1454,6 +1456,10 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) set_amax_for_all_moe_layers(block, attr_name="act_max") # Normalize imatrix and quantize layers for _, m in block.named_modules(): + # fix issue: Ling-flash-2.0-q2_k_s fail infer on cuda but well on cpu + # https://huggingface.co/Intel/Ling-flash-2.0-gguf-q2ks-mixed-AutoRound/discussions/1 + if hasattr(m, "imatrix"): + m.imatrix /= m.imatrix_cnt if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names: self._quantize_layer_via_rtn(m.tmp_name) all_to_quantized_module_names.remove(m.tmp_name) diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 6eb43e056..8a8fe602c 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -510,7 +510,7 @@ def gguf_args_check(args_or_ar, formats: list[str] = None, model_type=ModelType. if model_architecture not in ModelBase._model_classes[ModelType.TEXT]: logger.warning( f"Current version of gguf export does not support for {model_architecture}," - " will re-download dependency file." + " will re-download dependency file. Please restart the task." ) redownload = True except ModuleNotFoundError as e: diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 4b4c51942..f20c6c7a6 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable, Union import torch @@ -285,6 +286,39 @@ def quant_tensor_asym_dq( return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} +def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tensor, bits: int): + if not isinstance(imatrix, torch.Tensor): + return imatrix + + group_size = 16 if bits == 2 else 32 + imatrix = imatrix.reshape(-1, imatrix.shape[-1]) + if torch.min(imatrix) == 0: + logger.warning_once( + "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" + ) + + zero_cnt = torch.sum(imatrix == 0, dim=-1) + replace_index = zero_cnt > group_size // 2 + if torch.sum(replace_index) > 0: + ## fallback to no imatrix + if bits == 2: + tmp_quant_weights = torch.abs(weight) + elif bits == 4 or bits == 5: + sigma2 = torch.sum(torch.pow(weight, 2), dim=-1, keepdim=True) / 32 ## Note 32 is different from QK_K + av_x = torch.sqrt(sigma2) + tmp_quant_weights = torch.abs(weight) + av_x + tmp_quant_weights = tmp_quant_weights.to(imatrix.dtype) + imatrix[replace_index, :] = tmp_quant_weights[replace_index, :] + mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) + if torch.sum(mean_replace_index) > 0: + ## use mean values to fill zero values + tmp_quant_weights = torch.sum(imatrix, dim=-1) / (imatrix.shape[1] - zero_cnt) + tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, imatrix.shape[1]) + replace_idx = imatrix == 0 + imatrix[replace_idx] = tmp_quant_weights[replace_idx] + return imatrix.reshape(weight.shape) + + @torch.no_grad() def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None): super_bits = 4 if bits == 2 else 6 @@ -337,30 +371,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri weights = weights.expand(tensor.numel() // weights.numel(), -1) quant_weights = weights.reshape(tensor.shape) - if torch.min(quant_weights) == 0: - logger.warning_once( - "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" - ) - - zero_cnt = torch.sum(quant_weights == 0, dim=-1) - replace_index = zero_cnt > group_size // 2 - if torch.sum(replace_index) > 0: - ## fallback to no imatrix - if bits == 2: - tmp_quant_weights = torch.abs(tensor) - elif bits == 4 or bits == 5: - sigma2 = ( - torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / 32 - ) ## Note 32 is different from QK_K - av_x = torch.sqrt(sigma2) - tmp_quant_weights = torch.abs(tensor) + av_x - quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :] - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) - if torch.sum(mean_replace_index) > 0: - ## use mean values to fill zero values - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt) - tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]) - quant_weights[mean_replace_index, :] = tmp_quant_weights[mean_replace_index, :] + quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) # sigma2 = torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K # if imatrix is None: @@ -532,27 +543,8 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype): weights = imatrix.reshape(1, -1) weights = weights.expand(tensor.numel() // weights.numel(), -1) quant_weights = weights.reshape(tensor.shape) - if torch.min(quant_weights) == 0: - logger.warning_once( - "please use more data via setting `nsamples` to improve accuracy as calibration activations contain 0" - ) - zero_cnt = torch.sum(quant_weights == 0, dim=-1) - replace_index = zero_cnt > group_size // 2 - if torch.sum(replace_index) > 0: - if bits == 6: - quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index] - else: - sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K - tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor) - quant_weights[replace_index] = tmp_quant_weights[replace_index] - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) - if torch.sum(mean_replace_index) > 0: - ## use mean values to fill zero values - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt) - tmp_quant_weights = ( - tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape) - ) - quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index] + + quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights) return scale diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index bd0533581..699466dc8 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -17,6 +17,7 @@ from auto_round.data_type.register import register_dtype from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste +from auto_round.logger import logger from auto_round.utils import get_reciprocal @@ -44,7 +45,7 @@ def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, f @register_dtype("rtn_int_sym") -def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs): +def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5, imatrix=None, **kwargs): """Quantize and de-quantize tensor asymmetrically. full range, credict goes to llamacpp community Args: @@ -62,6 +63,7 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 Returns: Quantized and de-quantized tensor, scale, zero-point """ + from auto_round.data_type.gguf import _imatrix_handle_zero tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) maxq = 2 ** (bits - 1) @@ -73,6 +75,8 @@ def quant_tensor_rnt_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) imatrix = imatrix.reshape(tensor.shape) + imatrix = _imatrix_handle_zero(imatrix, tensor, bits) + scale = search_scales(tensor, bits, qw=imatrix) scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) int_w = round_ste(tensor / scale + v) diff --git a/auto_round/export/export_to_gguf/convert.py b/auto_round/export/export_to_gguf/convert.py index 5a7e803c5..03981db61 100644 --- a/auto_round/export/export_to_gguf/convert.py +++ b/auto_round/export/export_to_gguf/convert.py @@ -412,7 +412,7 @@ def prepare_tensors(cls): skip = False for tensor_info in cls.gguf_writer.tensors: if new_name in tensor_info: - logger.warning(f"{new_name} already add to gguf_writer, skip") + logger.info(f"{new_name} already add to gguf_writer, skip") skip = True break if skip: