From 7138aceac4e6325c5252ceadde389af1361581f2 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 5 Nov 2025 13:47:32 +0800 Subject: [PATCH 1/9] update --- auto_round/compressors/base.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index fa0c4c7e0..cabe0c9c7 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2271,10 +2271,10 @@ def _quantize_layer( init_loss = None gradient_accumulate_steps = self.batch_size # Force to low gpu batch_size = 1 # Force to low gpu - pick_samples = batch_size * gradient_accumulate_steps - pick_samples = min(nsamples, pick_samples) + global_batch_size = batch_size * gradient_accumulate_steps + global_batch_size = min(nsamples, global_batch_size) if self.sampler != "rand": - whole_indices = torch.randperm(nsamples)[:pick_samples] + whole_indices = torch.randperm(nsamples)[:global_batch_size] total_loss = 0 num_elm = 1 mse_reduction = "mean" @@ -2285,7 +2285,7 @@ def _quantize_layer( for i in range(self.iters): total_loss = 0 if self.sampler == "rand": - whole_indices = torch.randperm(nsamples)[:pick_samples] + whole_indices = torch.randperm(nsamples)[:global_batch_size] if gradient_accumulate_steps != 1: if q_inputs is not None: num_elm = self._get_current_num_elm(q_inputs, whole_indices) @@ -2564,10 +2564,10 @@ def _quantize_block( else: nsamples = len(input_ids) - pick_samples = self.batch_size * self.gradient_accumulate_steps - pick_samples = min(nsamples, pick_samples) + global_batch_size = self.batch_size * self.gradient_accumulate_steps + global_batch_size = min(nsamples, global_batch_size) if self.sampler != "rand": - whole_indices = torch.randperm(nsamples)[:pick_samples] + whole_indices = torch.randperm(nsamples)[:global_batch_size] last_best_iter = 0 best_loss = torch.finfo(torch.float).max num_elm = 1 @@ -2579,13 +2579,15 @@ def _quantize_block( init_loss = None best_params = {} total_loss = 0 + # We assume the block input and output shape is same + if self.gradient_accumulate_steps != 1: + whole_indices = torch.arange(global_batch_size) + num_elm = self._get_current_num_elm(input_ids, whole_indices) + for i in range(self.iters): total_loss = 0 if self.sampler == "rand": - whole_indices = torch.randperm(nsamples)[:pick_samples] - # We assume the block input and output shape is same - if self.gradient_accumulate_steps != 1: - num_elm = self._get_current_num_elm(input_ids, whole_indices) + whole_indices = torch.randperm(nsamples)[:global_batch_size] for tmp_step in range(self.gradient_accumulate_steps): indices = whole_indices[tmp_step * self.batch_size : (tmp_step + 1) * self.batch_size] @@ -2600,6 +2602,7 @@ def _quantize_block( tmp_attention_mask = [self.attention_mask[i] for i in indices] tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) tmp_attention_mask.unsqueeze_(-1) + num_elm = torch.sum(tmp_attention_mask).item() else: tmp_attention_mask = 1.0 if self.amp: @@ -2615,7 +2618,8 @@ def _quantize_block( total_loss += loss.item() / num_elm self._scale_loss_and_backward(scaler, loss) - clear_memory_if_reached_threshold(threshold=0.85) + if self.low_gpu_mem_usage: + clear_memory_if_reached_threshold(threshold=0.85) if i == 0: init_loss = total_loss From 72a53252c559b935f19817bf6e49ee8447707b2a Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 5 Nov 2025 17:36:20 +0800 Subject: [PATCH 2/9] fix lm-head and clear memory issue --- auto_round/compressors/base.py | 3 +-- auto_round/compressors/utils.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index cabe0c9c7..633211a8c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2618,8 +2618,7 @@ def _quantize_block( total_loss += loss.item() / num_elm self._scale_loss_and_backward(scaler, loss) - if self.low_gpu_mem_usage: - clear_memory_if_reached_threshold(threshold=0.85) + if i == 0: init_loss = total_loss diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 79e8b7133..7d6f28806 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -377,6 +377,9 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str if lm_head_name not in layer_config and quant_lm_head: layer_config[lm_head_name] = copy.deepcopy(default_dict) + if not quant_lm_head and not gguf_name: + layer_config.pop(lm_head_name, None) + # 8. enforce shape divisibility for int weight-only if default_dict["data_type"] == "int" and default_dict["act_bits"] >= 16 and not gguf_name: for n, m in model.named_modules(): From 1dafcf648cf95c1b3e892f5bdf23312bf9fd7e7c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Nov 2025 09:38:29 +0000 Subject: [PATCH 3/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 633211a8c..66f99b422 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2583,7 +2583,7 @@ def _quantize_block( if self.gradient_accumulate_steps != 1: whole_indices = torch.arange(global_batch_size) num_elm = self._get_current_num_elm(input_ids, whole_indices) - + for i in range(self.iters): total_loss = 0 if self.sampler == "rand": @@ -2619,7 +2619,6 @@ def _quantize_block( total_loss += loss.item() / num_elm self._scale_loss_and_backward(scaler, loss) - if i == 0: init_loss = total_loss From 3fa4131de53dcf7daac1ef0b65e2909fdcb7de30 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 5 Nov 2025 17:41:22 +0800 Subject: [PATCH 4/9] move one more clear memory to low_gpu_mem_usage --- auto_round/compressors/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 66f99b422..c09b7f915 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2657,7 +2657,8 @@ def _quantize_block( set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") if self.enable_quanted_input: - clear_memory() + if self.low_gpu_mem_usage: + clear_memory() q_outputs = self._get_block_outputs( block, input_ids, From 18e027fa3fa11674b6b3f89a40dc2fb0ede1ba1d Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 5 Nov 2025 19:58:40 +0800 Subject: [PATCH 5/9] fix issue --- auto_round/compressors/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index c09b7f915..d5cbb14c0 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2603,6 +2603,8 @@ def _quantize_block( tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) tmp_attention_mask.unsqueeze_(-1) num_elm = torch.sum(tmp_attention_mask).item() + if num_elm==0: + num_elm=1 else: tmp_attention_mask = 1.0 if self.amp: From 85a860681e9ee3db6adb463f60602069eb5ab917 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Nov 2025 12:00:40 +0000 Subject: [PATCH 6/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 8633a666a..39602680c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2603,8 +2603,8 @@ def _quantize_block( tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) tmp_attention_mask.unsqueeze_(-1) num_elm = torch.sum(tmp_attention_mask).item() - if num_elm==0: - num_elm=1 + if num_elm == 0: + num_elm = 1 else: tmp_attention_mask = 1.0 if self.amp: From 3a13d8a1cc77029c708c70bb999c22cde8cbd593 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Thu, 6 Nov 2025 09:55:25 +0800 Subject: [PATCH 7/9] update ut --- test/test_cpu/test_act_quantization.py | 41 ++++---------------------- test/test_cpu/test_autoround.py | 33 ++++++++++++++------- 2 files changed, 28 insertions(+), 46 deletions(-) diff --git a/test/test_cpu/test_act_quantization.py b/test/test_cpu/test_act_quantization.py index dfc387dee..31ba51f1b 100644 --- a/test/test_cpu/test_act_quantization.py +++ b/test/test_cpu/test_act_quantization.py @@ -154,17 +154,8 @@ def test_act_config_MXFP4_saving(self): quantized_model_path = self.save_dir autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") - lmhead_config = model.config.quantization_config.extra_config["lm_head"] - assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "mx_fp_rceil" - assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8 - assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 32 - assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"] - assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "mx_fp" - assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8 - assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 32 - assert "sym" in lmhead_config.keys() and lmhead_config["sym"] - assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None - assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None + assert "lm_head" not in model.config.quantization_config.extra_config + # check inblock layer config values kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"] assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "mx_fp_rceil" @@ -204,7 +195,7 @@ def test_act_config_NVFP4_saving(self): def test_WOQ_config_INT_saving(self): scheme = "W4A16" - layer_config = {"k_proj": {"bits": 8}} # "lm_head": {"bits": 4}, + layer_config = {"k_proj": {"bits": 8}} autoround = AutoRound( self.model_name, scheme=scheme, @@ -218,18 +209,6 @@ def test_WOQ_config_INT_saving(self): autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") extra_config = model.config.quantization_config.extra_config - # lmhead_config = extra_config["lm_head"] - # assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "float" - # assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 16 - # assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 128 - # assert "act_sym" in lmhead_config.keys() and not lmhead_config["act_sym"] - # assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "int" - # assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 4 - # assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 128 - # assert "sym" in lmhead_config.keys() and not lmhead_config["sym"] - # assert "act_dynamic" in lmhead_config.keys() and lmhead_config["act_dynamic"] - # assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None - # assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None # check inblock layer config values kproj_config = extra_config["model.decoder.layers.1.self_attn.k_proj"] @@ -270,18 +249,8 @@ def test_act_config_FP8_saving(self): from transformers import AutoConfig extra_config = AutoConfig.from_pretrained(quantized_model_path).quantization_config["extra_config"] - lmhead_config = extra_config["lm_head"] - assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "fp" - assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8 - assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 0 - assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"] - assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "fp" - assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8 - assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == -1 - assert "sym" in lmhead_config.keys() and lmhead_config["sym"] - assert "act_dynamic" in lmhead_config.keys() and not lmhead_config["act_dynamic"] - assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None - assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None + assert "lm_head" not in extra_config + # check inblock layer config values kproj_config = extra_config["model.decoder.layers.0.self_attn.k_proj"] assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "fp" diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 980db236d..7ebc5881b 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -743,20 +743,33 @@ def test_invalid_layer_config(self): def test_quant_lm_head(self): model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" - ar = AutoRound(model_name, quant_lm_head=True, iters=1, nsamples=1, seqlen=32) - ar.quantize() + ar = AutoRound(model_name, quant_lm_head=True, iters=0,disabel_opt_rtn=True) + ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") + assert "lm_head" in model.config.quantization_config.extra_config + assert model.config.quantization_config.extra_config["lm_head"]["bits"]==4 + + def test_quant_lm_head_layer_config(self): + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" + layer_config ={"lm_head":{"bits":4}} + ar = AutoRound(model_name, quant_lm_head=True, iters=0,disabel_opt_rtn=True,layer_config=layer_config) + ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") + assert "lm_head" in model.config.quantization_config.extra_config + assert model.config.quantization_config.extra_config["lm_head"]["bits"]==4 + def test_compressor(self): - model_name = "Qwen/Qwen2-VL-2B-Instruct" - ar = AutoRound(model_name, enable_adam=True) - self.assertEqual(ar.optimizer, torch.optim.AdamW) - self.assertTrue(ar.mllm) + model_name = "Qwen/Qwen2-VL-2B-Instruct" + ar = AutoRound(model_name, enable_adam=True) + self.assertEqual(ar.optimizer, torch.optim.AdamW) + self.assertTrue(ar.mllm) - # test old api - from auto_round import AutoRoundMLLM + # test old api + from auto_round import AutoRoundMLLM - ar = AutoRoundMLLM(model_name) - self.assertTrue(ar.mllm) + ar = AutoRoundMLLM(model_name) + self.assertTrue(ar.mllm) def test_attention_mask_in_dataset(self): from transformers import AutoTokenizer From f788dfe123db9005c48f1169aebd89c2221aca8f Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Thu, 6 Nov 2025 09:57:31 +0800 Subject: [PATCH 8/9] fix --- test/test_cpu/test_autoround.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 7ebc5881b..2ae437adf 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -743,7 +743,7 @@ def test_invalid_layer_config(self): def test_quant_lm_head(self): model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" - ar = AutoRound(model_name, quant_lm_head=True, iters=0,disabel_opt_rtn=True) + ar = AutoRound(model_name, quant_lm_head=True, iters=0,disable_opt_rtn=True) ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") assert "lm_head" in model.config.quantization_config.extra_config @@ -752,7 +752,7 @@ def test_quant_lm_head(self): def test_quant_lm_head_layer_config(self): model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" layer_config ={"lm_head":{"bits":4}} - ar = AutoRound(model_name, quant_lm_head=True, iters=0,disabel_opt_rtn=True,layer_config=layer_config) + ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True,layer_config=layer_config) ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") assert "lm_head" in model.config.quantization_config.extra_config From 9bbaefe96c930ce1a2f283f2ebe0a3f074a64baa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 6 Nov 2025 01:58:23 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/test_cpu/test_autoround.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 2ae437adf..7dfd4b479 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -743,33 +743,32 @@ def test_invalid_layer_config(self): def test_quant_lm_head(self): model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" - ar = AutoRound(model_name, quant_lm_head=True, iters=0,disable_opt_rtn=True) + ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True) ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") assert "lm_head" in model.config.quantization_config.extra_config - assert model.config.quantization_config.extra_config["lm_head"]["bits"]==4 + assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4 def test_quant_lm_head_layer_config(self): model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" - layer_config ={"lm_head":{"bits":4}} - ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True,layer_config=layer_config) + layer_config = {"lm_head": {"bits": 4}} + ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True, layer_config=layer_config) ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") assert "lm_head" in model.config.quantization_config.extra_config - assert model.config.quantization_config.extra_config["lm_head"]["bits"]==4 - + assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4 def test_compressor(self): - model_name = "Qwen/Qwen2-VL-2B-Instruct" - ar = AutoRound(model_name, enable_adam=True) - self.assertEqual(ar.optimizer, torch.optim.AdamW) - self.assertTrue(ar.mllm) + model_name = "Qwen/Qwen2-VL-2B-Instruct" + ar = AutoRound(model_name, enable_adam=True) + self.assertEqual(ar.optimizer, torch.optim.AdamW) + self.assertTrue(ar.mllm) - # test old api - from auto_round import AutoRoundMLLM + # test old api + from auto_round import AutoRoundMLLM - ar = AutoRoundMLLM(model_name) - self.assertTrue(ar.mllm) + ar = AutoRoundMLLM(model_name) + self.assertTrue(ar.mllm) def test_attention_mask_in_dataset(self): from transformers import AutoTokenizer