diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index d6df02bc1..cdaf01f96 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -397,6 +397,8 @@ def __init__( import habana_frameworks.torch.core as htcore # pylint: disable=E0401 import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] + self.attention_mask = [] + def _gen_auto_scheme( self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device] ) -> dict[str, dict]: @@ -797,21 +799,6 @@ def _check_compatibility(self) -> None: " We are likely to release new algorithm for certain configurations in the future." ) - # # Check group_size 32 for auto_round - # if ( - # self.data_type == "int" - # and hasattr(self, "formats") - # and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq")) - # ): - # for n, m in self.model.named_modules(): - # if type(m) in self.supported_types: - # if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: - # self.layer_config[n] = {"bits": 16} - # logger.info( - # f"{n} will not be quantized due to its shape not being divisible by 32," - # " resulting in an exporting issue to autogptq" - # ) - if ( self.seqlen is not None and hasattr(self.model, "config") @@ -1185,7 +1172,7 @@ def _quantize_embedding_layer(self): module.weight.to(self.device), **{k: config[k] for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]}, ) - except RuntimeError as e: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() try: logger.error(cuda_error_msg) @@ -1286,7 +1273,7 @@ def get_imatrix_hook(module, input, output): model = model.to("cpu") clear_memory() self._quantize_via_rtn_blockwise(all_to_quantized_module_names) - except RuntimeError as e: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() try: logger.error(cuda_error_msg) @@ -1360,7 +1347,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: ) m = m.unwrapper({}) m.to("cpu") - except RuntimeError as e: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() m = m.orig_layer if hasattr(m, "orig_layer") else m try: @@ -1462,7 +1449,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: hook_handles = self._register_act_max_hook(self.model) try: self._quantize_via_rtn_blockwise(all_to_quantized_module_names) - except RuntimeError as e: + except torch.OutOfMemoryError: logger.warning("Fallback to CPU. Consider using more GPUs via `--device 0,1,2,3`.") self.model = self.model.to("cpu") clear_memory() @@ -1920,7 +1907,9 @@ def calib(self, nsamples, bs): """ from auto_round.calib_dataset import get_dataloader + need_attention_mask = True if isinstance(self.dataset, str): + need_attention_mask = False # all supported datasets does not use pad dataset = self.dataset.replace(" ", "") ##remove all whitespaces # slow here @@ -1983,6 +1972,41 @@ def calib(self, nsamples, bs): raise error except Exception as error: raise error + if need_attention_mask: + if ( + isinstance(data_new, dict) + and "attention_mask" in data_new + and data_new["attention_mask"] is not None + ): + new_attention_mask = data_new["attention_mask"] + elif ( + self.tokenizer is not None + and hasattr(self.tokenizer, "pad_token") + and self.tokenizer.pad_token is not None + ): + new_attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.long) + else: + # Default all ones + new_attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + # For each sample, check if there are trailing repeated tokens + # If so, set the mask of the last token to 0 + batch_size, seq_len = input_ids.shape + for i in range(batch_size): + last_token = input_ids[i, -1] + # Check for trailing repeats + j = seq_len - 2 + repeated = False + while j >= 0 and input_ids[i, j] == last_token: + repeated = True + new_attention_mask[i, j] = 0 + j -= 1 + # If there was at least one repeat, set last token mask to 0 + if repeated: + new_attention_mask[i, -1] = 0 + + self.attention_mask.extend(list(torch.split(new_attention_mask, 1, dim=0))) + total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1 if total_cnt >= nsamples: break @@ -2058,7 +2082,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules(self.model) - except RuntimeError as e: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() try: logger.info("switch to cpu to cache block inputs") @@ -2070,10 +2094,10 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules( self.model - ) ##self.model.hf_device_map has not been changed + ) # self.model.hf_device_map has not been changed self.model = mv_module_from_gpu(self.model) clear_memory() - ## Important change after v0.51, on cpu, we use rtn mode for layers in layer_names + # Important change after v0.51, on cpu, we use rtn mode for layers in layer_names all_inputs = self.cache_inter_data( block_names, nsamples, layer_names=[], last_cache_name=last_cache_name ) @@ -2385,15 +2409,24 @@ def _quantize_layer( org_input = current_input with torch.no_grad(): current_output = layer(org_input) + if self.attention_mask: + tmp_attention_mask = [self.attention_mask[i] for i in indices] + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask.unsqueeze_(-1) + else: + tmp_attention_mask = 1.0 if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): output_q = wrapper_linear(current_input) # pylint: disable=not-callable - loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable + output_q * tmp_attention_mask, current_output * tmp_attention_mask + ) else: output_q = wrapper_linear(current_input) # pylint: disable=not-callable loss = mse_loss( # pylint: disable=not-callable - output_q.to(torch.float32), current_output.to(torch.float32) + output_q.to(torch.float32) * tmp_attention_mask, + current_output.to(torch.float32) * tmp_attention_mask, ) total_loss += loss.item() / num_elm @@ -2662,12 +2695,21 @@ def _quantize_block( current_output = to_device(current_output, device) output_q = self._get_current_q_output(block, input_ids, input_others, indices, device) + if self.attention_mask: + tmp_attention_mask = [self.attention_mask[i] for i in indices] + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask.unsqueeze_(-1) + else: + tmp_attention_mask = 1.0 if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): - loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable + output_q * tmp_attention_mask, current_output * tmp_attention_mask + ) else: loss = mse_loss( # pylint: disable=not-callable - output_q.to(torch.float32), current_output.to(torch.float32) + output_q.to(torch.float32) * tmp_attention_mask, + current_output.to(torch.float32) * tmp_attention_mask, ) total_loss += loss.item() / num_elm diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 7a54dc34a..3adfd9f47 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -755,6 +755,44 @@ def test_compressor(self): ar = AutoRoundMLLM(model_name) self.assertTrue(ar.mllm) + def test_attention_mask_in_dataset(self): + from transformers import AutoTokenizer + + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" + # model_name = "/models/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + text = ["haha", "hello world"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + data = [res.data] + + text = ["qudd", "hfd"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + data.append(res.data) + from auto_round import AutoRound + + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8) + ar.quantize() + + def test_attention_mask_via_tokenize_in_dataset(self): + from transformers import AutoTokenizer + + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" + # model_name = "/models/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + text = ["haha", "hello world"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + res.data.pop("attention_mask") + data = [res.data] + + text = ["qudd", "hfd"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + res.data.pop("attention_mask") + data.append(res.data) + from auto_round import AutoRound + + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8) + ar.quantize() + if __name__ == "__main__": unittest.main() diff --git a/test/test_cuda/test_main_func.py b/test/test_cuda/test_main_func.py index 7347118bc..70b85b0aa 100644 --- a/test/test_cuda/test_main_func.py +++ b/test/test_cuda/test_main_func.py @@ -179,6 +179,26 @@ def test_autoround_asym(self): ##need to install false assert accuracy > 0.35 shutil.rmtree("./saved", ignore_errors=True) + def test_attention_mask_lm_head(self): + from transformers import AutoTokenizer + + model_name = "/models/Qwen3-8B" + # model_name = "/models/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + text = ["haha", "hello world"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + res.data.pop("attention_mask") + data = [res.data] + + text = ["qudd", "hfd"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + res.data.pop("attention_mask") + data.append(res.data) + from auto_round import AutoRound + + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8, quant_lm_head=True) + ar.quantize() + if __name__ == "__main__": unittest.main()