Skip to content
Merged
1 change: 0 additions & 1 deletion auto_round/calib_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,6 @@ def get_dataloader(
Returns:
DataLoader: The DataLoader for the calibrated dataset.
"""

dataset_names = dataset_name.split(",")

def filter_func(example):
Expand Down
180 changes: 96 additions & 84 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,35 +207,7 @@ def __init__(
... }
"""

if isinstance(scheme, AutoScheme):
if len(scheme.options) <= 0:
raise ValueError("options of AutoScheme must not be empty")
options = []
for option in scheme.options:
new_option = self._parse_and_set_scheme(option, kwargs)
options.append(new_option)
scheme.options = options
for opt in options:
if isinstance(opt, str) and opt == "BF16":
continue
if isinstance(opt, QuantizationScheme):
if opt.bits >= 16 and (opt.act_bits is None or opt.act_bits >= 16):
continue
self.scheme = opt # Choose the first one that not 16 bits
break

# apply scheme to set default bits
self._parse_and_set_scheme(self.scheme, kwargs)

self.is_auto_scheme = True

else:
self.scheme = self._parse_and_set_scheme(scheme, kwargs)
self.is_auto_scheme = False

scheme_keys = [f.name for f in fields(QuantizationScheme)]
for key in scheme_keys:
kwargs.pop(key, None)
self.scheme, self.is_auto_scheme = self._parse_and_set_scheme(scheme, kwargs)

gguf_scheme_name = get_gguf_scheme(self.scheme)
# GGUF uses fp32 scale dtype as default
Expand Down Expand Up @@ -500,65 +472,105 @@ def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None:

def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kwargs) -> QuantizationScheme:
"""Parse and set the quantization scheme."""
res = ""
if isinstance(scheme, QuantizationScheme):
scheme = asdict(scheme)
elif isinstance(scheme, dict):
scheme = scheme
elif isinstance(scheme, str):
res = scheme # gguf:q4_k_s and gguf_q4_k_m has the same dict scheme, but the result is different
scheme = scheme.upper()
scheme = asdict(preset_name_to_scheme(scheme))
scheme_keys = [f.name for f in fields(QuantizationScheme)]
for key in scheme_keys:
if key in kwargs and kwargs[key] is not None:
setattr(self, key, kwargs[key])
else:
setattr(self, key, scheme.get(key, None))
# kwargs.pop(key, None)
if self.act_dynamic is None:
self.act_dynamic = True

tmp_bits = infer_bits_by_data_type(self.data_type)
if tmp_bits is not None and tmp_bits < 16 and tmp_bits != self.bits:
logger.warning(f"'data_type' do not match the specified 'bits' setting. Resetting 'bits' to {tmp_bits}.")
self.bits = tmp_bits
if tmp_bits is not None and tmp_bits < 16:
for supported_dtype in SUPPORTED_DTYPES: # to easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}}
if self.data_type.startswith(supported_dtype):
if supported_dtype + str(tmp_bits) == self.data_type: # could not replace FP8_e4m3
self.data_type = supported_dtype
break

self.act_group_size = self.act_group_size if self.act_group_size is not None else self.group_size
self.act_bits = self.act_bits if self.act_bits is not None else 16
self.act_sym = self.act_sym if self.act_sym is not None else self.sym
def _parse_and_set(scheme, kwargs):
res = ""
if isinstance(scheme, QuantizationScheme):
scheme = asdict(scheme)
elif isinstance(scheme, dict):
scheme = scheme
elif isinstance(scheme, str):
# We’d better keep the string scheme instead of the dict config,
# since GGUF uses different mixed-bit strategies for q4_k_s and q4_k_m
# even though they share the same scheme dict.
res = scheme
scheme = scheme.upper()
scheme = asdict(preset_name_to_scheme(scheme))
scheme_keys = [f.name for f in fields(QuantizationScheme)]
for key in scheme_keys:
if key in kwargs and kwargs[key] is not None:
setattr(self, key, kwargs[key])
else:
setattr(self, key, scheme.get(key, None))
# kwargs.pop(key, None)
if self.act_dynamic is None:
self.act_dynamic = True

if self.act_data_type is None:
if self.data_type in SUPPORTED_DTYPES and self.act_bits < 16:
self.act_data_type = self.data_type
logger.info(f"activation adopts {self.data_type}")
tmp_bits = infer_bits_by_data_type(self.data_type)
if tmp_bits is not None and tmp_bits < 16 and tmp_bits != self.bits:
logger.warning(
f"'data_type' do not match the specified 'bits' setting. Resetting 'bits' to {tmp_bits}."
)
self.bits = tmp_bits
if tmp_bits is not None and tmp_bits < 16:
for (
supported_dtype
) in SUPPORTED_DTYPES: # to easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}}
if self.data_type.startswith(supported_dtype):
if supported_dtype + str(tmp_bits) == self.data_type: # could not replace FP8_e4m3
self.data_type = supported_dtype
break

self.act_group_size = self.act_group_size if self.act_group_size is not None else self.group_size
self.act_bits = self.act_bits if self.act_bits is not None else 16
self.act_sym = self.act_sym if self.act_sym is not None else self.sym

if self.act_data_type is None:
if self.data_type in SUPPORTED_DTYPES and self.act_bits < 16:
self.act_data_type = self.data_type
logger.info(f"activation adopts {self.data_type}")
else:
self.act_data_type = "float"
tmp_act_bits = infer_bits_by_data_type(self.act_data_type)
if tmp_act_bits is not None and tmp_act_bits < 16 and tmp_act_bits != self.act_bits:
self.act_bits = tmp_act_bits
logger.warning(
f"`act_data_type` do not"
f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}."
)
if tmp_act_bits is not None and tmp_act_bits < 16:
for (
supported_dtype
) in SUPPORTED_DTYPES: # To easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}}
if self.act_data_type.startswith(supported_dtype):
if supported_dtype + str(tmp_act_bits) == self.act_data_type: # Could not replace FP8_e4m3
self.act_data_type = supported_dtype
break
for key in scheme_keys:
scheme[key] = getattr(self, key)
if res and QuantizationScheme.from_dict(scheme) == preset_name_to_scheme(res):
return res
else:
self.act_data_type = "float"
tmp_act_bits = infer_bits_by_data_type(self.act_data_type)
if tmp_act_bits is not None and tmp_act_bits < 16 and tmp_act_bits != self.act_bits:
self.act_bits = tmp_act_bits
logger.warning(
f"`act_data_type` do not"
f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}."
)
if tmp_act_bits is not None and tmp_act_bits < 16:
for supported_dtype in SUPPORTED_DTYPES: # To easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}}
if self.act_data_type.startswith(supported_dtype):
if supported_dtype + str(tmp_act_bits) == self.act_data_type: # Could not replace FP8_e4m3
self.act_data_type = supported_dtype
break
for key in scheme_keys:
scheme[key] = getattr(self, key)
if res and QuantizationScheme.from_dict(scheme) == preset_name_to_scheme(res):
return res
return QuantizationScheme.from_dict(scheme)

if isinstance(scheme, AutoScheme):
if len(scheme.options) <= 0:
raise ValueError("options of AutoScheme must not be empty")
options = []
for option in scheme.options:
new_option = _parse_and_set(option, kwargs)
options.append(new_option)
scheme.options = options
for opt in options:
if isinstance(opt, str) and opt == "BF16":
continue
if isinstance(opt, QuantizationScheme):
if opt.bits >= 16 and (opt.act_bits is None or opt.act_bits >= 16):
continue
self.scheme = opt # Choose the first one that not 16 bits
break
# apply scheme to set default bits
scheme = _parse_and_set(self.scheme, kwargs)
is_auto_scheme = True
else:
return QuantizationScheme.from_dict(scheme)
scheme = _parse_and_set(scheme, kwargs)
is_auto_scheme = False

scheme_keys = [f.name for f in fields(QuantizationScheme)]
for key in scheme_keys:
kwargs.pop(key, None)

return scheme, is_auto_scheme

def _adjust_torch_compile(self, enable_torch_compile: bool) -> None:
"""Sets the torch compile configuration for the tuning."""
Expand Down
34 changes: 30 additions & 4 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1958,9 +1958,9 @@ def _set_config(config, target_config):
config_tmp.pop(key, None)
matched_scheme = get_gguf_scheme(QuantizationScheme.from_dict(config_tmp)) # check matched
if not matched_scheme:
if config.get("super_group_size", None) is not None:
if config.get("super_group_size", None) is not None or config.get("super_bits", None) is not None:
new_type = new_type[:bits_index] + str(config["bits"]) + "_k"
if config.get("super_group_size", None) is None or new_type not in GGUF_INNER_CONFIG:
if new_type not in GGUF_INNER_CONFIG:
prefix_idx = 0 if config.get("sym", True) else 1
new_type = new_type[:bits_index] + str(config["bits"]) + f"_{prefix_idx}"
if new_type not in GGUF_INNER_CONFIG:
Expand Down Expand Up @@ -1992,7 +1992,8 @@ def _set_config(config, target_config):
elif new_type != "gguf:q8_0":
new_type = "gguf:q6_k"
elif lm_head_name is not None and layer_name == lm_head_name and tie_word_embeddings:
pass
# new_type = GGUF_CONFIG[target_gguf_format]["lm_head"]
continue
elif isinstance(layer, torch.nn.Embedding):
if "embedding" in GGUF_CONFIG[target_gguf_format]:
new_type = GGUF_CONFIG[target_gguf_format]["embedding"]
Expand Down Expand Up @@ -2914,7 +2915,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
if hasattr(model, "config") and hasattr(model.config, "tie_word_embeddings"):
tie_word_embeddings = model.config.tie_word_embeddings

if quant_lm_head and tie_word_embeddings:
if quant_lm_head and tie_word_embeddings and not gguf_name:
quant_lm_head = False
logger.warning(
"reset `quant_lm_head` to false as quantizing " "lm_head with tied weights has not been supported currently"
Expand Down Expand Up @@ -2966,6 +2967,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str
return layer_config, has_qlayer_outside_block, regex_config

# embed + lm_head defaults for gguf
tie_word_embeddings &= not is_separate_lm_head(model)
if lm_head_name not in layer_config and not tie_word_embeddings:
cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["lm_head"]]
cfg = {**cfg, "fixed_by_user": False, "scale_dtype": default_scale_dtype}
Expand Down Expand Up @@ -3024,6 +3026,30 @@ def is_diffusion_model(model_or_path: Union[str, object]) -> bool:
return False


def is_separate_lm_head(model: torch.nn.Module) -> bool:
dir_path = model.name_or_path
if not os.path.isdir(dir_path):
dir_path = download_hf_model(dir_path)
lm_head_name: str = get_lm_head_name(model)
lm_head_name += ".weight"

if "model.safetensors.index.json" in os.listdir(dir_path):
with open(os.path.join(dir_path, "model.safetensors.index.json")) as f:
index_mapping = json.load(f)
if lm_head_name in index_mapping["weight_map"]:
return True
else:
return False
else:
from safetensors import safe_open

f = safe_open(os.path.join(dir_path, "model.safetensors"), framework="pt")
if lm_head_name in f.keys():
return True
else:
return False


def to_standard_regex(pattern: str) -> str:
"""
Convert a user-specified string into a standardized regex for layer matching.
Expand Down
81 changes: 75 additions & 6 deletions test/test_cpu/test_gguf_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


class LLMDataLoader:

def __init__(self):
self.batch_size = 1

Expand All @@ -21,11 +22,10 @@ def __iter__(self):


class TestGGUF(unittest.TestCase):

@classmethod
def setUpClass(self):
self.model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2.5-0.5B-Instruct"
self.model_name = "Qwen/Qwen2.5-0.5B-Instruct"
self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
self.llm_dataloader = LLMDataLoader()

Expand Down Expand Up @@ -55,8 +55,7 @@ def test_basic_usage(self):
def test_q4_0(self):
bits, group_size, sym = 4, 32, True
autoround = AutoRound(
self.model,
self.tokenizer,
self.model_name,
bits=bits,
group_size=group_size,
sym=sym,
Expand Down Expand Up @@ -103,8 +102,7 @@ def test_q4_0(self):
def test_func(self):
bits, group_size, sym = 4, 128, True
autoround = AutoRound(
self.model,
self.tokenizer,
self.model_name,
# bits=bits,
# group_size=group_size,
# sym=sym,
Expand Down Expand Up @@ -336,6 +334,77 @@ def test_vlm_gguf(self):
self.assertAlmostEqual(file_size, 892, delta=1.0)
shutil.rmtree("./saved", ignore_errors=True)

def test_qtype_setting(self):
# Qwen2.5-0.5B-Instruct no output, token_embed q6_k fallbakc to q8_0 336M
# Qwen3-0.6B output q6_k, token_embed q4_0 448M
# Qwen3-8B output q6_k, token_embed q4_0 4.5G
# Llama-3.2-1B-Instruct o output, token_embed q6_k 736M
from auto_round.export.export_to_gguf.config import ModelType
from auto_round.utils import get_layer_config_by_gguf_format, set_layer_config

model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2.5-0.5B-Instruct"
ar = AutoRound(model=model_name, scheme="gguf:q4_0", iters=0)
ar.formats = ["gguf:q4_0"]
ar.layer_config, _, _ = set_layer_config(
ar.model,
ar.layer_config,
ar.scheme,
ar.scale_dtype,
ar.supported_types,
ar.inner_supported_types,
ar.quant_block_list,
ar.fp_layers,
ar.quant_lm_head,
enable_gguf_official_mixed=True,
is_mllm=ar.mllm,
)
self.assertTrue(ar.layer_config["model.embed_tokens"]["bits"] == 8)
self.assertTrue("lm_head" not in ar.layer_config)

model_name = "Qwen/Qwen3-0.6B"
ar = AutoRound(model=model_name, scheme="gguf:q4_0", iters=0)
ar.formats = ["gguf:q4_0"]
ar.layer_config, _, _ = set_layer_config(
ar.model,
ar.layer_config,
ar.scheme,
ar.scale_dtype,
ar.supported_types,
ar.inner_supported_types,
ar.quant_block_list,
ar.fp_layers,
ar.quant_lm_head,
enable_gguf_official_mixed=True,
is_mllm=ar.mllm,
)
self.assertTrue(ar.layer_config["model.embed_tokens"]["bits"] == 4)
self.assertTrue(ar.layer_config["lm_head"]["bits"] == 6 and ar.layer_config["lm_head"]["super_bits"] == 8)

layer_config = {
"model.embed_tokens": {"bits": 6, "super_bits": 8},
"lm_head": {"bits": 4},
}
ar = AutoRound(model=model_name, scheme="gguf:q4_0", iters=0, layer_config=layer_config)
ar.formats = ["gguf:q4_0"]
ar.layer_config, _, _ = set_layer_config(
ar.model,
ar.layer_config,
ar.scheme,
ar.scale_dtype,
ar.supported_types,
ar.inner_supported_types,
ar.quant_block_list,
ar.fp_layers,
ar.quant_lm_head,
enable_gguf_official_mixed=True,
is_mllm=ar.mllm,
)
self.assertTrue(ar.layer_config["lm_head"]["bits"] == 4)
self.assertTrue(
ar.layer_config["model.embed_tokens"]["bits"] == 6
and ar.layer_config["model.embed_tokens"]["super_bits"] == 8
)


if __name__ == "__main__":
unittest.main()