diff --git a/README.md b/README.md index 29eeae56f..598c3fd54 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ See our [paper](https://arxiv.org/pdf/2309.05516) for more details. For usage in ## 🆕 What's New -[2025/11] AutoRound now offers preliminary support for an **enhanced GGUF quantization algorithm** via `--enable_alg_ext`. For detailed accuracy benchmarks, please refer to the accompanying [documentation](./docs/gguf_alg_ext_acc.md). +[2025/11] AutoRound now offers preliminary support for an enhanced GGUF quantization algorithm via `--enable_alg_ext`. For detailed accuracy benchmarks, please refer to the [documentation](./docs/gguf_alg_ext_acc.md). [2025/10] AutoRound has been integrated into **SGLang**. You can now run models in the AutoRound format directly using the latest SGLang later than v0.5.4. @@ -46,8 +46,7 @@ refer to the documentation for accuracy [results](./docs/auto_scheme_acc.md) and for some accuracy results. [2025/07] AutoRound now offers experimental support for **GGUF** format, and recommends using optimized RTN mode (--iters 0) for - all bits other than 3 bits. **A more advanced algorithm** tailored for specific configurations may be available in - v0.8.1. + all bits other than 3 bits. [2025/05] AutoRound has been integrated into **Transformers** and **vLLM**. @@ -192,7 +191,7 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round") - **`layer_config` (dict)**: Configuration for weight quantization (default is `None`), mainly for mixed schemes. ##### Algorithm Settings -- **`enable_alg_ext` (bool)**: Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`. +- **`enable_alg_ext` (bool)**: [Experimental Feature] Enable algorithm variants for specific schemes (e.g., MXFP4/W2A16) that could bring notable improvements. Default is `False`. - **`disable_opt_rtn` (bool)**: Use pure RTN mode for specific schemes (e.g., GGUF and WOQ). Default is `False` (improved RTN enabled). ##### Tuning Process Parameters @@ -208,6 +207,7 @@ ar.quantize_and_save(output_dir="./qmodel", format="auto_round") ##### Device/Speed Configuration - **`enable_torch_compile` (bool)**: If no exception is raised, typically we recommend setting it to True for faster quantization with lower resource. - **`low_gpu_mem_usage` (bool)**: Whether to offload intermediate features to CPU at the cost of ~20% more tuning time (default is `False`). +- **`low_cpu_mem_usage` (bool)**: [Experimental Feature]Whether to enable saving immediately to save ram usage (default is `False`). - **`device_map` (str|dict|int)**: The device to be used for tuning, e.g., `auto`, "cpu"`, `"cuda"`, `"0,1,2"` (default is `'0'`). When using "auto", it will try to use all available GPUs. diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 76a8f73d1..19fef935a 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -172,6 +172,12 @@ def __init__(self, *args, **kwargs): type=float, help="Learning rate specifically for min-max tuning. " "If None, uses the same value as --lr. ", ) + tuning.add_argument( + "--momentum", + default=0, + type=float, + help="Momentum factor for the optimizer. Default is 0 (no momentum).", + ) tuning.add_argument( "--gradient_accumulate_steps", default=1, @@ -591,6 +597,7 @@ def tune(args): extra_config=extra_config, layer_config=layer_config, model_dtype=args.model_dtype, + momentum=args.momentum, ) model_name = args.model.rstrip("/") diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index cadf81d9d..871fbcb2c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -193,7 +193,7 @@ def __init__( super_group_size, super_bits, scale_dtype ("fp16" etc.), nblocks, to_quant_block_names, enable_norm_bias_tuning, enable_quanted_input, - disable_deterministic_algorithms, mllm, static_kv_dtype,enable_deterministic_algorithms + disable_deterministic_algorithms, mllm, static_kv_dtype,enable_deterministic_algorithms,momentum Raises: ValueError: If invalid device is provided or tokenizer is missing for non-str model with iters > 0. RuntimeError: If model parameters are on meta device. @@ -234,6 +234,7 @@ def __init__( enable_quanted_input: bool = kwargs.pop("enable_quanted_input", True) disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True) enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False) + self.momentum = kwargs.pop("momentum", 0.0) static_kv_dtype = kwargs.pop("static_kv_dtype", None) model_dtype = kwargs.pop("model_dtype", None) device = kwargs.pop("device", None) @@ -1567,11 +1568,12 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: # It is best to modify the model structure in the quantize function and check the format, # because it may cause the gguf format to not be exported normally. self.model = _handle_moe_model(self.model, formats=formats) - # Assign temporary names after replacing modules - for n, m in self.model.named_modules(): # TODO check if could removed + + # Temporary names must be assigned after handle_moe_model; + # placing them earlier would cause them to be removed when the module is replaced. + for n, m in self.model.named_modules(): m.tmp_name = n - # TODO check scale_dtype if not self.is_auto_scheme: enable_gguf_official_mixed = True else: @@ -2661,12 +2663,24 @@ def _quantize_block( lr = torch.tensor(self.lr) minmax_lr = torch.tensor(self.minmax_lr) + is_adam = "adam" in self.__class__.__name__.lower() + + extra_kwargs = {} if is_adam else {"momentum": self.momentum} + if self.enable_minmax_tuning: - optimizer = self.optimizer( - [{"params": round_params}, {"params": minmax_params, "lr": minmax_lr}], lr=lr, weight_decay=0 - ) + params = [ + {"params": round_params}, + {"params": minmax_params, "lr": minmax_lr}, + ] else: - optimizer = self.optimizer(round_params, lr=lr, weight_decay=0) + params = round_params + + optimizer = self.optimizer( + params, + lr=lr, + weight_decay=0, + **extra_kwargs, + ) if len(round_params) + len(minmax_params) <= 0: dump_info = ( diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index f20c6c7a6..577ccf34e 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -21,6 +21,7 @@ from auto_round.export.export_to_gguf.packing import make_q3_quants, make_qx_quants from auto_round.logger import logger from auto_round.utils import get_reciprocal +from auto_round.utils.device import clear_memory @register_dtype("int_sym_dq") @@ -320,7 +321,7 @@ def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tens @torch.no_grad() -def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None): +def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None, split_num=1): super_bits = 4 if bits == 2 else 6 super_group_size = 16 if bits == 2 else 8 group_size = 16 if bits == 2 else 32 @@ -348,6 +349,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri nstep=params["nstep"], use_mad=params["use_mad"], weights=quant_weights, + split_num=split_num, ) scale = scale.to(scale_dtype) scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) @@ -428,16 +430,8 @@ def quant_tensor_gguf_asym_dq( Args: tensor (torch.Tensor): Input tensor to quantize. bits (int): Number of bits for quantization. - group_size (int): Group size for per-group quantization. v (float): Perturbation added before rounding. - min_scale (float): Minimum allowed scale value. - max_scale (float): Maximum allowed scale value. scale_dtype (torch.dtype): Data type for quantized scale. - tensor_min (torch.Tensor, optional): Minimum values for the tensor groups. - tensor_max (torch.Tensor, optional): Maximum values for the tensor groups. - q_scale_thresh (float): Threshold to clamp the quantized scale. - super_group_size (int): Number of groups to bundle for secondary quantization. - super_bits (int): Number of bits used in secondary quantization. imatrix (torch.Tensor, optional): Importance matrix for weighted quantization. Returns: @@ -446,10 +440,19 @@ def quant_tensor_gguf_asym_dq( orig_dtype = tensor.dtype maxq = 2**bits - 1 group_size = 16 if bits == 2 else 32 + split_num = 1 + for dim in tensor.shape: + if dim > 100_000: + split_num = 16 + break + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + tensor = tensor.to(torch.float32) if scale is None: - scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym(tensor, bits, scale_dtype, imatrix) + scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym( + tensor, bits, scale_dtype, imatrix, split_num=split_num + ) inverse_scale = get_reciprocal(scale) int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq) @@ -458,7 +461,7 @@ def quant_tensor_gguf_asym_dq( return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} -def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None): +def iterative_wls_quant_search_non_chunk(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None): """Adapted from Llamacpp. Performs iterative weighted least squares quantization search. Args: @@ -526,6 +529,112 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u return scale.to(torch.float32), -rmin.to(torch.float32) +# TODO consolidate iterative_wls_quant_search_chunk and non-chunk +def iterative_wls_quant_search_chunk( + data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=8 +): + dtype = torch.float32 + data = data.to(dtype) + maxq = 2**bits - 1 + minq = 0 + weights = 1.0 if weights is None else weights.to(dtype) + + results_scale = [] + results_rmin = [] + chunk_size = (data.shape[0] + split_num - 1) // split_num + for start in range(0, data.shape[0], chunk_size): + end = min(start + chunk_size, data.shape[0]) + chunk = data[start:end] + chunk_weights = weights if isinstance(weights, float) else weights[start:end] + + rmin = torch.min(chunk, dim=1, keepdim=True)[0] + rmax = torch.max(chunk, dim=1, keepdim=True)[0] + sum_w = torch.sum(chunk_weights, dim=1, keepdim=True) + sum_x = torch.sum(chunk_weights * chunk, dim=1, keepdim=True) + scale = (rmax - rmin) / (maxq - minq) + iscale = get_reciprocal(scale) + quant_data = torch.clamp(torch.round(iscale * (chunk - rmin)), minq, maxq) + diff = scale * quant_data + rmin - chunk + best_mad = torch.sum( + (chunk_weights * torch.abs(diff)) if use_mad else chunk_weights * torch.pow(diff, 2), dim=1, keepdim=True + ) + + for is_ in range(nstep): + factor = rrmin + rdelta * is_ + maxq - minq + scale_new = (rmax - rmin) / factor + iscale_new = get_reciprocal(scale_new) + quant_data_new = torch.clamp(torch.round(iscale_new * (chunk - rmin)), minq, maxq) + mul_weights_quant_data = chunk_weights * quant_data_new + sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True) + sum_l2 = torch.sum(mul_weights_quant_data * quant_data_new, dim=-1, keepdim=True) + sum_xl = torch.sum(mul_weights_quant_data * chunk, dim=-1, keepdim=True) + D = sum_w * sum_l2 - torch.pow(sum_l, 2) + this_scale = (sum_w * sum_xl - sum_x * sum_l) / D + this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D + this_min[this_min > 0] = 0 + this_scale[this_min > 0] = (sum_xl / sum_l2)[this_min > 0] + reverse_this_scale = get_reciprocal(this_scale) + quant_data = torch.clamp(torch.round(reverse_this_scale * (chunk - this_min)), minq, maxq) + diff = this_scale * quant_data + this_min - chunk + mad = torch.sum( + (chunk_weights * torch.abs(diff)) if use_mad else chunk_weights * torch.pow(diff, 2), + dim=-1, + keepdim=True, + ) + idx_to_replace = torch.where((mad < best_mad) & (D > 0))[0] + best_mad[idx_to_replace] = mad[idx_to_replace] + scale[idx_to_replace] = this_scale[idx_to_replace] + rmin[idx_to_replace] = this_min[idx_to_replace] + results_scale.append(scale.to(torch.float32)) + results_rmin.append(-rmin.to(torch.float32)) + if split_num > 1: + clear_memory(device_list=[data.device]) + + return torch.cat(results_scale, dim=0), torch.cat(results_rmin, dim=0) + + +def iterative_wls_quant_search( + data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=1 +): + """Adapted from Llamacpp. Performs iterative weighted least squares quantization search. + + Args: + data (torch.Tensor): Input tensor to quantize. + bits (int): Number of quantization bits. + rrmin (float): Initial range scaling factor. + rdelta (float): Step size for range scaling. + nstep (int): Number of search steps. + use_mad (bool): Whether to use mean absolute deviation instead of squared error. + weights (torch.Tensor): Weight matrix for each element. + + Returns: + Tuple: (Optimal scale tensor, optimal minimum value tensor) + """ + + # TODO this one should change to try catch later + if split_num > 1: + return iterative_wls_quant_search_chunk( + data=data, + bits=bits, + rrmin=rrmin, + rdelta=rdelta, + nstep=nstep, + use_mad=use_mad, + weights=weights, + split_num=split_num, + ) + else: + return iterative_wls_quant_search_non_chunk( + data=data, + bits=bits, + rrmin=rrmin, + rdelta=rdelta, + nstep=nstep, + use_mad=use_mad, + weights=weights, + ) + + @torch.no_grad() def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype): from auto_round.export.export_to_gguf.config import K_SCALE_SIZE, QK_K @@ -550,7 +659,6 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype): return scale -# @register_dtype("rtn_int_sym_dq") def quant_tensor_gguf_sym_dq( tensor, @@ -566,7 +674,6 @@ def quant_tensor_gguf_sym_dq( Args: tensor: Tensor containing the tensor to be quantized bits: Number of bits for quantization (e.g., 2, 3, 4, 8) - group_size: Number of elements to share scale for quantization v: Rounding value perturbation min_scale: Minimum scale coefficient for tensor max_scale: Maximum scale coefficient for tensor diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 9f02fa163..8fc6f79a0 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -71,7 +71,6 @@ def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 imatrix = 1.0 else: imatrix = imatrix.reshape(1, -1) - imatrix = reshape_pad_tensor_by_group_size(imatrix, group_size, val=1e-5)[0].view(1, -1) imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) imatrix = imatrix.reshape(tensor.shape) diff --git a/auto_round/export/export_to_awq/utils.py b/auto_round/export/export_to_awq/utils.py index 0052ec9b1..871e4287a 100644 --- a/auto_round/export/export_to_awq/utils.py +++ b/auto_round/export/export_to_awq/utils.py @@ -316,10 +316,3 @@ def extra_repr(self) -> str: self.w_bit, self.group_size, ) - - -def clear_memory(weight=None): - if weight is not None: - del weight - gc.collect() - torch.cuda.empty_cache() diff --git a/auto_round/export/export_to_gguf/packing.py b/auto_round/export/export_to_gguf/packing.py index 4c64a75d5..bc9189b7b 100644 --- a/auto_round/export/export_to_gguf/packing.py +++ b/auto_round/export/export_to_gguf/packing.py @@ -16,7 +16,7 @@ import torch from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K -from auto_round.utils import get_reciprocal +from auto_round.utils import clear_memory, get_reciprocal GGML_QUANT_TYPE = {} @@ -66,6 +66,8 @@ def ggml_quant( wmin = wmin.to(device) if wmin is not None else wmin d_scale = d_scale.to(device) if d_scale is not None else d_scale d_wmin = d_wmin.to(device) if d_wmin is not None else d_wmin + imatrix = imatrix.to(device) if imatrix is not None else imatrix + clear_memory() new_data = quant_func( blocks, scale, zp=zp, wmin=wmin, d_scale=d_scale, d_wmin=d_wmin, imatrix=imatrix, original=original )