diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index a5e4df776..39602680c 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2271,10 +2271,10 @@ def _quantize_layer( init_loss = None gradient_accumulate_steps = self.batch_size # Force to low gpu batch_size = 1 # Force to low gpu - pick_samples = batch_size * gradient_accumulate_steps - pick_samples = min(nsamples, pick_samples) + global_batch_size = batch_size * gradient_accumulate_steps + global_batch_size = min(nsamples, global_batch_size) if self.sampler != "rand": - whole_indices = torch.randperm(nsamples)[:pick_samples] + whole_indices = torch.randperm(nsamples)[:global_batch_size] total_loss = 0 num_elm = 1 mse_reduction = "mean" @@ -2285,7 +2285,7 @@ def _quantize_layer( for i in range(self.iters): total_loss = 0 if self.sampler == "rand": - whole_indices = torch.randperm(nsamples)[:pick_samples] + whole_indices = torch.randperm(nsamples)[:global_batch_size] if gradient_accumulate_steps != 1: if q_inputs is not None: num_elm = self._get_current_num_elm(q_inputs, whole_indices) @@ -2564,10 +2564,10 @@ def _quantize_block( else: nsamples = len(input_ids) - pick_samples = self.batch_size * self.gradient_accumulate_steps - pick_samples = min(nsamples, pick_samples) + global_batch_size = self.batch_size * self.gradient_accumulate_steps + global_batch_size = min(nsamples, global_batch_size) if self.sampler != "rand": - whole_indices = torch.randperm(nsamples)[:pick_samples] + whole_indices = torch.randperm(nsamples)[:global_batch_size] last_best_iter = 0 best_loss = torch.finfo(torch.float).max num_elm = 1 @@ -2579,13 +2579,15 @@ def _quantize_block( init_loss = None best_params = {} total_loss = 0 + # We assume the block input and output shape is same + if self.gradient_accumulate_steps != 1: + whole_indices = torch.arange(global_batch_size) + num_elm = self._get_current_num_elm(input_ids, whole_indices) + for i in range(self.iters): total_loss = 0 if self.sampler == "rand": - whole_indices = torch.randperm(nsamples)[:pick_samples] - # We assume the block input and output shape is same - if self.gradient_accumulate_steps != 1: - num_elm = self._get_current_num_elm(input_ids, whole_indices) + whole_indices = torch.randperm(nsamples)[:global_batch_size] for tmp_step in range(self.gradient_accumulate_steps): indices = whole_indices[tmp_step * self.batch_size : (tmp_step + 1) * self.batch_size] @@ -2600,6 +2602,9 @@ def _quantize_block( tmp_attention_mask = [self.attention_mask[i] for i in indices] tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) tmp_attention_mask.unsqueeze_(-1) + num_elm = torch.sum(tmp_attention_mask).item() + if num_elm == 0: + num_elm = 1 else: tmp_attention_mask = 1.0 if self.amp: @@ -2615,7 +2620,6 @@ def _quantize_block( total_loss += loss.item() / num_elm self._scale_loss_and_backward(scaler, loss) - clear_memory_if_reached_threshold(threshold=0.85) if i == 0: init_loss = total_loss @@ -2655,7 +2659,8 @@ def _quantize_block( set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max") if self.enable_quanted_input: - clear_memory() + if self.low_gpu_mem_usage: + clear_memory() q_outputs = self._get_block_outputs( block, input_ids, diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index 79e8b7133..7d6f28806 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -377,6 +377,9 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str if lm_head_name not in layer_config and quant_lm_head: layer_config[lm_head_name] = copy.deepcopy(default_dict) + if not quant_lm_head and not gguf_name: + layer_config.pop(lm_head_name, None) + # 8. enforce shape divisibility for int weight-only if default_dict["data_type"] == "int" and default_dict["act_bits"] >= 16 and not gguf_name: for n, m in model.named_modules(): diff --git a/test/test_cpu/test_act_quantization.py b/test/test_cpu/test_act_quantization.py index dfc387dee..31ba51f1b 100644 --- a/test/test_cpu/test_act_quantization.py +++ b/test/test_cpu/test_act_quantization.py @@ -154,17 +154,8 @@ def test_act_config_MXFP4_saving(self): quantized_model_path = self.save_dir autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") - lmhead_config = model.config.quantization_config.extra_config["lm_head"] - assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "mx_fp_rceil" - assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8 - assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 32 - assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"] - assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "mx_fp" - assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8 - assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 32 - assert "sym" in lmhead_config.keys() and lmhead_config["sym"] - assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None - assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None + assert "lm_head" not in model.config.quantization_config.extra_config + # check inblock layer config values kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"] assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "mx_fp_rceil" @@ -204,7 +195,7 @@ def test_act_config_NVFP4_saving(self): def test_WOQ_config_INT_saving(self): scheme = "W4A16" - layer_config = {"k_proj": {"bits": 8}} # "lm_head": {"bits": 4}, + layer_config = {"k_proj": {"bits": 8}} autoround = AutoRound( self.model_name, scheme=scheme, @@ -218,18 +209,6 @@ def test_WOQ_config_INT_saving(self): autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") extra_config = model.config.quantization_config.extra_config - # lmhead_config = extra_config["lm_head"] - # assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "float" - # assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 16 - # assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 128 - # assert "act_sym" in lmhead_config.keys() and not lmhead_config["act_sym"] - # assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "int" - # assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 4 - # assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 128 - # assert "sym" in lmhead_config.keys() and not lmhead_config["sym"] - # assert "act_dynamic" in lmhead_config.keys() and lmhead_config["act_dynamic"] - # assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None - # assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None # check inblock layer config values kproj_config = extra_config["model.decoder.layers.1.self_attn.k_proj"] @@ -270,18 +249,8 @@ def test_act_config_FP8_saving(self): from transformers import AutoConfig extra_config = AutoConfig.from_pretrained(quantized_model_path).quantization_config["extra_config"] - lmhead_config = extra_config["lm_head"] - assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "fp" - assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8 - assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 0 - assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"] - assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "fp" - assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8 - assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == -1 - assert "sym" in lmhead_config.keys() and lmhead_config["sym"] - assert "act_dynamic" in lmhead_config.keys() and not lmhead_config["act_dynamic"] - assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None - assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None + assert "lm_head" not in extra_config + # check inblock layer config values kproj_config = extra_config["model.decoder.layers.0.self_attn.k_proj"] assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "fp" diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 980db236d..7dfd4b479 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -743,8 +743,20 @@ def test_invalid_layer_config(self): def test_quant_lm_head(self): model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" - ar = AutoRound(model_name, quant_lm_head=True, iters=1, nsamples=1, seqlen=32) - ar.quantize() + ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True) + ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") + assert "lm_head" in model.config.quantization_config.extra_config + assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4 + + def test_quant_lm_head_layer_config(self): + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" + layer_config = {"lm_head": {"bits": 4}} + ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True, layer_config=layer_config) + ar.quantize_and_save(output_dir=self.save_folder, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu") + assert "lm_head" in model.config.quantization_config.extra_config + assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4 def test_compressor(self): model_name = "Qwen/Qwen2-VL-2B-Instruct"