From f849461ea60b570a193cad53621cd10be30a0be6 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Tue, 28 Oct 2025 21:57:51 -0400 Subject: [PATCH 1/2] add self attribution and fix avg_bits error Signed-off-by: He, Xin3 --- auto_round/auto_scheme/gen_auto_scheme.py | 3 ++- auto_round/auto_scheme/utils.py | 11 ++++++++--- auto_round/compressors/base.py | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/auto_round/auto_scheme/gen_auto_scheme.py b/auto_round/auto_scheme/gen_auto_scheme.py index ca0abdfe0..fa7c8b7fa 100644 --- a/auto_round/auto_scheme/gen_auto_scheme.py +++ b/auto_round/auto_scheme/gen_auto_scheme.py @@ -153,4 +153,5 @@ def compute_avg_bit_range(self) -> tuple[float, float]: )[0] for option in self.auto_scheme.options ] - return min(avg_bits), max(avg_bits) + self.min_avg_bit, self.max_avg_bit = min(avg_bits), max(avg_bits) + return self.min_avg_bit, self.max_avg_bit diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index 0f2b00e06..f23b93c1f 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -160,6 +160,8 @@ def compute_layer_bits( n_param = weight.numel() weight_bits = getattr(layer, "bits", 16) group_size = getattr(layer, "group_size", 128) + data_type = getattr(layer, "data_type", "int") + is_sym = getattr(layer, "sym", False) super_group_size = getattr(layer, "super_group_size", None) super_weight_bits = getattr(layer, "super_bits", None) @@ -175,7 +177,7 @@ def compute_layer_bits( # Determine number of groups based on group size if group_size > 0: - n_group = out_features * (in_features + group_size - 1) // group_size + n_group = out_features * ((in_features + group_size - 1) // group_size) elif group_size == 0: n_group = 1 elif group_size == -1: @@ -185,9 +187,12 @@ def compute_layer_bits( # Compute auxiliary bits (scales, zero-points, or double quantization) aux_total_bits = 0 - if not super_group_size: + if "mx_fp" in data_type or "nv_fp" in data_type or "fp4" in data_type: + scale_bits = 8 + else: scale_bits = 16 - zp_bits = weight_bits + zp_bits = weight_bits if not is_sym else 0 + if not super_group_size: aux_total_bits = n_group * (scale_bits + zp_bits) else: aux_total_bits += n_group * super_weight_bits * 2 diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 9608e6ee4..ed784bed4 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -437,7 +437,7 @@ def _gen_auto_scheme( if not self.enable_torch_compile and self.super_bits is None: logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM") - gen_scheme = GenScheme( + self.scheme_generator = GenScheme( scheme, self.model, quant_layer_names, @@ -447,7 +447,7 @@ def _gen_auto_scheme( tokenizer=self.tokenizer, enable_torch_compile=self.enable_torch_compile, ) - layer_config = gen_scheme.get_layer_config() + layer_config = self.scheme_generator.get_layer_config() return layer_config def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None: From 97d8ba90db11558861ad5395fc706792fb4eaf70 Mon Sep 17 00:00:00 2001 From: "He, Xin3" Date: Tue, 28 Oct 2025 23:31:47 -0400 Subject: [PATCH 2/2] fix per review Signed-off-by: He, Xin3 --- auto_round/auto_scheme/utils.py | 2 +- test/test_cuda/test_auto_scheme.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index f23b93c1f..cc96455d6 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -191,7 +191,7 @@ def compute_layer_bits( scale_bits = 8 else: scale_bits = 16 - zp_bits = weight_bits if not is_sym else 0 + zp_bits = weight_bits if not is_sym or "int" in data_type else 0 if not super_group_size: aux_total_bits = n_group * (scale_bits + zp_bits) else: diff --git a/test/test_cuda/test_auto_scheme.py b/test/test_cuda/test_auto_scheme.py index ac486fef4..70366cf05 100644 --- a/test/test_cuda/test_auto_scheme.py +++ b/test/test_cuda/test_auto_scheme.py @@ -146,7 +146,7 @@ def test_min_target_bits(self): # def test_max_target_bits(self): model_name = "/models/opt-125m" - target_bits = 8.211 + target_bits = 8.025 scheme = AutoScheme(avg_bits=target_bits, options=("MXFP4", "W8A16")) ar = AutoRound(model=model_name, scheme=scheme, iters=0, nsamples=1) model, layer_config = ar.quantize()