From 52358f8506d1805e636c48fb9db723df8ca3bc75 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 22 Oct 2025 15:20:33 +0800 Subject: [PATCH 1/7] support attention mask --- auto_round/compressors/base.py | 67 +++++++++++++++++++------------- test/test_cpu/test_autoround.py | 35 +++++++++++++++++ test/test_cuda/test_main_func.py | 18 +++++++++ 3 files changed, 93 insertions(+), 27 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index d948de65b..48cf154e8 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -806,21 +806,6 @@ def _check_compatibility(self) -> None: " We are likely to release new algorithm for certain configurations in the future." ) - # # Check group_size 32 for auto_round - # if ( - # self.data_type == "int" - # and hasattr(self, "formats") - # and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq")) - # ): - # for n, m in self.model.named_modules(): - # if type(m) in self.supported_types: - # if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: - # self.layer_config[n] = {"bits": 16} - # logger.info( - # f"{n} will not be quantized due to its shape not being divisible by 32," - # " resulting in an exporting issue to autogptq" - # ) - if ( self.seqlen is not None and hasattr(self.model, "config") @@ -1194,7 +1179,7 @@ def _quantize_embedding_layer(self): module.weight.to(self.device), **{k: config[k] for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]}, ) - except RuntimeError as e: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() try: logger.error(cuda_error_msg) @@ -1295,7 +1280,7 @@ def get_imatrix_hook(module, input, output): model = model.to("cpu") clear_memory() self._quantize_via_rtn_blockwise(all_to_quantized_module_names) - except RuntimeError as e: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() try: logger.error(cuda_error_msg) @@ -1369,7 +1354,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None: ) m = m.unwrapper({}) m.to("cpu") - except RuntimeError as e: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() m = m.orig_layer if hasattr(m, "orig_layer") else m try: @@ -1471,7 +1456,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: hook_handles = self._register_act_max_hook(self.model) try: self._quantize_via_rtn_blockwise(all_to_quantized_module_names) - except RuntimeError as e: + except torch.OutOfMemoryError: logger.warning("Fallback to CPU. Consider using more GPUs via `--device 0,1,2,3`.") self.model = self.model.to("cpu") clear_memory() @@ -1928,8 +1913,9 @@ def calib(self, nsamples, bs): bs (int): The number of samples to use for calibration """ from auto_round.calib_dataset import get_dataloader - + need_attention_mask= True if isinstance(self.dataset, str): + need_attention_mask = False # all supportted datasets does not use pad dataset = self.dataset.replace(" ", "") ##remove all whitespaces # slow here @@ -1966,9 +1952,11 @@ def calib(self, nsamples, bs): for key in data.keys(): data_new[key] = data[key].to(self.model.device) input_ids = data_new["input_ids"] + need_attention_mask = True elif isinstance(data, tuple) or isinstance(data, list): data_new = to_device(data) input_ids = data_new[0] + need_attention_mask=True else: data_new = {} for key in data.keys(): @@ -1998,6 +1986,19 @@ def calib(self, nsamples, bs): raise error except Exception as error: raise error + if not hasattr(self, "attention_mask") or self.attention_mask is None: + self.attention_mask=[] + if need_attention_mask: + new_attention_mask = [] + if isinstance(data_new, dict) and "attention_mask" in data_new and data_new["attention_mask"] is not None: + new_attention_mask= data_new["attention_mask"] + elif self.tokenizer is not None and hasattr(self.tokenizer,"pad_token"): + new_attention_mask=(input_ids != self.tokenizer.pad_token_id).to(torch.long) + else: + new_attention_mask = torch.ones_like(input_ids).to(torch.long) + + self.attention_mask.extend(list(torch.split(new_attention_mask, 1, dim=0))) + total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1 if total_cnt >= nsamples: break @@ -2078,7 +2079,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules(self.model) - except RuntimeError as e: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() try: logger.info("switch to cpu to cache block inputs") @@ -2090,10 +2091,10 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1: accelerate.hooks.remove_hook_from_submodules( self.model - ) ##self.model.hf_device_map has not been changed + ) # self.model.hf_device_map has not been changed self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage) clear_memory() - ## Important change after v0.51, on cpu, we use rtn mode for layers in layer_names + # Important change after v0.51, on cpu, we use rtn mode for layers in layer_names all_inputs = self.cache_inter_data( block_names, nsamples, layer_names=[], last_cache_name=last_cache_name ) @@ -2405,15 +2406,21 @@ def _quantize_layer( org_input = current_input with torch.no_grad(): current_output = layer(org_input) + if self.attention_mask: + tmp_attention_mask = [self.attention_mask[i] for i in indices] + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask.unsqueeze_(-1) + else: + tmp_attention_mask = 1.0 if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): output_q = wrapper_linear(current_input) # pylint: disable=not-callable - loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + loss = mse_loss(output_q * tmp_attention_mask, current_output*tmp_attention_mask) # pylint: disable=not-callable else: output_q = wrapper_linear(current_input) # pylint: disable=not-callable loss = mse_loss( # pylint: disable=not-callable - output_q.to(torch.float32), current_output.to(torch.float32) + output_q.to(torch.float32)*tmp_attention_mask, current_output.to(torch.float32)*tmp_attention_mask ) total_loss += loss.item() / num_elm @@ -2678,12 +2685,18 @@ def _quantize_block( current_output = to_device(current_output, device) output_q = self._get_current_q_output(block, input_ids, input_others, indices, device) + if self.attention_mask: + tmp_attention_mask = [self.attention_mask[i] for i in indices] + tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) + tmp_attention_mask.unsqueeze_(-1) + else: + tmp_attention_mask=1.0 if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): - loss = mse_loss(output_q, current_output) # pylint: disable=not-callable + loss = mse_loss(output_q *tmp_attention_mask, current_output *tmp_attention_mask) # pylint: disable=not-callable else: loss = mse_loss( # pylint: disable=not-callable - output_q.to(torch.float32), current_output.to(torch.float32) + output_q.to(torch.float32)*tmp_attention_mask, current_output.to(torch.float32)*tmp_attention_mask ) total_loss += loss.item() / num_elm diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index cbd0583df..6c66ca56d 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -755,6 +755,41 @@ def test_compressor(self): ar = AutoRoundMLLM(model_name) self.assertTrue(ar.mllm) + def test_attention_mask_in_dataset(self): + from transformers import AutoTokenizer + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" + # model_name = "/models/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + text = ["haha", "hello world"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + data = [res.data] + + text = ["qudd", "hfd"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + data.append(res.data) + from auto_round import AutoRound + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8) + ar.quantize() + + def test_attention_mask_via_tokenize_in_dataset(self): + from transformers import AutoTokenizer + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" + # model_name = "/models/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + text = ["haha", "hello world"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + res.data.pop("attention_mask") + data = [res.data] + + text = ["qudd", "hfd"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + res.data.pop("attention_mask") + data.append(res.data) + from auto_round import AutoRound + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8) + ar.quantize() + + if __name__ == "__main__": unittest.main() diff --git a/test/test_cuda/test_main_func.py b/test/test_cuda/test_main_func.py index 7347118bc..c1bc651d6 100644 --- a/test/test_cuda/test_main_func.py +++ b/test/test_cuda/test_main_func.py @@ -179,6 +179,24 @@ def test_autoround_asym(self): ##need to install false assert accuracy > 0.35 shutil.rmtree("./saved", ignore_errors=True) + def test_attention_mask_lm_head(self): + from transformers import AutoTokenizer + model_name = "/models/Qwen3-8B" + # model_name = "/models/Qwen3-0.6B" + tokenizer = AutoTokenizer.from_pretrained(model_name) + text = ["haha", "hello world"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + res.data.pop("attention_mask") + data = [res.data] + + text = ["qudd", "hfd"] + res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) + res.data.pop("attention_mask") + data.append(res.data) + from auto_round import AutoRound + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8,quant_lm_head=True) + ar.quantize() + if __name__ == "__main__": unittest.main() From 00421a37191a5ccca21b91908343362a2155176f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Oct 2025 07:22:26 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 39 ++++++++++++++++++++------------ test/test_cpu/test_autoround.py | 5 +++- test/test_cuda/test_main_func.py | 4 +++- 3 files changed, 32 insertions(+), 16 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 48cf154e8..9751bb345 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1179,7 +1179,7 @@ def _quantize_embedding_layer(self): module.weight.to(self.device), **{k: config[k] for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]}, ) - except torch.OutOfMemoryError: + except torch.OutOfMemoryError: cuda_error_msg = traceback.format_exc() try: logger.error(cuda_error_msg) @@ -1913,9 +1913,10 @@ def calib(self, nsamples, bs): bs (int): The number of samples to use for calibration """ from auto_round.calib_dataset import get_dataloader - need_attention_mask= True + + need_attention_mask = True if isinstance(self.dataset, str): - need_attention_mask = False # all supportted datasets does not use pad + need_attention_mask = False # all supported datasets does not use pad dataset = self.dataset.replace(" ", "") ##remove all whitespaces # slow here @@ -1956,7 +1957,7 @@ def calib(self, nsamples, bs): elif isinstance(data, tuple) or isinstance(data, list): data_new = to_device(data) input_ids = data_new[0] - need_attention_mask=True + need_attention_mask = True else: data_new = {} for key in data.keys(): @@ -1987,13 +1988,17 @@ def calib(self, nsamples, bs): except Exception as error: raise error if not hasattr(self, "attention_mask") or self.attention_mask is None: - self.attention_mask=[] + self.attention_mask = [] if need_attention_mask: new_attention_mask = [] - if isinstance(data_new, dict) and "attention_mask" in data_new and data_new["attention_mask"] is not None: - new_attention_mask= data_new["attention_mask"] - elif self.tokenizer is not None and hasattr(self.tokenizer,"pad_token"): - new_attention_mask=(input_ids != self.tokenizer.pad_token_id).to(torch.long) + if ( + isinstance(data_new, dict) + and "attention_mask" in data_new + and data_new["attention_mask"] is not None + ): + new_attention_mask = data_new["attention_mask"] + elif self.tokenizer is not None and hasattr(self.tokenizer, "pad_token"): + new_attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.long) else: new_attention_mask = torch.ones_like(input_ids).to(torch.long) @@ -2416,11 +2421,14 @@ def _quantize_layer( if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): output_q = wrapper_linear(current_input) # pylint: disable=not-callable - loss = mse_loss(output_q * tmp_attention_mask, current_output*tmp_attention_mask) # pylint: disable=not-callable + loss = mse_loss( + output_q * tmp_attention_mask, current_output * tmp_attention_mask + ) # pylint: disable=not-callable else: output_q = wrapper_linear(current_input) # pylint: disable=not-callable loss = mse_loss( # pylint: disable=not-callable - output_q.to(torch.float32)*tmp_attention_mask, current_output.to(torch.float32)*tmp_attention_mask + output_q.to(torch.float32) * tmp_attention_mask, + current_output.to(torch.float32) * tmp_attention_mask, ) total_loss += loss.item() / num_elm @@ -2690,13 +2698,16 @@ def _quantize_block( tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device) tmp_attention_mask.unsqueeze_(-1) else: - tmp_attention_mask=1.0 + tmp_attention_mask = 1.0 if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): - loss = mse_loss(output_q *tmp_attention_mask, current_output *tmp_attention_mask) # pylint: disable=not-callable + loss = mse_loss( + output_q * tmp_attention_mask, current_output * tmp_attention_mask + ) # pylint: disable=not-callable else: loss = mse_loss( # pylint: disable=not-callable - output_q.to(torch.float32)*tmp_attention_mask, current_output.to(torch.float32)*tmp_attention_mask + output_q.to(torch.float32) * tmp_attention_mask, + current_output.to(torch.float32) * tmp_attention_mask, ) total_loss += loss.item() / num_elm diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 6c66ca56d..ee801048b 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -757,6 +757,7 @@ def test_compressor(self): def test_attention_mask_in_dataset(self): from transformers import AutoTokenizer + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" # model_name = "/models/Qwen3-0.6B" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -768,11 +769,13 @@ def test_attention_mask_in_dataset(self): res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True) data.append(res.data) from auto_round import AutoRound + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8) ar.quantize() def test_attention_mask_via_tokenize_in_dataset(self): from transformers import AutoTokenizer + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B" # model_name = "/models/Qwen3-0.6B" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -786,10 +789,10 @@ def test_attention_mask_via_tokenize_in_dataset(self): res.data.pop("attention_mask") data.append(res.data) from auto_round import AutoRound + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8) ar.quantize() - if __name__ == "__main__": unittest.main() diff --git a/test/test_cuda/test_main_func.py b/test/test_cuda/test_main_func.py index c1bc651d6..70b85b0aa 100644 --- a/test/test_cuda/test_main_func.py +++ b/test/test_cuda/test_main_func.py @@ -181,6 +181,7 @@ def test_autoround_asym(self): ##need to install false def test_attention_mask_lm_head(self): from transformers import AutoTokenizer + model_name = "/models/Qwen3-8B" # model_name = "/models/Qwen3-0.6B" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -194,7 +195,8 @@ def test_attention_mask_lm_head(self): res.data.pop("attention_mask") data.append(res.data) from auto_round import AutoRound - ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8,quant_lm_head=True) + + ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8, quant_lm_head=True) ar.quantize() From f34d268dcd3a43db47e490160596bc2e5f24217f Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 22 Oct 2025 15:37:20 +0800 Subject: [PATCH 3/7] fix preci --- auto_round/compressors/base.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 9751bb345..38fb63cda 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1953,11 +1953,9 @@ def calib(self, nsamples, bs): for key in data.keys(): data_new[key] = data[key].to(self.model.device) input_ids = data_new["input_ids"] - need_attention_mask = True elif isinstance(data, tuple) or isinstance(data, list): data_new = to_device(data) input_ids = data_new[0] - need_attention_mask = True else: data_new = {} for key in data.keys(): @@ -2421,9 +2419,9 @@ def _quantize_layer( if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): output_q = wrapper_linear(current_input) # pylint: disable=not-callable - loss = mse_loss( + loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask - ) # pylint: disable=not-callable + ) else: output_q = wrapper_linear(current_input) # pylint: disable=not-callable loss = mse_loss( # pylint: disable=not-callable @@ -2701,9 +2699,9 @@ def _quantize_block( tmp_attention_mask = 1.0 if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): - loss = mse_loss( + loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask - ) # pylint: disable=not-callable + ) else: loss = mse_loss( # pylint: disable=not-callable output_q.to(torch.float32) * tmp_attention_mask, From 87ffc1230a0f11d90505d5d9f9174eeda749fdee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Oct 2025 07:38:02 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 38fb63cda..e06c33abc 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -2419,7 +2419,7 @@ def _quantize_layer( if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): output_q = wrapper_linear(current_input) # pylint: disable=not-callable - loss = mse_loss( # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask ) else: @@ -2699,7 +2699,7 @@ def _quantize_block( tmp_attention_mask = 1.0 if self.amp: with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype): - loss = mse_loss( # pylint: disable=not-callable + loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask ) else: From 6aaea6c60bb7c8ecbe44f7b509a2a9180127325a Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Wed, 22 Oct 2025 16:06:10 +0800 Subject: [PATCH 5/7] fix bug --- auto_round/compressors/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index e06c33abc..e7f69c1aa 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -401,6 +401,8 @@ def __init__( import habana_frameworks.torch.core as htcore # pylint: disable=E0401 import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] + self.attention_mask = [] + def _gen_auto_scheme( self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device] ) -> dict[str, dict]: @@ -1985,8 +1987,6 @@ def calib(self, nsamples, bs): raise error except Exception as error: raise error - if not hasattr(self, "attention_mask") or self.attention_mask is None: - self.attention_mask = [] if need_attention_mask: new_attention_mask = [] if ( From b873508e3258e17bbdc5f3ebdd6c5aad25b9bf7d Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Thu, 23 Oct 2025 13:48:32 +0800 Subject: [PATCH 6/7] update --- auto_round/compressors/base.py | 22 +++++- test/test_cpu/test_low_cpu_mem.py | 122 ------------------------------ 2 files changed, 19 insertions(+), 125 deletions(-) delete mode 100644 test/test_cpu/test_low_cpu_mem.py diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index e7f69c1aa..2d1bf1630 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1988,17 +1988,33 @@ def calib(self, nsamples, bs): except Exception as error: raise error if need_attention_mask: - new_attention_mask = [] if ( isinstance(data_new, dict) and "attention_mask" in data_new and data_new["attention_mask"] is not None ): new_attention_mask = data_new["attention_mask"] - elif self.tokenizer is not None and hasattr(self.tokenizer, "pad_token"): + elif self.tokenizer is not None and hasattr(self.tokenizer, "pad_token") and self.tokenizer.pad_token is not None: new_attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.long) else: - new_attention_mask = torch.ones_like(input_ids).to(torch.long) + # Default all ones + new_attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + # For each sample, check if there are trailing repeated tokens + # If so, set the mask of the last token to 0 + batch_size, seq_len = input_ids.shape + for i in range(batch_size): + last_token = input_ids[i, -1] + # Check for trailing repeats + j = seq_len - 2 + repeated = False + while j >= 0 and input_ids[i, j] == last_token: + repeated = True + new_attention_mask[i, j] = 0 + j -= 1 + # If there was at least one repeat, set last token mask to 0 + if repeated: + new_attention_mask[i, -1] = 0 self.attention_mask.extend(list(torch.split(new_attention_mask, 1, dim=0))) diff --git a/test/test_cpu/test_low_cpu_mem.py b/test/test_cpu/test_low_cpu_mem.py deleted file mode 100644 index 582b5e47b..000000000 --- a/test/test_cpu/test_low_cpu_mem.py +++ /dev/null @@ -1,122 +0,0 @@ -import os -import shutil -import sys -import unittest - -sys.path.insert(0, "../..") - -import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from auto_round import AutoRound -from auto_round.low_cpu_mem.utils import ( - get_layers_before_block, - layer_wise_load, - layer_wise_save, - load_empty_model, - load_model_with_hooks, -) - - -class LLMDataLoader: - def __init__(self): - self.batch_size = 1 - - def __iter__(self): - for i in range(2): - yield torch.ones([1, 10], dtype=torch.long) - - -class TestLowCPUMem(unittest.TestCase): - @classmethod - def setUpClass(self): - self.model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - self.saved_path = "./test_tmp_saved" - self.ori_model = AutoModelForCausalLM.from_pretrained(self.model_name, trust_remote_code=True) - self.model = load_model_with_hooks( - self.model_name, AutoModelForCausalLM, saved_path=self.saved_path, device="cpu" - ) - self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) - self.llm_dataloader = LLMDataLoader() - - @classmethod - def tearDownClass(self): - shutil.rmtree(self.saved_path, ignore_errors=True) - - def test_default(self): - self.assertTrue(self.model.device.type, "meta") - - # TODO: change this func - # layers = get_layers_before_block(self.model) - # self.assertEqual(layers[0][0], "model.decoder.embed_tokens") - - # test get_weight bias - self.assertTrue( - torch.equal( - self.model.model.decoder.layers[0].self_attn.k_proj.get_weight(), - self.ori_model.model.decoder.layers[0].self_attn.k_proj.weight, - ) - ) - self.assertTrue( - torch.equal( - self.model.model.decoder.layers[0].self_attn.k_proj.get_bias(), - self.ori_model.model.decoder.layers[0].self_attn.k_proj.bias, - ) - ) - - # test hooks - text = ["Hello, my dog is cute"] - input = self.tokenizer(text) - for key in input: - input[key] = torch.tensor(input[key]) - ori_output = self.ori_model.generate(**input, max_new_tokens=5, do_sample=False) - ori_result = self.tokenizer.decode(ori_output[0]) - print(ori_result) - self.model.to("cpu") - output = self.model.generate(**input, max_new_tokens=5, do_sample=False) - result = self.tokenizer.decode(output[0]) - print(result) - self.assertEqual(ori_result, result) - self.model.to("meta") - - # test save and load - layer_wise_save(self.model, self.saved_path) - state_dict = layer_wise_load(self.saved_path) - self.assertTrue(torch.equal(state_dict["lm_head.weight"], self.ori_model.lm_head.weight)) - - # test layer-wise auto_round - bits, group_size, sym = 4, 128, False - autoround = AutoRound( - self.model, - self.tokenizer, - device="cpu", - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - enable_torch_compile=False, - ) - autoround.quantize() - - # test block-wise auto_round - self.model = load_empty_model(self.model_name, AutoModelForCausalLM, saved_path=self.saved_path, device="cpu") - bits, group_size, sym = 4, 128, False - autoround = AutoRound( - self.model, - self.tokenizer, - device="cpu", - bits=bits, - group_size=group_size, - sym=sym, - iters=2, - seqlen=2, - dataset=self.llm_dataloader, - low_cpu_mem_usage=True, - ) - autoround.quantize() - - -if __name__ == "__main__": - unittest.main() From abbe8b9a0a527f81b1020a7fe6472418636838ef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 23 Oct 2025 05:49:17 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 2d1bf1630..e79b044a6 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -1994,7 +1994,11 @@ def calib(self, nsamples, bs): and data_new["attention_mask"] is not None ): new_attention_mask = data_new["attention_mask"] - elif self.tokenizer is not None and hasattr(self.tokenizer, "pad_token") and self.tokenizer.pad_token is not None: + elif ( + self.tokenizer is not None + and hasattr(self.tokenizer, "pad_token") + and self.tokenizer.pad_token is not None + ): new_attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.long) else: # Default all ones