Skip to content
Merged
31 changes: 18 additions & 13 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,10 +2271,10 @@ def _quantize_layer(
init_loss = None
gradient_accumulate_steps = self.batch_size # Force to low gpu
batch_size = 1 # Force to low gpu
pick_samples = batch_size * gradient_accumulate_steps
pick_samples = min(nsamples, pick_samples)
global_batch_size = batch_size * gradient_accumulate_steps
global_batch_size = min(nsamples, global_batch_size)
if self.sampler != "rand":
whole_indices = torch.randperm(nsamples)[:pick_samples]
whole_indices = torch.randperm(nsamples)[:global_batch_size]
total_loss = 0
num_elm = 1
mse_reduction = "mean"
Expand All @@ -2285,7 +2285,7 @@ def _quantize_layer(
for i in range(self.iters):
total_loss = 0
if self.sampler == "rand":
whole_indices = torch.randperm(nsamples)[:pick_samples]
whole_indices = torch.randperm(nsamples)[:global_batch_size]
if gradient_accumulate_steps != 1:
if q_inputs is not None:
num_elm = self._get_current_num_elm(q_inputs, whole_indices)
Expand Down Expand Up @@ -2564,10 +2564,10 @@ def _quantize_block(
else:
nsamples = len(input_ids)

pick_samples = self.batch_size * self.gradient_accumulate_steps
pick_samples = min(nsamples, pick_samples)
global_batch_size = self.batch_size * self.gradient_accumulate_steps
global_batch_size = min(nsamples, global_batch_size)
if self.sampler != "rand":
whole_indices = torch.randperm(nsamples)[:pick_samples]
whole_indices = torch.randperm(nsamples)[:global_batch_size]
last_best_iter = 0
best_loss = torch.finfo(torch.float).max
num_elm = 1
Expand All @@ -2579,13 +2579,15 @@ def _quantize_block(
init_loss = None
best_params = {}
total_loss = 0
# We assume the block input and output shape is same
if self.gradient_accumulate_steps != 1:
whole_indices = torch.arange(global_batch_size)
num_elm = self._get_current_num_elm(input_ids, whole_indices)

for i in range(self.iters):
total_loss = 0
if self.sampler == "rand":
whole_indices = torch.randperm(nsamples)[:pick_samples]
# We assume the block input and output shape is same
if self.gradient_accumulate_steps != 1:
num_elm = self._get_current_num_elm(input_ids, whole_indices)
whole_indices = torch.randperm(nsamples)[:global_batch_size]

for tmp_step in range(self.gradient_accumulate_steps):
indices = whole_indices[tmp_step * self.batch_size : (tmp_step + 1) * self.batch_size]
Expand All @@ -2600,6 +2602,9 @@ def _quantize_block(
tmp_attention_mask = [self.attention_mask[i] for i in indices]
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
tmp_attention_mask.unsqueeze_(-1)
num_elm = torch.sum(tmp_attention_mask).item()
if num_elm == 0:
num_elm = 1
else:
tmp_attention_mask = 1.0
if self.amp:
Expand All @@ -2615,7 +2620,6 @@ def _quantize_block(

total_loss += loss.item() / num_elm
self._scale_loss_and_backward(scaler, loss)
clear_memory_if_reached_threshold(threshold=0.85)

if i == 0:
init_loss = total_loss
Expand Down Expand Up @@ -2655,7 +2659,8 @@ def _quantize_block(
set_amax_for_all_moe_layers(block, attr_name="orig_layer.act_max")

if self.enable_quanted_input:
clear_memory()
if self.low_gpu_mem_usage:
clear_memory()
q_outputs = self._get_block_outputs(
block,
input_ids,
Expand Down
3 changes: 3 additions & 0 deletions auto_round/compressors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
if lm_head_name not in layer_config and quant_lm_head:
layer_config[lm_head_name] = copy.deepcopy(default_dict)

if not quant_lm_head and not gguf_name:
layer_config.pop(lm_head_name, None)

# 8. enforce shape divisibility for int weight-only
if default_dict["data_type"] == "int" and default_dict["act_bits"] >= 16 and not gguf_name:
for n, m in model.named_modules():
Expand Down
41 changes: 5 additions & 36 deletions test/test_cpu/test_act_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,8 @@ def test_act_config_MXFP4_saving(self):
quantized_model_path = self.save_dir
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu")
lmhead_config = model.config.quantization_config.extra_config["lm_head"]
assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "mx_fp_rceil"
assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8
assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 32
assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"]
assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "mx_fp"
assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8
assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 32
assert "sym" in lmhead_config.keys() and lmhead_config["sym"]
assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None
assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None
assert "lm_head" not in model.config.quantization_config.extra_config

# check inblock layer config values
kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"]
assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "mx_fp_rceil"
Expand Down Expand Up @@ -204,7 +195,7 @@ def test_act_config_NVFP4_saving(self):

def test_WOQ_config_INT_saving(self):
scheme = "W4A16"
layer_config = {"k_proj": {"bits": 8}} # "lm_head": {"bits": 4},
layer_config = {"k_proj": {"bits": 8}}
autoround = AutoRound(
self.model_name,
scheme=scheme,
Expand All @@ -218,18 +209,6 @@ def test_WOQ_config_INT_saving(self):
autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round")
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu")
extra_config = model.config.quantization_config.extra_config
# lmhead_config = extra_config["lm_head"]
# assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "float"
# assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 16
# assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 128
# assert "act_sym" in lmhead_config.keys() and not lmhead_config["act_sym"]
# assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "int"
# assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 4
# assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 128
# assert "sym" in lmhead_config.keys() and not lmhead_config["sym"]
# assert "act_dynamic" in lmhead_config.keys() and lmhead_config["act_dynamic"]
# assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None
# assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None

# check inblock layer config values
kproj_config = extra_config["model.decoder.layers.1.self_attn.k_proj"]
Expand Down Expand Up @@ -270,18 +249,8 @@ def test_act_config_FP8_saving(self):
from transformers import AutoConfig

extra_config = AutoConfig.from_pretrained(quantized_model_path).quantization_config["extra_config"]
lmhead_config = extra_config["lm_head"]
assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "fp"
assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8
assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 0
assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"]
assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "fp"
assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8
assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == -1
assert "sym" in lmhead_config.keys() and lmhead_config["sym"]
assert "act_dynamic" in lmhead_config.keys() and not lmhead_config["act_dynamic"]
assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None
assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None
assert "lm_head" not in extra_config

# check inblock layer config values
kproj_config = extra_config["model.decoder.layers.0.self_attn.k_proj"]
assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "fp"
Expand Down
16 changes: 14 additions & 2 deletions test/test_cpu/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,8 +743,20 @@ def test_invalid_layer_config(self):

def test_quant_lm_head(self):
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B"
ar = AutoRound(model_name, quant_lm_head=True, iters=1, nsamples=1, seqlen=32)
ar.quantize()
ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True)
ar.quantize_and_save(output_dir=self.save_folder, format="auto_round")
model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu")
assert "lm_head" in model.config.quantization_config.extra_config
assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4

def test_quant_lm_head_layer_config(self):
model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B"
layer_config = {"lm_head": {"bits": 4}}
ar = AutoRound(model_name, quant_lm_head=True, iters=0, disable_opt_rtn=True, layer_config=layer_config)
ar.quantize_and_save(output_dir=self.save_folder, format="auto_round")
model = AutoModelForCausalLM.from_pretrained(self.save_folder, device_map="cpu")
assert "lm_head" in model.config.quantization_config.extra_config
assert model.config.quantization_config.extra_config["lm_head"]["bits"] == 4

def test_compressor(self):
model_name = "Qwen/Qwen2-VL-2B-Instruct"
Expand Down
Loading