Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 68 additions & 26 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def __init__(
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401]

self.attention_mask = []

def _gen_auto_scheme(
self, model: torch.nn.Module, scheme: AutoScheme, dataset: str, device_map: Union[str, int, dict, torch.device]
) -> dict[str, dict]:
Expand Down Expand Up @@ -797,21 +799,6 @@ def _check_compatibility(self) -> None:
" We are likely to release new algorithm for certain configurations in the future."
)

# # Check group_size 32 for auto_round
# if (
# self.data_type == "int"
# and hasattr(self, "formats")
# and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq"))
# ):
# for n, m in self.model.named_modules():
# if type(m) in self.supported_types:
# if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
# self.layer_config[n] = {"bits": 16}
# logger.info(
# f"{n} will not be quantized due to its shape not being divisible by 32,"
# " resulting in an exporting issue to autogptq"
# )

if (
self.seqlen is not None
and hasattr(self.model, "config")
Expand Down Expand Up @@ -1185,7 +1172,7 @@ def _quantize_embedding_layer(self):
module.weight.to(self.device),
**{k: config[k] for k in ["bits", "group_size", "super_bits", "super_group_size", "scale_dtype"]},
)
except RuntimeError as e:
except torch.OutOfMemoryError:
cuda_error_msg = traceback.format_exc()
try:
logger.error(cuda_error_msg)
Expand Down Expand Up @@ -1286,7 +1273,7 @@ def get_imatrix_hook(module, input, output):
model = model.to("cpu")
clear_memory()
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
except RuntimeError as e:
except torch.OutOfMemoryError:
cuda_error_msg = traceback.format_exc()
try:
logger.error(cuda_error_msg)
Expand Down Expand Up @@ -1360,7 +1347,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
)
m = m.unwrapper({})
m.to("cpu")
except RuntimeError as e:
except torch.OutOfMemoryError:
cuda_error_msg = traceback.format_exc()
m = m.orig_layer if hasattr(m, "orig_layer") else m
try:
Expand Down Expand Up @@ -1462,7 +1449,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
hook_handles = self._register_act_max_hook(self.model)
try:
self._quantize_via_rtn_blockwise(all_to_quantized_module_names)
except RuntimeError as e:
except torch.OutOfMemoryError:
logger.warning("Fallback to CPU. Consider using more GPUs via `--device 0,1,2,3`.")
self.model = self.model.to("cpu")
clear_memory()
Expand Down Expand Up @@ -1920,7 +1907,9 @@ def calib(self, nsamples, bs):
"""
from auto_round.calib_dataset import get_dataloader

need_attention_mask = True
if isinstance(self.dataset, str):
need_attention_mask = False # all supported datasets does not use pad
dataset = self.dataset.replace(" ", "") ##remove all whitespaces

# slow here
Expand Down Expand Up @@ -1983,6 +1972,41 @@ def calib(self, nsamples, bs):
raise error
except Exception as error:
raise error
if need_attention_mask:
if (
isinstance(data_new, dict)
and "attention_mask" in data_new
and data_new["attention_mask"] is not None
):
new_attention_mask = data_new["attention_mask"]
elif (
self.tokenizer is not None
and hasattr(self.tokenizer, "pad_token")
and self.tokenizer.pad_token is not None
):
new_attention_mask = (input_ids != self.tokenizer.pad_token_id).to(torch.long)
else:
# Default all ones
new_attention_mask = torch.ones_like(input_ids, dtype=torch.long)

# For each sample, check if there are trailing repeated tokens
# If so, set the mask of the last token to 0
batch_size, seq_len = input_ids.shape
for i in range(batch_size):
last_token = input_ids[i, -1]
# Check for trailing repeats
j = seq_len - 2
repeated = False
while j >= 0 and input_ids[i, j] == last_token:
repeated = True
new_attention_mask[i, j] = 0
j -= 1
# If there was at least one repeat, set last token mask to 0
if repeated:
new_attention_mask[i, -1] = 0

self.attention_mask.extend(list(torch.split(new_attention_mask, 1, dim=0)))

total_cnt += input_ids.shape[0] if len(input_ids.shape) > 1 else 1
if total_cnt >= nsamples:
break
Expand Down Expand Up @@ -2058,7 +2082,7 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(self.model)

except RuntimeError as e:
except torch.OutOfMemoryError:
cuda_error_msg = traceback.format_exc()
try:
logger.info("switch to cpu to cache block inputs")
Expand All @@ -2070,10 +2094,10 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1:
accelerate.hooks.remove_hook_from_submodules(
self.model
) ##self.model.hf_device_map has not been changed
) # self.model.hf_device_map has not been changed
self.model = mv_module_from_gpu(self.model)
clear_memory()
## Important change after v0.51, on cpu, we use rtn mode for layers in layer_names
# Important change after v0.51, on cpu, we use rtn mode for layers in layer_names
all_inputs = self.cache_inter_data(
block_names, nsamples, layer_names=[], last_cache_name=last_cache_name
)
Expand Down Expand Up @@ -2385,15 +2409,24 @@ def _quantize_layer(
org_input = current_input
with torch.no_grad():
current_output = layer(org_input)
if self.attention_mask:
tmp_attention_mask = [self.attention_mask[i] for i in indices]
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
tmp_attention_mask.unsqueeze_(-1)
else:
tmp_attention_mask = 1.0

if self.amp:
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
loss = mse_loss(output_q, current_output) # pylint: disable=not-callable
loss = mse_loss( # pylint: disable=not-callable
output_q * tmp_attention_mask, current_output * tmp_attention_mask
)
else:
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
loss = mse_loss( # pylint: disable=not-callable
output_q.to(torch.float32), current_output.to(torch.float32)
output_q.to(torch.float32) * tmp_attention_mask,
current_output.to(torch.float32) * tmp_attention_mask,
)
total_loss += loss.item() / num_elm

Expand Down Expand Up @@ -2662,12 +2695,21 @@ def _quantize_block(
current_output = to_device(current_output, device)

output_q = self._get_current_q_output(block, input_ids, input_others, indices, device)
if self.attention_mask:
tmp_attention_mask = [self.attention_mask[i] for i in indices]
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
tmp_attention_mask.unsqueeze_(-1)
else:
tmp_attention_mask = 1.0
if self.amp:
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
loss = mse_loss(output_q, current_output) # pylint: disable=not-callable
loss = mse_loss( # pylint: disable=not-callable
output_q * tmp_attention_mask, current_output * tmp_attention_mask
)
else:
loss = mse_loss( # pylint: disable=not-callable
output_q.to(torch.float32), current_output.to(torch.float32)
output_q.to(torch.float32) * tmp_attention_mask,
current_output.to(torch.float32) * tmp_attention_mask,
)

total_loss += loss.item() / num_elm
Expand Down
38 changes: 38 additions & 0 deletions test/test_cpu/test_autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,44 @@ def test_compressor(self):
ar = AutoRoundMLLM(model_name)
self.assertTrue(ar.mllm)

def test_attention_mask_in_dataset(self):
from transformers import AutoTokenizer

model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B"
# model_name = "/models/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = ["haha", "hello world"]
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
data = [res.data]

text = ["qudd", "hfd"]
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
data.append(res.data)
from auto_round import AutoRound

ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8)
ar.quantize()

def test_attention_mask_via_tokenize_in_dataset(self):
from transformers import AutoTokenizer

model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-0.6B"
# model_name = "/models/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = ["haha", "hello world"]
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
res.data.pop("attention_mask")
data = [res.data]

text = ["qudd", "hfd"]
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
res.data.pop("attention_mask")
data.append(res.data)
from auto_round import AutoRound

ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8)
ar.quantize()


if __name__ == "__main__":
unittest.main()
20 changes: 20 additions & 0 deletions test/test_cuda/test_main_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,26 @@ def test_autoround_asym(self): ##need to install false
assert accuracy > 0.35
shutil.rmtree("./saved", ignore_errors=True)

def test_attention_mask_lm_head(self):
from transformers import AutoTokenizer

model_name = "/models/Qwen3-8B"
# model_name = "/models/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = ["haha", "hello world"]
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
res.data.pop("attention_mask")
data = [res.data]

text = ["qudd", "hfd"]
res = tokenizer(text, return_tensors="pt", max_length=8, padding="max_length", truncation=True)
res.data.pop("attention_mask")
data.append(res.data)
from auto_round import AutoRound

ar = AutoRound(model_name, iters=1, dataset=data, seqlen=8, quant_lm_head=True)
ar.quantize()


if __name__ == "__main__":
unittest.main()