diff --git a/README.md b/README.md index ddaecc53a..29eeae56f 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,8 @@ See our [paper](https://arxiv.org/pdf/2309.05516) for more details. For usage in ## 🆕 What's New +[2025/11] AutoRound now offers preliminary support for an **enhanced GGUF quantization algorithm** via `--enable_alg_ext`. For detailed accuracy benchmarks, please refer to the accompanying [documentation](./docs/gguf_alg_ext_acc.md). + [2025/10] AutoRound has been integrated into **SGLang**. You can now run models in the AutoRound format directly using the latest SGLang later than v0.5.4. [2025/10] We enhanced the RTN mode (--iters 0) to significantly reduce quantization cost compared to the default tuning mode. Check out [this doc](./docs/opt_rtn.md) for some accuracy results. If you don’t have sufficient resources, you can use this mode for 4-bit quantization. diff --git a/auto_round/alg_ext.abi3.so b/auto_round/alg_ext.abi3.so index 4b3f3bca3..423a615ca 100755 Binary files a/auto_round/alg_ext.abi3.so and b/auto_round/alg_ext.abi3.so differ diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index c098320f9..f31bccbbc 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -381,6 +381,16 @@ def __init__( self.attention_mask = [] + self.wrapper_block = wrapper_block + if self.enable_alg_ext: + try: + logger.warning_once("using algorithm extension for quantization.") + from auto_round.alg_ext import wrapper_autoround + + wrapper_autoround(self) + except (ImportError, ModuleNotFoundError): + logger.error("algorithm extension import error, fallback to default mode") + def _gen_auto_scheme( self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device] ) -> dict[str, dict]: @@ -2495,6 +2505,32 @@ def quantize_block( input_ids, input_others = normalize_input(inputs) return self._quantize_block(block, input_ids, input_others, q_input, device, auto_offload) + def _get_loss( + self, + output_q: torch.Tensor, + current_output: torch.Tensor, + indices: torch.Tensor, + mse_loss: Callable, + device: Union[str, torch.device] = "cpu", + ): + if self.attention_mask: + tmp_attention_mask = [self.attention_mask[i] for i in indices] + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask.unsqueeze_(-1) + else: + tmp_attention_mask = 1.0 + if self.amp: + with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): + loss = mse_loss( # pylint: disable=not-callable + output_q * tmp_attention_mask, current_output * tmp_attention_mask + ) + else: + loss = mse_loss( # pylint: disable=not-callable + output_q.to(torch.float32) * tmp_attention_mask, + current_output.to(torch.float32) * tmp_attention_mask, + ) + return loss + def _quantize_block( self, block: torch.nn.Module, @@ -2579,7 +2615,7 @@ def _quantize_block( clear_memory(device_list=self.device_list) input_ids = q_input - quantized_layer_names, unquantized_layer_names = wrapper_block( + quantized_layer_names, unquantized_layer_names = self.wrapper_block( block, self.enable_minmax_tuning, self.enable_norm_bias_tuning, @@ -2654,6 +2690,9 @@ def _quantize_block( num_elm = self._get_current_num_elm(input_ids, whole_indices) for i in range(self.iters): + if self.enable_alg_ext and self.data_type.endswith("dq"): + for n, m in block.named_modules(): + m.cur_iter = i total_loss = 0 if self.sampler == "rand": whole_indices = torch.randperm(nsamples)[:global_batch_size] @@ -2667,25 +2706,7 @@ def _quantize_block( output_q = self._get_current_q_output(block, input_ids, input_others, indices, device, loss_device) - if self.attention_mask: - tmp_attention_mask = [self.attention_mask[i] for i in indices] - tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(loss_device) - tmp_attention_mask.unsqueeze_(-1) - num_elm = torch.sum(tmp_attention_mask).item() - if num_elm == 0: - num_elm = 1 - else: - tmp_attention_mask = 1.0 - if self.amp: - with autocast(device_type=str(loss_device).split(":")[0], dtype=self.amp_dtype): - loss = mse_loss( # pylint: disable=not-callable - output_q * tmp_attention_mask, current_output * tmp_attention_mask - ) - else: - loss = mse_loss( # pylint: disable=not-callable - output_q.to(torch.float32) * tmp_attention_mask, - current_output.to(torch.float32) * tmp_attention_mask, - ) + loss = self._get_loss(output_q, current_output, indices, mse_loss, device) total_loss += loss.item() / num_elm @@ -2815,44 +2836,6 @@ def _quantize_blocks( for i in range(len(input_others[key])): to_dtype(input_others[key][i], tmp_dtype) - if ( - self.sym - and self.enable_alg_ext - and self.super_group_size is None - and ( - (self.data_type.startswith("int") and self.act_bits >= 8) - or self.data_type.startswith("mx") - or self.data_type.startswith("nv") - ) - ): - try: - from auto_round.alg_ext import quantize_block_ext - - BaseCompressor.quantize_block_ext = quantize_block_ext - quantize_block = self.quantize_block_ext # must use self.quantize_block_ext - if self.bits > 2 and (not self.data_type.startswith("mx") or not self.data_type.startswith("nv")): - logger.warning( - "algorithm extension has only undergone limited validation on " - "INT2,mxfp4 and nvfp4; use with caution." - ) - else: - logger.info("using algorithm extension for quantization.") - except (ImportError, ModuleNotFoundError): - logger.error("algorithm extension import error, fallback to default mode") - quantize_block = self._quantize_block - elif self.enable_alg_ext and self.data_type.endswith("dq"): - try: - from auto_round.alg_ext import dq_quantize_block_ext - - BaseCompressor.dq_quantize_block_ext = dq_quantize_block_ext - quantize_block = self.dq_quantize_block_ext - logger.info("using algorithm extension for quantization.") - except (ImportError, ModuleNotFoundError): - logger.error("algorithm extension import error, fallback to default mode") - quantize_block = self._quantize_block - else: - quantize_block = self._quantize_block - if pbar is None: pbar = tqdm(range(0, len(block_names), nblocks)) @@ -2870,7 +2853,7 @@ def _quantize_blocks( m = WrapperMultiblock(modules) m.config = model.config if hasattr(model, "config") else None - q_input, input_ids = quantize_block( + q_input, input_ids = self._quantize_block( m, input_ids, input_others, diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 699466dc8..9f02fa163 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -72,6 +72,7 @@ def quant_tensor_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh=1e-5 else: imatrix = imatrix.reshape(1, -1) + imatrix = reshape_pad_tensor_by_group_size(imatrix, group_size, val=1e-5)[0].view(1, -1) imatrix = imatrix.expand(tensor.numel() // imatrix.numel(), -1) imatrix = imatrix.reshape(tensor.shape) diff --git a/auto_round/data_type/utils.py b/auto_round/data_type/utils.py index 1bb53a14b..acdc9db93 100644 --- a/auto_round/data_type/utils.py +++ b/auto_round/data_type/utils.py @@ -23,7 +23,7 @@ from auto_round.utils import logger -def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int): +def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int, val: float = 0.0): """Reshapes and pads the tensor to ensure that it can be quantized in groups of `group_size`. This function adjusts the @@ -55,7 +55,7 @@ def reshape_pad_tensor_by_group_size(data: torch.Tensor, group_size: int): return data, orig_shape, pad_len else: pad_len = (data.shape[1] + group_size - 1) // group_size * group_size - data.shape[1] - data_new = torch.nn.functional.pad(data, (0, pad_len)) + data_new = torch.nn.functional.pad(data, (0, pad_len), value=val) data_new = data_new.reshape(-1, group_size) return data_new, orig_shape, pad_len diff --git a/docs/gguf_alg_ext_acc.md b/docs/gguf_alg_ext_acc.md index 8b874ae25..555f68e60 100644 --- a/docs/gguf_alg_ext_acc.md +++ b/docs/gguf_alg_ext_acc.md @@ -8,9 +8,21 @@ to stabilize accuracy during evaluation. All other settings follow the default c |method|scheme|Llama-3.1-8B|Qwen2.5-7B-Instruct|Qwen3-8b|Qwen3-30B-A3B-Instruct-2507| |:-----|:-----|:-----------|:------------------|:-------|:--------------------------| |**BF16** | - |0.6295(100%)|0.6571(100%) |0.6322(100%)|0.6746(100%) | -| **original** | q2_k_s | 0.5535(87.92%)| 0.6266(95.35%)|0.5901(93.35%)|0.6386(94.66%)| -| **enable_alg_ext** |q2_k_s|0.5740(91.18%)|0.6349(96.62%)|0.5962(94.31%)|0.6460(95.77%)| -| **original** | q3_k_s | 0.6040(95.95%)|0.6382(97.12%)|0.6128(96.94%)|0.6598(97.82%)| -| **enable_alg_ext** |q3_k_s|0.6081(96.59%)|0.6503(98.97%)|0.6252(98.89%)|0.6622(98.17%)| -| **original** | q4_k_s | 0.6228(98.94%)|0.6560(99.83%)|0.6303(99.70%)|0.6762(100.24%)| -| **enable_alg_ext** |q4_k_s|0.6239(99.11%)|0.6605(100.51%)|0.6320(99.98%)|0.6777(100.46%)| \ No newline at end of file +| **Optimized RTN** | q2_k_s | 0.5535(87.92%)| 0.6266(95.35%)|0.5901(93.35%)|0.6386(94.66%)| +| **AutoRound+alg_ext** |q2_k_s|0.5740(91.18%)|0.6349(96.62%)|0.5962(94.31%)|0.6460(95.77%)| +| **Optimized RTN** | q3_k_s | 0.6040(95.95%)|0.6382(97.12%)|0.6128(96.94%)|0.6598(97.82%)| +| **AutoRound+alg_ext** |q3_k_s|0.6081(96.59%)|0.6503(98.97%)|0.6252(98.89%)|0.6622(98.17%)| +| **Optimized RTN** | q3_k_m |0.6083(96.63%) |0.6418(97.68%)|0.6194(97.97%)|| +| **AutoRound+alg_ext** |q3_k_m|0.6127(97.33%)|0.6533(99.42%)|0.6197(98.02%)|| +| **Optimized RTN** | q4_k_s | 0.6228(98.94%)|0.6560(99.83%)|0.6303(99.70%)|0.6762(100.24%)| +| **AutoRound+alg_ext** |q4_k_s|0.6239(99.11%)|0.6605(100.51%)|0.6320(99.98%)|0.6777(100.46%)| +| **Optimized RTN** | q4_k_m |0.6252(99.32%) |0.6558(99.80%)|0.6296(99.59%)|| +| **AutoRound+alg_ext** |q4_k_m|0.6257(99.40%)|0.6575(100.06%)|0.6340(100.29%)|| + +**Time cost** +|model |Optimized RTN |AutoRound+alg_ext| +|:--------------------------|:-------------|:----------------| +|Llama-3.1-8B |1m25s |29m43s | +|Qwen2.5-7B-Instruct |1m20s |35m35s | +|Qwen3-8b |1m29s |47m58s | +|Qwen3-30B-A3B-Instruct-2507|25m12s |12h47m39s | \ No newline at end of file diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index dd188e6ad..71c6ee506 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -699,7 +699,7 @@ def test_alg_ext(self): ar.quantize() def test_alg_ext_import(self): - from auto_round.alg_ext import dq_quantize_block_ext, quantize_block_ext + from auto_round.alg_ext import wrapper_autoround def test_invalid_layer_config(self): with self.assertRaises(ValueError):