diff --git a/README.md b/README.md index b107b3d0a..a6289e2f9 100644 --- a/README.md +++ b/README.md @@ -189,14 +189,14 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round") Important Hyperparameters ##### Quantization Scheme & Configuration -- **`scheme` (str|dict|AutoScheme)**: The predefined quantization keys, e.g. `W4A16`, `MXFP4`, `NVFP4`, `GGUF:Q4_K_M`. +- **`scheme` (str|dict|AutoScheme)**: The predefined quantization keys, e.g. `W4A16`, `MXFP4`, `NVFP4`, `GGUF:Q4_K_M`. For MXFP4/NVFP4, we recommend exporting to LLM-Compressor format. - **`bits` (int)**: Number of bits for quantization (default is `None`). If not None, it will override the scheme setting. - **`group_size` (int)**: Size of the quantization group (default is `None`). If not None, it will override the scheme setting. - **`sym` (bool)**: Whether to use symmetric quantization (default is `None`). If not None, it will override the scheme setting. -- **`layer_config` (dict)**: Configuration for weight quantization (default is `None`), mainly for mixed schemes. +- **`layer_config` (dict)**: Configuration for layer_wise scheme (default is `None`), mainly for customized mixed schemes. ##### Algorithm Settings -- **`enable_alg_ext` (bool)**: [Experimental Feature] Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`. +- **`enable_alg_ext` (bool)**: [Experimental Feature] Only for `iters>0`. Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`. - **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `False` (improved RTN enabled). ##### Tuning Process Parameters @@ -217,7 +217,8 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round") -### AutoScheme Usage +### Adaptive Bits/Dtype Usage +AutoScheme provide automatically algorithm to provide mixed bits/data_type quantization recipes. For some accuracy result, please refer to this [doc](https://github.com/intel/auto-round/blob/main/docs/auto_scheme_acc.md). Please refer to the [user guide](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md#autoscheme) for more details on AutoScheme. ~~~python from auto_round import AutoRound, AutoScheme @@ -299,7 +300,7 @@ for output in outputs: ### SGLang (Intel GPU/CUDA) -Please note that support for the MoE models and visual language models is currently limited. +**Please note that support for the MoE models and visual language models is currently limited.** ```python import sglang as sgl diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 31683eb7c..62c2f6a32 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -713,11 +713,10 @@ def _check_compatibility(self) -> None: raise ValueError("Gguf format is not compatible with other formats, please choose only one of them") if has_gguf and self.iters != 0 and self.bits != 3 and not self.enable_alg_ext: logger.warning( - "`iters=0` is recommended when exporting to GGUF format except for bits 3," - " as we have optimized the RTN method for this case." - " Or add enable_alg_ext to use the new algorithm," - " refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md" - " to check the acc." + "`iters=0` is recommended when exporting to current GGUF format" + " or add `enable_alg_ext` for better accuracy with much more tuning cost." + " Please refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md" + " for the accuracy results." ) if ( @@ -1087,11 +1086,16 @@ def _quantize_embedding_layer(self): dtype = f"rtn_{dtype}" quant_func = QUANT_FUNC_WITH_DTYPE[dtype] + dtype = module.weight.dtype + # As typically float32 are used in RTN to search scale zp, + # to avoid cache a bf16 copy we'd better use float32 + if config["super_group_size"] is not None: + dtype = torch.float32 # Attempt quantization on GPU, fall back to CPU if OOM try: weight, scale, zp = quant_func( - module.weight.to(self.device), + module.weight.to(dtype=dtype, device=self.device), **{k: config[k] for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]}, ) except torch.OutOfMemoryError: @@ -1124,8 +1128,9 @@ def _quantize_embedding_layer(self): # Update config self.layer_config.setdefault(name, {}).update(config) - - # Release memory + del weight + del scale + del zp clear_memory(device_list=self.device_list) return is_quantized @@ -1224,7 +1229,7 @@ def get_imatrix_hook(module, input, output): for hook in hooks: hook.remove() - def _quantize_layer_via_rtn(self, name: str) -> None: + def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=True) -> None: """Quantizes a layer using RTN (Round-To-Nearest) if available. This function attempts to quantize a layer by switching its data type to a @@ -1241,19 +1246,20 @@ def _quantize_layer_via_rtn(self, name: str) -> None: RuntimeError: If quantization fails for reasons unrelated to memory. """ m = get_module(self.model, name) + if dtype is not None: + m = m.to(dtype) if is_fp8_linear(m): m = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device) set_module(self.model, name, m) - + tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.device # Step 1: Try quantization on GPU first, fall back to CPU if OOM - # if only export gguf, using gguf-packing instead of rtn if self.immediate_packing and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn: + m = m.to(tuning_device) m.scale = None m.zp = None else: try: - tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.device m = m.to(tuning_device) m = WrapperLinear( m, @@ -1265,7 +1271,6 @@ def _quantize_layer_via_rtn(self, name: str) -> None: disable_opt_rtn=self.disable_opt_rtn, ) m = m.unwrapper({}) - m.to("cpu") except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() m = m.orig_layer if hasattr(m, "orig_layer") else m @@ -1285,11 +1290,14 @@ def _quantize_layer_via_rtn(self, name: str) -> None: raise # Step 2: Optional immediate packing/export - if self.immediate_packing: + if self.immediate_packing: # For gguf, packing conducts on block level self._immediate_pack(name) + if to_cpu: + m = m.to("cpu") else: + if to_cpu: + m = m.to("cpu") set_module(self.model, name, m) - if self.immediate_saving: all_to_quantized_module_names = [n for n, m in self.model.named_modules() if check_to_quantized(m)] last_module = (len(all_to_quantized_module_names) == 0) or (name == all_to_quantized_module_names[-1]) @@ -1297,6 +1305,8 @@ def _quantize_layer_via_rtn(self, name: str) -> None: immediate_saving(self, m, name, last_module) def _immediate_pack(self, name: str): + if not self.immediate_packing: + return m = get_module(self.model, name) if not check_to_quantized(m): return @@ -1353,16 +1363,18 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: for module in tqdm(modules, desc="Update weight global scale for fuse module"): update_fused_layer_global_scales(module) - has_gguf_k = ( - any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None - ) - - self._quantize_embedding_layer() + if not (any("gguf" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None): + self._quantize_embedding_layer() # leave to gguf itself to handle self.model.to("cpu") + # Release memory + clear_memory(device_list=self.device_list) enable_imatrix = False if not self.disable_opt_rtn: + has_gguf_k = ( + any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None + ) if has_gguf_k: enable_imatrix = True elif self.data_type == "int" and self.sym: @@ -1498,6 +1510,7 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) self.device, self.cache_device, ) + if len(self.device_list) > 1: accelerate.hooks.remove_hook_from_submodules(block) @@ -1505,32 +1518,36 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str]) # enable moe experts act_max automatic generation for Linear set_amax_for_all_moe_layers(block, attr_name="act_max") # Normalize imatrix and quantize layers + if self.low_gpu_mem_usage: + block.to("cpu") + clear_memory(device_list=self.device_list) + for _, m in block.named_modules(): # fix issue: Ling-flash-2.0-q2_k_s fail infer on cuda but well on cpu # https://huggingface.co/Intel/Ling-flash-2.0-gguf-q2ks-mixed-AutoRound/discussions/1 if hasattr(m, "imatrix"): m.imatrix /= m.imatrix_cnt if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names: - self._quantize_layer_via_rtn(m.tmp_name) + self._quantize_layer_via_rtn(m.tmp_name, to_cpu=False) all_to_quantized_module_names.remove(m.tmp_name) if not self.immediate_saving: mv_module_from_gpu(block) + if block_name == block_names[-1]: + clear_memory(input_ids, device_list=self.device_list) + else: + clear_memory(device_list=self.device_list) + memory_monitor.log_summary() pbar.update(1) pbar.close() - cnt = 1 - block_names_cnt = len(flatten_list(get_block_names(self.model, True))) - clear_mem_freq = len(all_to_quantized_module_names) // block_names_cnt - if clear_mem_freq == 0: - clear_mem_freq = 1 # Process remaining layers not in blocks for name in all_to_quantized_module_names: - self._quantize_layer_via_rtn(name) - if cnt % clear_mem_freq == 0: - clear_memory(device_list=self.device_list) - cnt = 1 - cnt += 1 + dtype = None + if self.super_group_size is not None: + dtype = torch.float32 + self._quantize_layer_via_rtn(name, dtype=dtype) + # clear_memory(device_list=self.device_list) def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tensor]: keys = inputs.keys() @@ -1631,6 +1648,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: logger.info("start to cache block inputs") all_inputs = self.try_cache_inter_data_gpucpu(all_first_block_names, self.nsamples, layer_names=layer_names) is_quantized_embedding = self._quantize_embedding_layer() + clear_memory(device_list=self.device_list) all_q_inputs = None if is_quantized_embedding: all_inputs = copy.deepcopy(self.inputs) @@ -2838,7 +2856,7 @@ def _quantize_block( if auto_offload: mv_module_from_gpu(block) - clear_memory(input_ids) + clear_memory(input_ids, device_list=self.device_list) memory_info_summary = memory_monitor.get_summary() logger.infoclean(dump_info + "," + memory_info_summary) @@ -2848,7 +2866,7 @@ def _quantize_block( accelerate.hooks.remove_hook_from_submodules(block) if auto_offload: mv_module_from_gpu(block) - clear_memory(input_ids) + clear_memory(input_ids, device_list=self.device_list) memory_info_summary = memory_monitor.get_summary() logger.infoclean(dump_info + "," + memory_info_summary) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 577ccf34e..27a97f010 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -18,7 +18,7 @@ from auto_round.data_type.register import register_dtype from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad, round_ste from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES -from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants +from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants, make_qx_quants_chunk from auto_round.logger import logger from auto_round.utils import get_reciprocal from auto_round.utils.device import clear_memory @@ -165,6 +165,37 @@ def double_quant_tensor_sym(tensor, bits): return qdq_tensor, scale +def double_quant_tensor_sym_rtn(tensor, bits): + """ + Inplace-optimized symmetric double quantization. + - Uses float32 inplace where possible + - Minimizes temporary tensor allocations + """ + # Ensure tensor is float32 inplace (if tensor already float32, no copy) + if tensor.dtype != torch.float32: + tensor = tensor.float() # .float() creates a copy if needed + + maxq = 2 ** (bits - 1) + + # Compute absolute max along last dim + # abs_() is inplace + tensor_abs = tensor.abs() # cannot inplace abs on original if we need original sign + imax = tensor_abs.argmax(dim=-1, keepdim=True) + wmax = torch.take_along_dim(tensor, imax, dim=-1) + + # Compute scale inplace + scale = wmax / -maxq + inverse_scale = get_reciprocal(scale) + + # Inplace quantization + tensor = tensor.mul_(inverse_scale) # tensor * inverse_scale inplace + tensor = tensor.round_() # round inplace + tensor.clamp_(-maxq, maxq - 1) # clamp inplace + tensor.mul_(scale) # multiply scale inplace + + return tensor, scale + + def make_qp_quants(nmax, data, quant_weights): data = data.to(torch.float32) quant_weights = quant_weights.to(torch.float32) @@ -320,13 +351,11 @@ def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tens return imatrix.reshape(weight.shape) -@torch.no_grad() +@torch.inference_mode() def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None, split_num=1): super_bits = 4 if bits == 2 else 6 super_group_size = 16 if bits == 2 else 8 - group_size = 16 if bits == 2 else 32 - if bits not in [2, 4, 5]: - raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq") + quant_weights = None if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0): search_kwargs = { @@ -408,6 +437,8 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri d_wmin = d_wmin.unsqueeze(-1) scale = (d_scale * q_scale).view(-1, 1) wmin = (d_wmin * q_wmin).view(-1, 1) + if split_num > 1: + clear_memory(device_list=[tensor.device]) return scale, wmin, d_scale, d_wmin @@ -415,7 +446,6 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri def quant_tensor_gguf_asym_dq( tensor: torch.Tensor, bits: int = 4, - v=0, scale_dtype=torch.float16, imatrix=None, scale=None, @@ -437,14 +467,12 @@ def quant_tensor_gguf_asym_dq( Returns: Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary) """ + if bits not in [2, 4, 5]: + raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq") orig_dtype = tensor.dtype maxq = 2**bits - 1 group_size = 16 if bits == 2 else 32 split_num = 1 - for dim in tensor.shape: - if dim > 100_000: - split_num = 16 - break tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) @@ -455,83 +483,17 @@ def quant_tensor_gguf_asym_dq( ) inverse_scale = get_reciprocal(scale) - int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq) - qdq_result = (scale * int_w - wmin).to(orig_dtype) - qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} - - -def iterative_wls_quant_search_non_chunk(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None): - """Adapted from Llamacpp. Performs iterative weighted least squares quantization search. - - Args: - data (torch.Tensor): Input tensor to quantize. - bits (int): Number of quantization bits. - rrmin (float): Initial range scaling factor. - rdelta (float): Step size for range scaling. - nstep (int): Number of search steps. - use_mad (bool): Whether to use mean absolute deviation instead of squared error. - weights (torch.Tensor): Weight matrix for each element. - - Returns: - Tuple: (Optimal scale tensor, optimal minimum value tensor) - """ - dtype = torch.float32 - data = data.to(dtype) - maxq = 2**bits - 1 - minq = 0 - weights = 1.0 if weights is None else weights.to(dtype) - - rmin = torch.min(data, dim=1, keepdim=True)[0] - rmax = torch.max(data, dim=1, keepdim=True)[0] - - sum_w = torch.sum(weights, dim=1, keepdim=True) - sum_x = torch.sum(weights * data, dim=1, keepdim=True) - - # scale = 1 / ((maxq - minq) / (rmax - rmin + 1e-8)) - scale = (rmax - rmin) / (maxq - minq) - iscale = get_reciprocal(scale) - # quant_data = torch.clamp(torch.round((maxq - minq) / (rmax - rmin + 1e-8) * (data - rmin)), minq, maxq) - quant_data = torch.clamp(torch.round(iscale * (data - rmin)), minq, maxq) - diff = scale * quant_data + rmin - data - - best_mad = torch.sum((weights * torch.abs(diff)) if use_mad else weights * torch.pow(diff, 2), dim=1, keepdim=True) - - for is_ in range(nstep): - factor = rrmin + rdelta * is_ + maxq - minq - # iscale_new = factor / (rmax - rmin + 1e-8) - scale_new = (rmax - rmin) / factor - iscale_new = get_reciprocal(scale_new) - quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin)), minq, maxq) - - mul_weights_quant_data = weights * quant_data_new - sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True) - sum_l2 = torch.sum(mul_weights_quant_data * quant_data_new, dim=-1, keepdim=True) - sum_xl = torch.sum(mul_weights_quant_data * data, dim=-1, keepdim=True) - - D = sum_w * sum_l2 - torch.pow(sum_l, 2) - this_scale = (sum_w * sum_xl - sum_x * sum_l) / D - this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D - this_min[this_min > 0] = 0 - this_scale[this_min > 0] = (sum_xl / sum_l2)[this_min > 0] - reverse_this_scale = get_reciprocal(this_scale) - - quant_data = torch.clamp(torch.round(reverse_this_scale * (data - this_min)), minq, maxq) - diff = this_scale * quant_data + this_min - data - # diff = this_scale * quant_data_new + this_min - data - mad = torch.sum((weights * torch.abs(diff)) if use_mad else weights * torch.pow(diff, 2), dim=-1, keepdim=True) - - idx_to_replace = torch.where((mad < best_mad) & (D > 0))[0] - best_mad[idx_to_replace] = mad[idx_to_replace] - scale[idx_to_replace] = this_scale[idx_to_replace] - rmin[idx_to_replace] = this_min[idx_to_replace] - - return scale.to(torch.float32), -rmin.to(torch.float32) + tensor = tensor + wmin + tensor = (tensor.mul_(inverse_scale)).round_().clamp_(0, maxq) + tensor = tensor.mul_(scale) + tensor = tensor.sub_(wmin).to(orig_dtype) + tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len) + return tensor, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} # TODO consolidate iterative_wls_quant_search_chunk and non-chunk def iterative_wls_quant_search_chunk( - data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=8 + data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=1 ): dtype = torch.float32 data = data.to(dtype) @@ -541,56 +503,104 @@ def iterative_wls_quant_search_chunk( results_scale = [] results_rmin = [] + chunk_size = (data.shape[0] + split_num - 1) // split_num + for start in range(0, data.shape[0], chunk_size): end = min(start + chunk_size, data.shape[0]) chunk = data[start:end] chunk_weights = weights if isinstance(weights, float) else weights[start:end] + # Pre-allocate reusable buffers to avoid new allocations + tmp = torch.empty_like(chunk) + quant_data = torch.empty_like(chunk) + diff = torch.empty_like(chunk) + rmin = torch.min(chunk, dim=1, keepdim=True)[0] rmax = torch.max(chunk, dim=1, keepdim=True)[0] sum_w = torch.sum(chunk_weights, dim=1, keepdim=True) sum_x = torch.sum(chunk_weights * chunk, dim=1, keepdim=True) + scale = (rmax - rmin) / (maxq - minq) iscale = get_reciprocal(scale) - quant_data = torch.clamp(torch.round(iscale * (chunk - rmin)), minq, maxq) - diff = scale * quant_data + rmin - chunk - best_mad = torch.sum( - (chunk_weights * torch.abs(diff)) if use_mad else chunk_weights * torch.pow(diff, 2), dim=1, keepdim=True - ) + + # tmp = (chunk - rmin) * iscale + tmp.copy_(chunk).sub_(rmin).mul_(iscale) + + # quant_data = round(tmp).clamp_() + torch.round(tmp, out=quant_data) + quant_data.clamp_(minq, maxq) + + # diff = scale * quant_data + rmin - chunk + diff.copy_(quant_data).mul_(scale).add_(rmin).sub_(chunk) + + if use_mad: + best_mad = (chunk_weights * diff.abs_()).sum(dim=1, keepdim=True) + else: + diff.pow_(2) + best_mad = (chunk_weights * diff).sum(dim=1, keepdim=True) for is_ in range(nstep): factor = rrmin + rdelta * is_ + maxq - minq + scale_new = (rmax - rmin) / factor iscale_new = get_reciprocal(scale_new) - quant_data_new = torch.clamp(torch.round(iscale_new * (chunk - rmin)), minq, maxq) - mul_weights_quant_data = chunk_weights * quant_data_new - sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True) - sum_l2 = torch.sum(mul_weights_quant_data * quant_data_new, dim=-1, keepdim=True) - sum_xl = torch.sum(mul_weights_quant_data * chunk, dim=-1, keepdim=True) - D = sum_w * sum_l2 - torch.pow(sum_l, 2) + + # tmp = (chunk - rmin) * iscale_new + tmp.copy_(chunk).sub_(rmin).mul_(iscale_new) + + torch.round(tmp, out=quant_data) + quant_data.clamp_(minq, maxq) + + # tmp = chunk_weights * quant_data + tmp.copy_(quant_data).mul_(chunk_weights) + + sum_l = tmp.sum(dim=-1, keepdim=True) + sum_l2 = (tmp * quant_data).sum(dim=-1, keepdim=True) + sum_xl = (tmp * chunk).sum(dim=-1, keepdim=True) + + D = sum_w * sum_l2 - sum_l * sum_l + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D - this_min[this_min > 0] = 0 - this_scale[this_min > 0] = (sum_xl / sum_l2)[this_min > 0] + + mask = this_min > 0 + if mask.any(): + this_min[mask] = 0 + this_scale[mask] = (sum_xl / sum_l2)[mask] + reverse_this_scale = get_reciprocal(this_scale) - quant_data = torch.clamp(torch.round(reverse_this_scale * (chunk - this_min)), minq, maxq) - diff = this_scale * quant_data + this_min - chunk - mad = torch.sum( - (chunk_weights * torch.abs(diff)) if use_mad else chunk_weights * torch.pow(diff, 2), - dim=-1, - keepdim=True, - ) + + # tmp = (chunk - this_min) * reverse_this_scale + tmp.copy_(chunk).sub_(this_min).mul_(reverse_this_scale) + + torch.round(tmp, out=quant_data) + quant_data.clamp_(minq, maxq) + + # diff = this_scale * quant_data + this_min - chunk + diff.copy_(quant_data).mul_(this_scale).add_(this_min).sub_(chunk) + + if use_mad: + mad = (chunk_weights * diff.abs_()).sum(dim=-1, keepdim=True) + else: + diff.pow_(2) + mad = (chunk_weights * diff).sum(dim=-1, keepdim=True) + idx_to_replace = torch.where((mad < best_mad) & (D > 0))[0] + best_mad[idx_to_replace] = mad[idx_to_replace] scale[idx_to_replace] = this_scale[idx_to_replace] rmin[idx_to_replace] = this_min[idx_to_replace] + results_scale.append(scale.to(torch.float32)) results_rmin.append(-rmin.to(torch.float32)) - if split_num > 1: - clear_memory(device_list=[data.device]) - return torch.cat(results_scale, dim=0), torch.cat(results_rmin, dim=0) + if split_num > 1: + clear_memory(device_list=data.device) + if len(results_scale) > 1: + return torch.cat(results_scale, dim=0), torch.cat(results_rmin, dim=0) + else: + return results_scale[0], results_rmin[0] def iterative_wls_quant_search( @@ -612,41 +622,29 @@ def iterative_wls_quant_search( """ # TODO this one should change to try catch later - if split_num > 1: - return iterative_wls_quant_search_chunk( - data=data, - bits=bits, - rrmin=rrmin, - rdelta=rdelta, - nstep=nstep, - use_mad=use_mad, - weights=weights, - split_num=split_num, - ) - else: - return iterative_wls_quant_search_non_chunk( - data=data, - bits=bits, - rrmin=rrmin, - rdelta=rdelta, - nstep=nstep, - use_mad=use_mad, - weights=weights, - ) - -@torch.no_grad() -def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype): - from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K + return iterative_wls_quant_search_chunk( + data=data, + bits=bits, + rrmin=rrmin, + rdelta=rdelta, + nstep=nstep, + use_mad=use_mad, + weights=weights, + split_num=split_num, + ) - group_size = 16 +@torch.inference_mode() +def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num): if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0): if bits == 3: + # Note: make_q3_quants does not support split_num/chunking; + # 3-bit quantization is performed in a single chunk. scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True) - ##scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) + # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) elif bits == 6: - scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) + scale, int_w = make_qx_quants_chunk(tensor, bits=bits, rmse_type=1, qw=None, split_num=split_num) else: imatrix = imatrix.to(tensor.device) weights = imatrix.reshape(1, -1) @@ -655,7 +653,9 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype): quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) - scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights) + scale, int_w = make_qx_quants_chunk(tensor, bits=bits, rmse_type=1, qw=quant_weights, split_num=split_num) + if split_num > 1: + clear_memory(device_list=[tensor.device]) return scale @@ -667,6 +667,7 @@ def quant_tensor_gguf_sym_dq( scale=None, d_scale=None, scale_dtype=torch.float16, + split_num=1, **kwargs, ): """Quantize and de-quantize tensor asymmetrically. For Q3_K, Q6_K. @@ -704,18 +705,20 @@ def quant_tensor_gguf_sym_dq( # (nb, 16, 16) tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size) if scale is None and d_scale is None: - scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype) + scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num=split_num) scale = scale.to(scale_dtype) scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) # conduct double quant - scale, d_scale = double_quant_tensor_sym(scale, super_bits) + scale, d_scale = double_quant_tensor_sym_rtn(scale, super_bits) scale = scale.unsqueeze(-1) - zp = torch.full_like(scale, maxq) # pylint: disable=E1130 + # zp = torch.full_like(scale, maxq) # pylint: disable=E1130 inverse_scale = get_reciprocal(scale) - int_w = round_ste(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq - qdq_result = (scale * (int_w - zp)).to(orig_dtype) - qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) + # int_w = round_ste(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq + # qdq_result = (scale * (int_w - zp)).to(orig_dtype) + tensor = tensor.mul_(inverse_scale).round_().clamp_(-maxq, maxq - 1) + tensor = tensor.mul_(scale).to(orig_dtype) + tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len) - return qdq_result, {"scale": scale, "d_scale": d_scale}, zp + return tensor, {"scale": scale, "d_scale": d_scale}, maxq diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 8fc6f79a0..960a7fc08 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -22,25 +22,60 @@ def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, float] = None) -> torch.Tensor: - nmax = pow(2, bits - 1) + # Maximum absolute value for symmetric quantization + nmax = 1 << (bits - 1) # equivalent to pow(2, bits-1) + + # Find per-group max along the last dimension imax = torch.abs(data).argmax(dim=-1, keepdim=True) group_max = torch.take_along_dim(data, imax, dim=-1) + + # Compute initial inverse scales iscales = -nmax * get_reciprocal(group_max) - scales = get_reciprocal(iscales) - L = torch.round(1.0 * iscales * data).clip(-nmax, nmax - 1) + scales = get_reciprocal(iscales) # scale = 1 / iscales + + # Initial quantized values (in-place round and clamp) + L = torch.empty_like(data) + torch.round(iscales * data, out=L) + L.clamp_(-nmax, nmax - 1) + + # Set default weight if None if qw is None: qw = 1.0 - best_loss = torch.sum(((scales * L - data).to(torch.float32)) ** 2 * qw, dim=-1) + + # Compute initial best loss + best_loss = ((scales * L - data).to(torch.float32)) ** 2 + if isinstance(qw, torch.Tensor): + best_loss.mul_(qw) # inplace multiply by weight + best_loss = torch.sum(best_loss, dim=-1) + + # Iterative search over small adjustments for _is in range(-18 * 5, 18 * 5 + 1): if _is == 0: continue - iscales = -(nmax - 0.01 * _is) * get_reciprocal(group_max) - tmp_L = torch.round(iscales * data).clip(-nmax, nmax - 1) - tmp_scales = get_reciprocal(iscales) - loss = torch.sum(((tmp_scales * tmp_L - data).to(torch.float32)) ** 2 * qw, dim=-1) + + # Update iscales in-place + iscales_tmp = -(nmax - 0.01 * _is) * get_reciprocal(group_max) + + # Compute temporary quantized values (in-place round + clamp) + tmp_L = torch.empty_like(data) + torch.round(iscales_tmp * data, out=tmp_L) + tmp_L.clamp_(-nmax, nmax - 1) + + # Compute temporary scales + tmp_scales = get_reciprocal(iscales_tmp) + + # Compute temporary loss + loss = ((tmp_scales * tmp_L - data).to(torch.float32)) ** 2 + if isinstance(qw, torch.Tensor): + loss.mul_(qw) + loss = torch.sum(loss, dim=-1) + + # Replace scales where loss improves (in-place) replace_id = loss < best_loss - scales[replace_id] = tmp_scales[replace_id] - best_loss[replace_id] = loss[replace_id] + if replace_id.any(): + scales[replace_id] = tmp_scales[replace_id] + best_loss[replace_id] = loss[replace_id] + return scales @@ -53,11 +88,6 @@ def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 bits: Number of bits for quantization (e.g., 2, 3, 4, 8) group_size: Number of elements to share scale for quantization v: Rounding value perturbation - min_scale: Minimum scale coefficient for tensor - max_scale: Maximum scale coefficient for tensor - tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. - tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. - scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability Returns: @@ -79,9 +109,8 @@ def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 scale = search_scales(tensor, bits, qw=imatrix) scale = torch.where(scale < 0, torch.clamp(scale, max=-q_scale_thresh), torch.clamp(scale, min=q_scale_thresh)) - int_w = round_ste(tensor / scale + v) - q = torch.clamp(int_w, -maxq, maxq - 1) - qdq_result = (scale * q).to(tensor.dtype) + int_w = tensor.div(scale).round_().clamp_(-maxq, maxq - 1) + qdq_result = (int_w.mul_(scale)).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, scale, maxq diff --git a/auto_round/export/export_to_gguf/convert.py b/auto_round/export/export_to_gguf/convert.py index 3ac31932d..5f42535d3 100644 --- a/auto_round/export/export_to_gguf/convert.py +++ b/auto_round/export/export_to_gguf/convert.py @@ -50,7 +50,15 @@ from auto_round.export.export_to_gguf.config import ModelType from auto_round.export.export_to_gguf.packing import ggml_quant -from auto_round.utils import LazyImport, get_module, get_packing_device, is_fp8_model, logger +from auto_round.utils import ( + LazyImport, + clean_module_parameter, + clear_memory, + get_module, + get_packing_device, + is_fp8_model, + logger, +) gguf = LazyImport("gguf") @@ -58,17 +66,6 @@ from torch import Tensor -def clean_module_parameter(submodule, parameter): - if submodule is None: - return - is_buffer = parameter in submodule._buffers - with torch.no_grad(): - if is_buffer: - submodule._buffers[parameter] = None - else: - submodule._parameters[parameter] = None - - def download_convert_file(redownload=False): CONVERT_URL = "https://raw.githubusercontent.com/ggml-org/llama.cpp/refs/heads/master/convert_hf_to_gguf.py" FILE_NAME = "convert_hf_to_gguf.py" @@ -190,7 +187,7 @@ def is_extra_tensor(tensor_name): def _quant_data_with_args(data_torch, data_qtype, scale, zp, d_scale=None, wmin=None, d_wmin=None, imatrix=None): - device = get_packing_device() + device = data_torch.device data_torch = data_torch.to(torch.float32) scale = scale.to(torch.float32) if isinstance(scale, torch.Tensor) else scale zp = zp.to(torch.float32) if isinstance(zp, torch.Tensor) else zp @@ -215,7 +212,7 @@ def _quant_data_with_args(data_torch, data_qtype, scale, zp, d_scale=None, wmin= def _quant_data(cls, data_torch, data_qtype, name, modify_name, bid): suffix = ".weight" - device = get_packing_device() + device = data_torch.device if suffix in name: layer_name = name[: -len(suffix)] module = get_module(cls.model, layer_name) @@ -375,7 +372,8 @@ def prepare_tensors(cls): max_name_len = max(len(s) for _, s in cls.tensor_map.mapping.values()) + len(".weight,") for name, data_torch in chain(cls.generate_extra_tensors(), cls.get_tensors()): - if data_torch is None: + + if data_torch is None or data_torch.numel() == 0: continue # we don't need these if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): @@ -406,9 +404,7 @@ def prepare_tensors(cls): modify_name = _special_name_handle(cls, name) orig_device = data_torch.device - data_torch = data_torch.to("cpu") for new_name, data_torch in cls.modify_tensors(data_torch, modify_name, bid): - data_torch.to(orig_device) skip = False for tensor_info in cls.gguf_writer.tensors: if new_name in tensor_info: @@ -417,12 +413,7 @@ def prepare_tensors(cls): break if skip: continue - data = data_torch.squeeze().cpu().numpy() - - # if data ends up empty, it means data_torch was a scalar tensor -> restore - if len(data.shape) == 0: - data = data_torch.numpy() - + data = data_torch.squeeze() n_dims = len(data.shape) data_qtype: gguf.GGMLQuantizationType | bool = cls.tensor_force_quant(name, new_name, bid, n_dims) @@ -537,6 +528,11 @@ def prepare_tensors(cls): gguf.GGMLQuantizationType.BF16, gguf.GGMLQuantizationType.F32, ]: + data = data_torch.squeeze().cpu().numpy() + + # if data ends up empty, it means data_torch was a scalar tensor -> restore + if len(data.shape) == 0: + data = data_torch.numpy() try: data = gguf.quants.quantize(data, data_qtype) except gguf.QuantError as e: @@ -598,6 +594,8 @@ def prepare_tensors(cls): logger.info( f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype}" f" --> {data_qtype.name}, shape = {shape_str}" ) + if not (hasattr(cls, "current_packing_block") and cls.current_packing_block is not None): + clear_memory(device_list=[orig_device]) cls.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) diff --git a/auto_round/export/export_to_gguf/export.py b/auto_round/export/export_to_gguf/export.py index 890a93880..0a5bfc461 100644 --- a/auto_round/export/export_to_gguf/export.py +++ b/auto_round/export/export_to_gguf/export.py @@ -159,13 +159,16 @@ def pack_gguf_layer( model_type=convert_hf_to_gguf.ModelType.MMPROJ, ) ) + if not hasattr(model, "last_layer_name_to_block_name"): block_name_to_last_layer_name = {} block_names = get_block_names(model, quant_vision=True) block_names_flatten = flatten_list(block_names) + all_qlayer_name = [] for n, m in model.named_modules(): if not check_to_quantized(m): continue + all_qlayer_name.append(n) for block_name in block_names_flatten: block_name_split = block_name.split(".") name_split = n.split(".") @@ -177,23 +180,46 @@ def pack_gguf_layer( block_name_to_last_layer_name[block_name] = n last_layer_name_to_block_name = {v: k for k, v in block_name_to_last_layer_name.items()} model.last_layer_name_to_block_name = last_layer_name_to_block_name + names_in_blocks = [] + for block_name in block_names_flatten: + block = get_module(model, block_name) + for n, m in block.named_modules(): + if check_to_quantized(m): + names_in_blocks.append(m.tmp_name) + names_outside_blocks = list(set(layer_config.keys()) - set(names_in_blocks)) + model.names_outside_blocks = names_outside_blocks + if name in model.last_layer_name_to_block_name: - ##packing block + # Packing block + block = get_module(model, model.last_layer_name_to_block_name[name]) for gguf_model in gguf_model_instance_global: gguf_model.current_packing_block = model.last_layer_name_to_block_name[name] gguf_model.prepare_tensors() - block = get_module(model, model.last_layer_name_to_block_name[name]) for n, m in block.named_modules(): if hasattr(m, "weight"): m.weight = None if hasattr(m, "bias"): m.bias = None - clear_memory() model.last_layer_name_to_block_name.pop(name) if len(model.last_layer_name_to_block_name) == 0: for gguf_model in gguf_model_instance_global: gguf_model.current_packing_block = None + if name in model.names_outside_blocks: + # Packing block + for gguf_model in gguf_model_instance_global: + gguf_model.current_packing_block = name + gguf_model.prepare_tensors() + + layer = get_module(model, name) + if hasattr(layer, "weight"): + layer.weight = None + if hasattr(layer, "bias"): + layer.bias = None + model.names_outside_blocks.remove(name) + if len(model.names_outside_blocks) == 0: + for gguf_model in gguf_model_instance_global: + gguf_model.current_packing_block = None @torch.inference_mode() diff --git a/auto_round/export/export_to_gguf/packing.py b/auto_round/export/export_to_gguf/packing.py index bc9189b7b..9a48cfa0d 100644 --- a/auto_round/export/export_to_gguf/packing.py +++ b/auto_round/export/export_to_gguf/packing.py @@ -29,6 +29,35 @@ def register(cls): return register +def ggml_quant_core(quant_func, blocks, scale, zp, wmin, d_scale, d_wmin, imatrix, original): + try: + new_data = quant_func( + blocks, + scale, + zp=zp, + wmin=wmin, + d_scale=d_scale, + d_wmin=d_wmin, + imatrix=imatrix, + original=original, + ) + except torch.OutOfMemoryError: + orig_device = blocks.device + cpu_device = "cpu" + blocks = blocks.to(cpu_device) + scale = scale.to(cpu_device) if scale is not None else scale + zp = zp.to(cpu_device) if zp is not None and isinstance(zp, torch.Tensor) else zp + wmin = wmin.to(cpu_device) if wmin is not None else wmin + d_scale = d_scale.to(cpu_device) if d_scale is not None else d_scale + d_wmin = d_wmin.to(cpu_device) if d_wmin is not None else d_wmin + imatrix = imatrix.to(cpu_device) if imatrix is not None else imatrix + clear_memory(device_list=orig_device) + new_data = quant_func( + blocks, scale, zp=zp, wmin=wmin, d_scale=d_scale, d_wmin=d_wmin, imatrix=imatrix, original=original + ) + return new_data + + def ggml_quant( data, ggml_type, @@ -42,38 +71,44 @@ def ggml_quant( original=False, ): block_size, type_size = GGML_QUANT_SIZES[ggml_type] - data = data.to(torch.float32).to(device) - scale = scale.to(device) if scale is not None else scale - zp = zp.to(device) if zp is not None and isinstance(zp, torch.Tensor) else zp - wmin = wmin.to(device) if wmin is not None else wmin - d_scale = d_scale.to(device) if d_scale is not None else d_scale - d_wmin = d_wmin.to(device) if d_wmin is not None else d_wmin - shape = data.shape n_blocks = data.nelement() // block_size + split_num = 16 if max(data.shape) > 100_000 else 1 blocks = data.reshape((n_blocks, block_size)) + scale = scale.to(device).reshape(blocks.shape[0], -1) if scale is not None else scale + zp = zp.to(device).reshape(blocks.shape[0], -1) if zp is not None and isinstance(zp, torch.Tensor) else zp + wmin = wmin.to(device).reshape(blocks.shape[0], -1) if wmin is not None else wmin + d_scale = d_scale.to(device).reshape(blocks.shape[0], -1) if d_scale is not None else d_scale + d_wmin = d_wmin.to(device).reshape(blocks.shape[0], -1) if d_wmin is not None else d_wmin quant_func = GGML_QUANT_TYPE[ggml_type] - try: - new_data = quant_func( - blocks, scale, zp=zp, wmin=wmin, d_scale=d_scale, d_wmin=d_wmin, imatrix=imatrix, original=original - ) - except Exception: - device = "cpu" - blocks = blocks.to(device) - scale = scale.to(device) if scale is not None else scale - zp = zp.to(device) if zp is not None and isinstance(zp, torch.Tensor) else zp - wmin = wmin.to(device) if wmin is not None else wmin - d_scale = d_scale.to(device) if d_scale is not None else d_scale - d_wmin = d_wmin.to(device) if d_wmin is not None else d_wmin - imatrix = imatrix.to(device) if imatrix is not None else imatrix - clear_memory() - new_data = quant_func( - blocks, scale, zp=zp, wmin=wmin, d_scale=d_scale, d_wmin=d_wmin, imatrix=imatrix, original=original - ) + results = [] + chunk_size = (n_blocks + split_num - 1) // split_num + if split_num > 1: + for i in range(split_num): + start = chunk_size * i + end = chunk_size * (i + 1) + tmp_blocks = blocks[start:end] + tmp_scale = scale[start:end] if scale is not None else scale + tmp_zp = zp[start:end] if zp is not None and isinstance(zp, torch.Tensor) else zp + tmp_wmin = wmin[start:end] if wmin is not None else wmin + tmp_d_scale = d_scale[start:end] if d_scale is not None else d_scale + tmp_d_wmin = d_wmin[start:end] if d_wmin is not None else d_wmin + new_data = ggml_quant_core( + quant_func, tmp_blocks, tmp_scale, tmp_zp, tmp_wmin, tmp_d_scale, tmp_d_wmin, imatrix, original + ) + results.append(new_data) + if split_num > 1: + clear_memory(device_list=device) + else: + new_data = ggml_quant_core(quant_func, blocks, scale, zp, wmin, d_scale, d_wmin, imatrix, original) + results.append(new_data) - assert new_data.shape[-1] == type_size - new_data = new_data.reshape(*shape[:-1], shape[-1] // block_size * type_size) + if len(results) == 1: + new_data = results[0] + else: + new_data = np.concatenate(results, axis=0) + new_data = new_data.reshape(*shape[:-1], shape[-1] // block_size * type_size) # Check shape correctness new_data = new_data.reshape(*shape[:-1], -1) return new_data @@ -81,10 +116,107 @@ def ggml_quant( def torch_roundf(n): a = torch.abs(n) floored = torch.floor(a) - b = floored + torch.floor(2 * (a - floored)) + b = floored + torch.floor((a - floored).mul_(2)) return torch.sign(n) * b +def make_qx_quants_chunk(data, bits, rmse_type=0, qw=None, split_num=1): + """ + Extreme VRAM-optimized version of quantization. + + - Processes data in chunks along the batch dimension (dim=0) to reduce peak memory usage. + - Uses inplace operations to avoid unnecessary tensor copies. + - Reuses buffers for temporary calculations wherever possible. + """ + nmax = 2 ** (bits - 1) + scales_list = [] + L_list = [] + chunk_size = (data.shape[0] + split_num - 1) // split_num + for start in range(0, data.shape[0], chunk_size): + end = min(start + chunk_size, data.shape[0]) + chunk = data[start:end] # Slice a batch chunk to reduce memory footprint + + # Compute absolute values inplace to avoid extra tensor allocation + chunk_abs = chunk.abs() + imax = chunk_abs.argmax(dim=-1, keepdim=True) + group_max = torch.take_along_dim(chunk, imax, dim=-1) + + # Compute scale factors (inverse max) without extra tensor + + iscales = -nmax * get_reciprocal(group_max) + + # L buffer stores quantized values, modified inplace to save memory + L = (chunk * iscales).round_().clamp_(-nmax, nmax - 1) + + # Simple case: rmse_type == 0 + if rmse_type == 0: + L.add_(nmax) # Shift to unsigned representation inplace + scales = (1 / iscales).reshape(iscales.shape[:2]) + scales_list.append(scales) + L_list.append(L.to(torch.uint8)) + continue + + return_early = False + if rmse_type < 0: + rmse_type = -rmse_type + return_early = True + + # Compute weighting tensor w based on rmse_type + if qw is not None: + w = qw + elif rmse_type == 1: + w = chunk * chunk + elif rmse_type == 2: + w = torch.ones_like(chunk) + elif rmse_type == 3: + w = chunk.abs() + else: + w = chunk.abs().sqrt() + + # Compute sumlx and suml2 using the pre-allocated L buffer + sumlx = (w * chunk * L).sum(dim=-1) + suml2 = (w * L * L).sum(dim=-1) + scales = sumlx / suml2 + + if return_early: + iscales_inv = (1 / iscales).reshape(iscales.shape[:2]) + # Mix the current scale with inverse scale if suml2 > 0 + scales = torch.where(suml2 > 0, 0.5 * (scales + iscales_inv), iscales_inv) + L.add_(nmax) + scales_list.append(scales) + L_list.append(L.to(torch.uint8)) + continue + + # Iteratively refine scales and quantized values + best = scales * sumlx + for _is in range(-9, 10): + if _is == 0: + continue + iscales_tmp = -(nmax + -0.1 * _is) / group_max + # Use a temporary L buffer to avoid creating new large tensor + L_tmp = (chunk * iscales_tmp).round_().clamp_(-nmax, nmax - 1) + sumlx_tmp = (w * chunk * L_tmp).sum(dim=-1) + suml2_tmp = (w * L_tmp * L_tmp).sum(dim=-1) + # Determine which elements should be replaced + replace_id = (suml2_tmp > 0) & (sumlx_tmp * sumlx_tmp > best * suml2_tmp) + # Inplace update of L and scales + L[replace_id] = L_tmp[replace_id] + scales[replace_id] = sumlx_tmp[replace_id] / suml2_tmp[replace_id] + best[replace_id] = scales[replace_id] * sumlx_tmp[replace_id] + + L.add_(nmax) # Final shift to unsigned + scales_list.append(scales) + L_list.append(L.to(torch.uint8)) + + # Concatenate all chunks along batch dimension + if len(scales_list) > 1: + scales = torch.cat(scales_list, dim=0) + L = torch.cat(L_list, dim=0) + return scales, L + else: + return scales, L + + def make_qx_quants(data, bits, rmse_type=0, qw=None): """ adapted from llmacpp @@ -142,34 +274,78 @@ def make_qx_quants(data, bits, rmse_type=0, qw=None): def make_q3_quants(data, bits, do_rmse=False): - nmax = pow(2, bits - 1) + # Maximum absolute integer value for symmetric quantization + nmax = 1 << (bits - 1) # equivalent to pow(2, bits-1) + + # Find per-group max indices along last dim imax = abs(data).argmax(axis=-1, keepdims=True) + + # Gather group-wise maximum values group_max = torch.take_along_dim(data, imax, dim=-1) + + # Compute inverse scale in-place (multiplying by -nmax) iscale = -nmax * get_reciprocal(group_max) + if do_rmse: - L = torch.round(iscale * data).clip(-nmax, nmax - 1) - w = torch.pow(data, 2) + # Initial quantization L (in-place round and clamp) + L = torch.empty_like(data) + torch.round(iscale * data, out=L) + L.clamp_(-nmax, nmax - 1) + + # Weight for RMSE = x^2 (in-place) + w = data.clone().pow_(2) + + # Precompute sums sumlx = torch.sum(w * data * L, dim=-1) suml2 = torch.sum(w * L * L, dim=-1) - for itry in range(5): + # Iterative RMSE refinement + for _ in range(5): for i in range(sumlx.shape[-1]): - w_tmp, data_tmp, L_tmp = w[:, :, i], data[:, :, i], L[:, :, i] + # Extract current slice + w_tmp = w[:, :, i] + data_tmp = data[:, :, i] + L_tmp = L[:, :, i] + + # Exclude current slice from sums slx = sumlx - w_tmp * data_tmp * L_tmp replace_idx = slx > 0 - sl2 = suml2 - w_tmp * torch.pow(L_tmp, 2) - new_L = torch.round(data_tmp * sl2 / slx).clip(-nmax, nmax - 1) + sl2 = suml2 - w_tmp * L_tmp * L_tmp + + # Compute new L candidate (in-place round and clamp) + new_L = torch.empty_like(L_tmp) + torch.round(data_tmp * sl2 / slx, out=new_L) + new_L.clamp_(-nmax, nmax - 1) + + # Identify positions to update tmp_replace_idx = replace_idx & (new_L != L_tmp) + + # Update sums where L changes slx[tmp_replace_idx] += w_tmp[tmp_replace_idx] * data_tmp[tmp_replace_idx] * new_L[tmp_replace_idx] sl2[tmp_replace_idx] += w_tmp[tmp_replace_idx] * new_L[tmp_replace_idx] * new_L[tmp_replace_idx] + + # Further check condition for improvement replace_idx &= (sl2 > 0) & (slx * slx * suml2 > sumlx * sumlx * sl2) - L[:, :, i][replace_idx] = new_L[replace_idx] + + # Update L in-place + L_tmp[replace_idx] = new_L[replace_idx] + + # Update global sums sumlx = slx suml2 = sl2 + + # Compute final scale and return quantized L return sumlx * get_reciprocal(suml2), L.to(torch.uint8) - L = torch.round(iscale * data).clip(-nmax, nmax - 1) + nmax + # Fast path: quantize without RMSE (in-place round, clamp, shift) + L = torch.empty_like(data) + torch.round(iscale * data, out=L) + L.clamp_(-nmax, nmax - 1) + L.add_(nmax) + + # Compute scales (reciprocal of iscale) scales = get_reciprocal(iscale).reshape(iscale.shape[:2]) + return scales, L.to(torch.uint8) @@ -248,10 +424,6 @@ def make_qkx2_quants(data, bits, weights=None, rmin=-1.0, rdelta=0.1, nstep=20, return scale.reshape(scale.shape[:2]), L, the_mins.reshape(the_mins.shape[:2]) -def make_qkx3_quants(data, bits, weights, rmin=-1.0, rdelta=0.1, nstep=20, use_mad=False): - return make_qkx2_quants(data, bits, weights, rmin=rmin, rdelta=rdelta, nstep=nstep, use_mad=use_mad) - - def make_qp_quants(nmax, data, quant_weights): group_max = torch.max(data, dim=-1, keepdim=True)[0] scale = group_max / nmax @@ -320,10 +492,9 @@ def q4_0_quant_block(blocks, scale=None, zp=None, **kwargs): max = torch.take_along_dim(blocks, imax, dim=-1) d = max / -8 id = get_reciprocal(d) - - qs = torch.trunc(blocks.to(torch.float64) * id.to(torch.float64) + 8.5).clip(0, 15).to(torch.uint8) - n_blocks = blocks.shape[0] + qs = torch.trunc(blocks.to(torch.float64).mul_(id.to(torch.float64)).add_(8.5)).clamp_(0, 15).to(torch.uint8) + block_size = GGML_QUANT_SIZES["q4_0"][0] qs = qs.reshape((n_blocks, 2, block_size // 2)).cpu().numpy() qs = qs[..., 0, :] | (qs[..., 1, :] << 4) @@ -343,10 +514,10 @@ def q4_1_quant_block(blocks, scale=None, zp=None, **kwargs): min = blocks.min(axis=-1, keepdims=True)[0] d = (max - min) / 15 id = get_reciprocal(d) + n_blocks = blocks.shape[0] - qs = torch.trunc((blocks - min) * id + 0.5).clip(0, 15).to(torch.uint8) + qs = torch.trunc(blocks.sub_(min).mul_(id).add_(0.5)).clamp_(0, 15).to(torch.uint8) - n_blocks = blocks.shape[0] block_size = GGML_QUANT_SIZES["q4_1"][0] qs = qs.reshape((n_blocks, 2, block_size // 2)).cpu().numpy() qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4)) @@ -370,7 +541,13 @@ def q5_0_quant_block(blocks: np.array, scale=None, zp=None, **kwargs): block_size = GGML_QUANT_SIZES["q5_0"][0] # FIXME: Q5_0's reference rounding is cursed and depends on FMA - q = torch.trunc(blocks.to(torch.float64) * id.to(torch.float64) + 16.5).clip(0, 31).to(torch.uint8).cpu().numpy() + q = ( + torch.trunc(blocks.to(torch.float64).mul_(id.to(torch.float64)).add_(16.5)) + .clamp_(0, 31) + .to(torch.uint8) + .cpu() + .numpy() + ) qs = q.reshape((n_blocks, 2, block_size // 2)) qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4)) @@ -396,7 +573,7 @@ def q5_1_quant_block(blocks: np.array, scale=None, zp=None, **kwargs): block_size = GGML_QUANT_SIZES["q5_1"][0] id = get_reciprocal(d) - q = torch.trunc((blocks - min) * id + 0.5).clip(0, 31).to(torch.uint8).cpu().numpy() + q = torch.trunc(blocks.sub_(min).mul_(id).add_(0.5)).clamp_(0, 31).to(torch.uint8).cpu().numpy() qs = q.reshape((n_blocks, 2, block_size // 2)) qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4)) @@ -416,8 +593,8 @@ def q8_0_quant_block(blocks, scale=None, zp=None, **kwargs) -> np.ndarray: else: d = torch.abs(blocks).max(dim=1, keepdim=True)[0] / 127 id = get_reciprocal(d) - - qs = torch.clip(torch_roundf(blocks * id), -128, 127) + blocks = blocks.mul(id) + qs = torch_roundf(blocks).clamp_(-128, 127) # (n_blocks, 2) d = d.cpu().numpy().astype(np.float16).view(np.uint8) @@ -430,7 +607,7 @@ def q8_0_quant_block(blocks, scale=None, zp=None, **kwargs) -> np.ndarray: @register_qtype("q2_k") def q2_k_quant_block(blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, imatrix=None, original=False, **kwargs): nb = blocks.shape[0] - + device = blocks.device blocks = blocks.reshape((nb, QK_K // 16, 16)) # (nb, 16, 16) if scale is not None: @@ -438,9 +615,9 @@ def q2_k_quant_block(blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, i mins = wmin.reshape((-1, QK_K // 16)) output_d = d_scale.reshape(-1, 1).to(torch.float32) output_dmin = d_wmin.reshape(-1, 1).to(torch.float32) - output_scale = torch.round(scales * get_reciprocal(output_d)).clip(0, 15).to(torch.uint8) - output_scale |= torch.round(mins * get_reciprocal(output_dmin)).clip(0, 15).to(torch.uint8) << 4 - all_L = torch.round((blocks + mins.unsqueeze(-1)) / scales.unsqueeze(-1)).clip(0, 3).to(torch.uint8) + output_scale = (scales * get_reciprocal(output_d)).round_().clamp_(0, 15).to(torch.uint8) + output_scale |= (mins * get_reciprocal(output_dmin)).round_().clamp_(0, 15).to(torch.uint8) << 4 + all_L = blocks.add_(mins.unsqueeze(-1)).div_(scales.unsqueeze(-1)).round_().clamp_(0, 3).to(torch.uint8) elif original: scales, all_L, mins = make_qkx2_quants(blocks, bits=2, rmin=-0.5, rdelta=0.1, nstep=15, use_mad=True) max_scales = torch.max(scales, dim=-1, keepdim=True)[0] @@ -466,10 +643,14 @@ def q2_k_quant_block(blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, i replace_ids = d_tmp != 0 all_L[replace_ids] = ( - torch.round((blocks[replace_ids] + dm_tmp[replace_ids].unsqueeze(-1)) / d_tmp[replace_ids].unsqueeze(-1)) - .clip(0, 3) + blocks[replace_ids] + .add_(dm_tmp[replace_ids].unsqueeze(-1)) + .div_(d_tmp[replace_ids].unsqueeze(-1)) + .round_() + .clamp_(0, 3) .to(torch.uint8) ) + else: from auto_round.data_type.gguf import quant_tensor_gguf_asym_dq @@ -477,15 +658,14 @@ def q2_k_quant_block(blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, i blocks, scales, mins = quant_tensor_gguf_asym_dq(blocks, bits=2, scale_dtype=torch.float32, imatrix=imatrix) scales, d_scale = scales["scale"], scales["d_scale"] mins, d_wmin = mins["wmin"], mins["d_wmin"] - blocks = blocks.reshape((nb, QK_K // 16, 16)) scales = scales.reshape((-1, QK_K // 16)) mins = mins.reshape((-1, QK_K // 16)) output_d = d_scale.reshape(-1, 1).to(torch.float32) output_dmin = d_wmin.reshape(-1, 1).to(torch.float32) - output_scale = torch.round(scales * get_reciprocal(output_d)).clip(0, 15).to(torch.uint8) - output_scale |= torch.round(mins * get_reciprocal(output_dmin)).clip(0, 15).to(torch.uint8) << 4 - all_L = torch.round((blocks + mins.unsqueeze(-1)) / scales.unsqueeze(-1)).clip(0, 3).to(torch.uint8) + output_scale = scales.mul(get_reciprocal(output_d)).round_().clamp_(0, 15).to(torch.uint8) + output_scale |= (mins * get_reciprocal(output_dmin)).round_().clamp_(0, 15).to(torch.uint8) << 4 + all_L = blocks.add_(mins.unsqueeze(-1)).div_(scales.unsqueeze(-1)).round_().clamp_(0, 3).to(torch.uint8) output_scale = output_scale.cpu().numpy() all_L = all_L.reshape(-1, 4, 32) @@ -505,36 +685,37 @@ def q3_k_quant_block(blocks: np.array, scale=None, d_scale=None, original=False, nb = blocks.shape[0] blocks = blocks.reshape(nb, QK_K // 16, 16) - output_scale = np.empty((nb, K_SCALE_SIZE), dtype=np.uint8) - if scale is not None: qdq_scale = scale.reshape(-1, QK_K // 16).to(torch.float32) dq_scale = d_scale.reshape(-1, 1).to(torch.float32) - all_L = (torch.round(blocks * get_reciprocal(qdq_scale.unsqueeze(-1))).clip(-4, 3) + 4).to(torch.uint8) - q_scales_offset = torch.round(qdq_scale * get_reciprocal(dq_scale)).clip(-32, 31) + 32 - elif original: ## this is correct + all_L = blocks.mul(get_reciprocal(qdq_scale.unsqueeze(-1))).round_().clamp_(-4, 3).add_(4).to(torch.uint8) + q_scales_offset = (qdq_scale * get_reciprocal(dq_scale)).round_().clamp_(-32, 31).add_(32) + elif original: scales, _ = make_q3_quants(blocks, bits=3, do_rmse=True) scales_abs_max = abs(scales).argmax(dim=-1, keepdim=True) max_scales_mag = torch.take_along_dim(scales, scales_abs_max, dim=-1) inverse_dq_scale = -32 * get_reciprocal(max_scales_mag) dq_scale = get_reciprocal(inverse_dq_scale) - qscale = torch.round(inverse_dq_scale * scales).clip(-32, 31) + qscale = (inverse_dq_scale * scales).round_().clamp_(-32, 31) qdq_scale = dq_scale.to(torch.float32) * qscale reverse_qdq_scale = get_reciprocal(qdq_scale) - all_L = (torch.round(blocks * reverse_qdq_scale.unsqueeze(-1)).clip(-4, 3) + 4).to(torch.uint8) - q_scales_offset = torch.round(qdq_scale * inverse_dq_scale).clip(-32, 31) + 32 + all_L = blocks.mul_(reverse_qdq_scale.unsqueeze(-1)).round_().clamp_(-4, 3).add_(4).to(torch.uint8) + q_scales_offset = (qdq_scale * inverse_dq_scale).round_().clamp_(-32, 31).add_(32) else: from auto_round.data_type.gguf import quant_tensor_gguf_sym_dq blocks = blocks.reshape(blocks.shape[0], -1) blocks, scales, _ = quant_tensor_gguf_sym_dq(blocks, bits=3, scale_dtype=torch.float32, imatrix=imatrix) scales, d_scale = scales["scale"], scales["d_scale"] + blocks = blocks.reshape((nb, QK_K // 16, 16)) qdq_scale = scales.reshape((-1, QK_K // 16)).to(torch.float32) dq_scale = d_scale.reshape(-1, 1).to(torch.float32) - all_L = (torch.round(blocks * get_reciprocal(qdq_scale.unsqueeze(-1))).clip(-4, 3) + 4).to(torch.uint8) - q_scales_offset = torch.round(qdq_scale * get_reciprocal(dq_scale)).clip(-32, 31) + 32 + all_L = blocks.mul_(get_reciprocal(qdq_scale.unsqueeze(-1))).round_().clamp_(-4, 3).add_(4).to(torch.uint8) + + q_scales_offset = (qdq_scale * get_reciprocal(dq_scale)).round_().clamp_(-32, 31).add_(32) + output_scale = np.empty((nb, K_SCALE_SIZE), dtype=np.uint8) q_scales_offset = q_scales_offset.cpu().numpy().astype(np.uint8) output_scale[:, :8] = (q_scales_offset[:, :8] & 0xF) | ((q_scales_offset[:, 8:] & 0xF) << 4) hmask = q_scales_offset >> 4 @@ -554,24 +735,27 @@ def q3_k_quant_block(blocks: np.array, scale=None, d_scale=None, original=False, @register_qtype("q4_k") -def q4_k_quant_block(blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, imatrix=None, original=False, **kwargs): +def q4_k_quant_block( + blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, imatrix=None, original=False, split_num=1, **kwargs +): nb = blocks.shape[0] blocks = blocks.reshape((nb, QK_K // 32, 32)) - output_scale = torch.empty((nb, K_SCALE_SIZE), dtype=torch.uint8, device=blocks.device) - if scale is not None: scales = scale.reshape(-1, QK_K // 32) mins = wmin.reshape(-1, QK_K // 32) output_d = d_scale.reshape(-1, 1).to(torch.float32) output_dmin = d_wmin.reshape(-1, 1).to(torch.float32) - q_scales = torch.round(scales * get_reciprocal(output_d)).clip(0, 63).to(torch.uint8) - q_mins = torch.round(mins * get_reciprocal(output_dmin)).clip(0, 63).to(torch.uint8) + q_scales = (scales * get_reciprocal(output_d)).round_().clamp_(0, 63).to(torch.uint8) + q_mins = (mins * get_reciprocal(output_dmin)).round_().clamp_(0, 63).to(torch.uint8) all_L = ( - torch.round((blocks + mins.unsqueeze(-1)) * get_reciprocal(scales.unsqueeze(-1))) - .clip(0, 15) + blocks.add_(mins.unsqueeze(-1)) + .mul_(get_reciprocal(scales.unsqueeze(-1))) + .round_() + .clamp_(0, 15) .to(torch.uint8) ) + elif original: scales, all_L, mins = make_qkx2_quants(blocks, bits=4, rmin=-1, rdelta=0.1, nstep=20, use_mad=False) max_scales = torch.max(scales, dim=-1, keepdim=True)[0] @@ -580,17 +764,21 @@ def q4_k_quant_block(blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, i id_mins = (63 * get_reciprocal(max_mins)).clamp(min=0) output_d = max_scales / 63 output_dmin = max_mins / 63 - q_scales = torch.round(id_scales * scales).clip(0, 63).to(torch.uint8) - q_mins = torch.round(id_mins * mins).clip(0, 63).to(torch.uint8) + q_scales = (id_scales * scales).round_().clamp_(0, 63).to(torch.uint8) + q_mins = (id_mins * mins).round_().clamp_(0, 63).to(torch.uint8) d_tmp = output_d * q_scales dm_tmp = output_dmin * q_mins replace_ids = d_tmp != 0 all_L[replace_ids] = ( - torch.round((blocks[replace_ids] + dm_tmp[replace_ids].unsqueeze(-1)) / d_tmp[replace_ids].unsqueeze(-1)) - .clip(0, 15) + blocks[replace_ids] + .add_(dm_tmp[replace_ids].unsqueeze(-1)) + .div_(d_tmp[replace_ids].unsqueeze(-1)) + .round_() + .clamp_(0, 15) .to(torch.uint8) ) + else: from auto_round.data_type.gguf import quant_tensor_gguf_asym_dq @@ -604,14 +792,16 @@ def q4_k_quant_block(blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, i mins = mins.reshape((-1, QK_K // 32)) output_d = d_scale.reshape(-1, 1).to(torch.float32) output_dmin = d_wmin.reshape(-1, 1).to(torch.float32) - q_scales = torch.round(scales * get_reciprocal(output_d)).clip(0, 63).to(torch.uint8) - q_mins = torch.round(mins * get_reciprocal(output_dmin)).clip(0, 63).to(torch.uint8) + q_scales = (scales * get_reciprocal(output_d)).round_().clamp_(0, 63).to(torch.uint8) + q_mins = (mins * get_reciprocal(output_dmin)).round_().clamp_(0, 63).to(torch.uint8) all_L = ( - torch.round((blocks + mins.unsqueeze(-1)) * get_reciprocal(scales.unsqueeze(-1))) - .clip(0, 15) + blocks.add_(mins.unsqueeze(-1)) + .mul_(get_reciprocal(scales.unsqueeze(-1))) + .round_() + .clamp_(0, 15) .to(torch.uint8) ) - + output_scale = torch.empty((nb, K_SCALE_SIZE), dtype=torch.uint8, device=blocks.device) output_scale[:, :4] = q_scales[:, :4] output_scale[:, 4:8] = q_mins[:, :4] @@ -634,25 +824,34 @@ def q4_k_quant_block(blocks, scale=None, wmin=None, d_scale=None, d_wmin=None, i @register_qtype("q5_k") def q5_k_quant_block( - blocks, scale=None, zp=None, wmin=None, d_scale=None, d_wmin=None, imatrix=None, original=False, **kwargs + blocks, + scale=None, + zp=None, + wmin=None, + d_scale=None, + d_wmin=None, + imatrix=None, + original=False, + **kwargs, ): nb = blocks.shape[0] blocks = blocks.reshape((nb, QK_K // 32, 32)) - output_scale = torch.empty((nb, K_SCALE_SIZE), dtype=torch.uint8, device=blocks.device) - if scale is not None: scales = scale.reshape(-1, QK_K // 32) mins = wmin.reshape(-1, QK_K // 32) output_d = d_scale.reshape(-1, 1).to(torch.float32) output_dmin = d_wmin.reshape(-1, 1).to(torch.float32) - q_scales = torch.round(scales * get_reciprocal(output_d)).clip(0, 63).to(torch.uint8) - q_mins = torch.round(mins * get_reciprocal(output_dmin)).clip(0, 63).to(torch.uint8) + q_scales = (scales * get_reciprocal(output_d)).round_().clamp_(0, 63).to(torch.uint8) + q_mins = (mins * get_reciprocal(output_dmin)).round_().clamp_(0, 63).to(torch.uint8) all_L = ( - torch.round((blocks + mins.unsqueeze(-1)) * get_reciprocal(scales.unsqueeze(-1))) - .clip(0, 31) + blocks.add_(mins.unsqueeze(-1)) + .mul_(get_reciprocal(scales.unsqueeze(-1))) + .round_() + .clamp_(0, 31) .to(torch.uint8) ) + elif original: scales, all_L, mins = make_qkx2_quants(blocks, bits=5, rmin=-0.5, rdelta=0.1, nstep=15, use_mad=False) max_scales = torch.max(scales, dim=-1, keepdim=True)[0] @@ -661,15 +860,18 @@ def q5_k_quant_block( id_mins = (63 * get_reciprocal(max_mins)).clamp(min=0) output_d = max_scales / 63 output_dmin = max_mins / 63 - q_scales = torch.round(id_scales * scales).clip(0, 63).to(torch.uint8) - q_mins = torch.round(id_mins * mins).clip(0, 63).to(torch.uint8) + q_scales = (id_scales * scales).round_().clamp_(0, 63).to(torch.uint8) + q_mins = (id_mins * mins).round_().clamp_(0, 63).to(torch.uint8) d_tmp = output_d * q_scales dm_tmp = output_dmin * q_mins replace_ids = d_tmp != 0 all_L[replace_ids] = ( - torch.round((blocks[replace_ids] + dm_tmp[replace_ids].unsqueeze(-1)) / d_tmp[replace_ids].unsqueeze(-1)) - .clip(0, 31) + blocks[replace_ids] + .add_(dm_tmp[replace_ids].unsqueeze(-1)) + .div_(d_tmp[replace_ids].unsqueeze(-1)) + .round_() + .clamp_(0, 31) .to(torch.uint8) ) else: @@ -685,13 +887,16 @@ def q5_k_quant_block( mins = mins.reshape((-1, QK_K // 32)) output_d = d_scale.reshape(-1, 1).to(torch.float32) output_dmin = d_wmin.reshape(-1, 1).to(torch.float32) - q_scales = torch.round(scales * get_reciprocal(output_d)).clip(0, 63).to(torch.uint8) - q_mins = torch.round(mins * get_reciprocal(output_dmin)).clip(0, 63).to(torch.uint8) + q_scales = (scales * get_reciprocal(output_d)).round_().clamp_(0, 63).to(torch.uint8) + q_mins = (mins * get_reciprocal(output_dmin)).round_().clamp_(0, 63).to(torch.uint8) all_L = ( - torch.round((blocks + mins.unsqueeze(-1)) * get_reciprocal(scales.unsqueeze(-1))) - .clip(0, 31) + blocks.add_(mins.unsqueeze(-1)) + .mul_(get_reciprocal(scales.unsqueeze(-1))) + .round_() + .clamp_(0, 31) .to(torch.uint8) ) + output_scale = torch.empty((nb, K_SCALE_SIZE), dtype=torch.uint8, device=blocks.device) output_scale[:, :4] = q_scales[:, :4] output_scale[:, 4:8] = q_mins[:, :4] @@ -721,23 +926,26 @@ def q5_k_quant_block( def q6_k_quant_block(blocks: np.array, scale=None, d_scale=None, original=False, imatrix=None, **kwargs): nb = blocks.shape[0] blocks = blocks.reshape((nb, QK_K // 16, 16)) - + device = blocks.device if scale is not None: scales = scale.reshape(-1, QK_K // 16) output_d = d_scale.reshape(-1, 1).to(torch.float32) - output_scale = torch.round(scales * get_reciprocal(output_d)).clip(max=127).to(torch.int8) - all_L = torch.round(blocks * get_reciprocal(scales.unsqueeze(-1)) + 32).clip(0, 63).to(torch.uint8) + rd = get_reciprocal(output_d) + output_scale = scales.mul(rd).round_().clamp_(max=127).to(torch.int8) + rs = get_reciprocal(scales).unsqueeze_(-1) # unsqueeze for broadcasting + all_L = blocks.mul_(rs).add_(32).round_().clamp_(0, 63).to(torch.uint8) elif original: scales, all_L = make_qx_quants(blocks, bits=6, rmse_type=1, qw=None) imax = abs(scales).argmax(dim=-1, keepdim=True) max_scales = torch.take_along_dim(scales, imax, dim=-1) + iscales = -128 * get_reciprocal(max_scales) output_d = get_reciprocal(iscales) - output_scale = torch.round(iscales * scales).clip(max=127).to(torch.int8) + output_scale = (iscales * scales).round_().clamp_(max=127).to(torch.int8) d_tmp = output_d * output_scale.to(torch.float32) replace_ids = d_tmp != 0 all_L[replace_ids] = ( - torch.round(blocks[replace_ids] / d_tmp[replace_ids].reshape(-1, 1) + 32).clip(0, 63).to(torch.uint8) + blocks[replace_ids].div(d_tmp[replace_ids]).reshape(-1, 1).add_(32).round_().clamp_(0, 63).to(torch.uint8) ) else: from auto_round.data_type.gguf import quant_tensor_gguf_sym_dq @@ -745,11 +953,12 @@ def q6_k_quant_block(blocks: np.array, scale=None, d_scale=None, original=False, blocks = blocks.reshape(blocks.shape[0], -1) blocks, scales, _ = quant_tensor_gguf_sym_dq(blocks, bits=6, scale_dtype=torch.float32, imatrix=imatrix) scales, d_scale = scales["scale"], scales["d_scale"] + blocks = blocks.reshape((nb, QK_K // 16, 16)) scales = scales.reshape((-1, QK_K // 16)) output_d = d_scale.reshape(-1, 1).to(torch.float32) - output_scale = torch.round(scales * get_reciprocal(output_d)).clip(max=127).to(torch.int8) - all_L = torch.round(blocks * get_reciprocal(scales.unsqueeze(-1)) + 32).clip(0, 63).to(torch.uint8) + output_scale = (scales * get_reciprocal(output_d)).round_().clamp_(max=127).to(torch.int8) + all_L = blocks.mul_(get_reciprocal(scales.unsqueeze(-1))).add_(32).round_().clamp_(0, 63).to(torch.uint8) tmp_L = all_L.reshape(nb, 4, 64) & 0xF output_ql = (tmp_L[:, ::2] | (tmp_L[:, 1::2] << 4)).reshape(nb, QK_K // 2).cpu().numpy().astype(np.uint8) diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index 6b1717e7b..3241f0cb1 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -308,11 +308,27 @@ def json_serialize(obj: Any): def get_reciprocal(tensor): - if torch.dtype is torch.float16: - tensor = torch.sign(tensor) * torch.clamp(torch.abs(tensor), min=1e-5) - else: - tensor = torch.where(torch.abs(tensor) < 1e-30, 0, tensor) - return torch.where(tensor != 0, 1 / tensor, torch.zeros_like(tensor)) + """ + Memory-frugal reciprocal: + - Inplace operations on original tensor + - Only allocates small boolean mask + """ + eps = 1e-5 if tensor.dtype == torch.float16 else 1e-30 + + # Create mask for very small elements (small overhead) + mask = tensor.abs() < eps + + # Prepare output in place: reuse tensor if allowed, otherwise create once + recip = torch.empty_like(tensor) + + # Safe reciprocal: for nonzero elements + nonzero_mask = ~mask + recip[nonzero_mask] = 1.0 / tensor[nonzero_mask] + + # Zero out elements below threshold + recip[mask] = 0.0 + + return recip def normalize_input(decoding_layer_inputs: list[tuple[Any]]) -> Tuple[List[torch.Tensor], Dict[str, Any]]: diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index 38674c508..12c904b3e 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -409,34 +409,57 @@ def bytes_to_gigabytes(bytes) -> int: def _clear_memory_for_cpu_and_cuda( - tensor: torch.Tensor | list[torch.Tensor] | None = None, device_list: tuple | list | None = None + tensor: torch.Tensor | list[torch.Tensor] | None = None, + device_list: tuple | list | str | torch.device | None = None, ): + # ------------------------ + # Clear CPU-side references + # ------------------------ if isinstance(tensor, list): for i in range(len(tensor)): tensor[i] = None - if tensor is not None: - del tensor + tensor = None gc.collect() + + # ------------------------ + # Normalize device_list + # ------------------------ + if isinstance(device_list, (str, torch.device)): + device_list = [device_list] + + # ----------------------------------- + # CUDA-specific clearing + # ----------------------------------- if torch.cuda.is_available(): - if device_list is None: - torch.cuda.synchronize() + # No device_list → clear all GPUs + if not device_list: # Fix https://github.com/intel/auto-round/issues/1004 + torch.cuda.synchronize() torch.cuda.empty_cache() - - elif len(device_list) > 1: + else: + # Parse valid CUDA device IDs devices = [] - for device in device_list: - if not device.startswith("cuda"): + for dev in device_list: + dev = str(dev) + if not dev.startswith("cuda"): continue - if ":" in device: - device = device.split(":")[-1] + # cuda / cuda:0 / cuda:1 + if ":" in dev: + devid = int(dev.split(":")[-1]) else: - device = 0 - devices.append(int(device)) - for device in devices: - torch.cuda.synchronize(device) + devid = 0 + devices.append(devid) + + for d in devices: + torch.cuda.synchronize(d) + torch.cuda.empty_cache() - if torch.xpu.is_available(): + + # ----------------------------------- + # XPU-specific clearing + # ----------------------------------- + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.synchronize() torch.xpu.empty_cache() @@ -1344,7 +1367,7 @@ def update(self, device_list=None): process = psutil.Process() current_ram = process.memory_info().rss / 1024**3 # GB self.peak_ram = max(self.peak_ram, current_ram) - if device_list is None: # TODO this have issue, wait for clean memory all pass device_list + if device_list is None: # TODO this has issue, wait for clean_memory all pass device_list device_list = [0] if device_list is not None: if not isinstance(device_list, (list, tuple)): @@ -1356,12 +1379,16 @@ def update(self, device_list=None): device_list = list(range(torch.xpu.device_count())) for device in device_list: - if device == "cpu": + if str(device) == "cpu": continue if torch.cuda.is_available(): current_vram = torch.cuda.memory_reserved(device) / 1024**3 # GB + if device == "cuda": + device = "0" elif torch.xpu.is_available(): current_vram = torch.xpu.memory_reserved(device) / 1024**3 # GB + if device == "xpu": + device = "0" else: return diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index ff9b3b57c..1c2fc7987 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -29,6 +29,33 @@ from auto_round.schemes import QuantizationScheme +def clean_module_parameter(submodule: torch.nn.Module, param_name: str) -> None: + """This function is recommended to be used instead of module.weight = None. + For models like `tie_word_embeddings`, setting the embedding weight to None + causes `lm_head` to reallocate memory for its weight instead of treating it as a "bound shared weight," + it's now iterated over as an independent parameter, + resulting in an additional `lm_head` parameter in `named_parameters`. + + Args: + submodule (torch.nn.Module): submodule to clean + param_name (str): "weight" or "bias" + """ + if submodule is None: + return + is_buffer = param_name in submodule._buffers + with torch.no_grad(): + if is_buffer: + buf = submodule._buffers[param_name] + if buf is not None: + buf.data = torch.empty(0, dtype=buf.dtype, device=buf.device) + buf.requires_grad = False + else: + param = submodule._parameters[param_name] + if param is not None: + param.data = torch.empty(0, dtype=param.dtype, device=param.device) + param.requires_grad = False + + def convert_dtype_str2torch(str_dtype): """Converts a string dtype to its corresponding PyTorch dtype. diff --git a/auto_round/wrapper.py b/auto_round/wrapper.py index 4d2c7d5fd..4b599d3a5 100644 --- a/auto_round/wrapper.py +++ b/auto_round/wrapper.py @@ -319,7 +319,7 @@ def unwrapper(self, best_params): if self.orig_layer.weight.device.type == "meta": self.orig_layer.to(self.device) - ##unwrapper weight + # Unwrapper weight qdq_weight, scale, zp = self._qdq_weight(v, min_scale, max_scale) # if hasattr(self.orig_layer, "imatrix"): # self.orig_layer.imatrix = None @@ -380,7 +380,7 @@ def _set_dict_attr(attr_dict, attr_name): self.orig_layer.update() self.orig_layer.to("meta") - ##unwrapper act + # Unwrapper act if self.enable_act_quant: if not self.orig_layer.act_dynamic: act_max_scale = best_params.get("act_max_scale", torch.tensor(1.0)).to(self.device) diff --git a/docs/step_by_step.md b/docs/step_by_step.md index 29924bc79..e8c1dca8d 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -559,6 +559,7 @@ autoround.save_quantized(format="auto_awq", output_dir="tmp_autoround") - **Reduced CPU Memory Usage :** + - Enable low_cpu_mem_usage (experimental): Only one export format is supported. The quantized model is saved immediately after each block is packed, reducing peak CPU memory usage. - Trigger immediate packing: Packing will be triggered immediately when using the command-line interface or the quantize_and_save API, as long as only one export format is specified.