diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index 5bac8a419..a219aa787 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -657,7 +657,6 @@ def get_dataloader( Returns: DataLoader: The DataLoader for the calibrated dataset. """ - dataset_names = dataset_name.split(",") def filter_func(example): diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index d6df02bc1..b5a2ff2ac 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -207,35 +207,7 @@ def __init__( ... } """ - if isinstance(scheme, AutoScheme): - if len(scheme.options) <= 0: - raise ValueError("options of AutoScheme must not be empty") - options = [] - for option in scheme.options: - new_option = self._parse_and_set_scheme(option, kwargs) - options.append(new_option) - scheme.options = options - for opt in options: - if isinstance(opt, str) and opt == "BF16": - continue - if isinstance(opt, QuantizationScheme): - if opt.bits >= 16 and (opt.act_bits is None or opt.act_bits >= 16): - continue - self.scheme = opt # Choose the first one that not 16 bits - break - - # apply scheme to set default bits - self._parse_and_set_scheme(self.scheme, kwargs) - - self.is_auto_scheme = True - - else: - self.scheme = self._parse_and_set_scheme(scheme, kwargs) - self.is_auto_scheme = False - - scheme_keys = [f.name for f in fields(QuantizationScheme)] - for key in scheme_keys: - kwargs.pop(key, None) + self.scheme, self.is_auto_scheme = self._parse_and_set_scheme(scheme, kwargs) gguf_scheme_name = get_gguf_scheme(self.scheme) # GGUF uses fp32 scale dtype as default @@ -500,65 +472,105 @@ def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None: def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kwargs) -> QuantizationScheme: """Parse and set the quantization scheme.""" - res = "" - if isinstance(scheme, QuantizationScheme): - scheme = asdict(scheme) - elif isinstance(scheme, dict): - scheme = scheme - elif isinstance(scheme, str): - res = scheme # gguf:q4_k_s and gguf_q4_k_m has the same dict scheme, but the result is different - scheme = scheme.upper() - scheme = asdict(preset_name_to_scheme(scheme)) - scheme_keys = [f.name for f in fields(QuantizationScheme)] - for key in scheme_keys: - if key in kwargs and kwargs[key] is not None: - setattr(self, key, kwargs[key]) - else: - setattr(self, key, scheme.get(key, None)) - # kwargs.pop(key, None) - if self.act_dynamic is None: - self.act_dynamic = True - - tmp_bits = infer_bits_by_data_type(self.data_type) - if tmp_bits is not None and tmp_bits < 16 and tmp_bits != self.bits: - logger.warning(f"'data_type' do not match the specified 'bits' setting. Resetting 'bits' to {tmp_bits}.") - self.bits = tmp_bits - if tmp_bits is not None and tmp_bits < 16: - for supported_dtype in SUPPORTED_DTYPES: # to easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}} - if self.data_type.startswith(supported_dtype): - if supported_dtype + str(tmp_bits) == self.data_type: # could not replace FP8_e4m3 - self.data_type = supported_dtype - break - self.act_group_size = self.act_group_size if self.act_group_size is not None else self.group_size - self.act_bits = self.act_bits if self.act_bits is not None else 16 - self.act_sym = self.act_sym if self.act_sym is not None else self.sym + def _parse_and_set(scheme, kwargs): + res = "" + if isinstance(scheme, QuantizationScheme): + scheme = asdict(scheme) + elif isinstance(scheme, dict): + scheme = scheme + elif isinstance(scheme, str): + # We’d better keep the string scheme instead of the dict config, + # since GGUF uses different mixed-bit strategies for q4_k_s and q4_k_m + # even though they share the same scheme dict. + res = scheme + scheme = scheme.upper() + scheme = asdict(preset_name_to_scheme(scheme)) + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for key in scheme_keys: + if key in kwargs and kwargs[key] is not None: + setattr(self, key, kwargs[key]) + else: + setattr(self, key, scheme.get(key, None)) + # kwargs.pop(key, None) + if self.act_dynamic is None: + self.act_dynamic = True - if self.act_data_type is None: - if self.data_type in SUPPORTED_DTYPES and self.act_bits < 16: - self.act_data_type = self.data_type - logger.info(f"activation adopts {self.data_type}") + tmp_bits = infer_bits_by_data_type(self.data_type) + if tmp_bits is not None and tmp_bits < 16 and tmp_bits != self.bits: + logger.warning( + f"'data_type' do not match the specified 'bits' setting. Resetting 'bits' to {tmp_bits}." + ) + self.bits = tmp_bits + if tmp_bits is not None and tmp_bits < 16: + for ( + supported_dtype + ) in SUPPORTED_DTYPES: # to easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}} + if self.data_type.startswith(supported_dtype): + if supported_dtype + str(tmp_bits) == self.data_type: # could not replace FP8_e4m3 + self.data_type = supported_dtype + break + + self.act_group_size = self.act_group_size if self.act_group_size is not None else self.group_size + self.act_bits = self.act_bits if self.act_bits is not None else 16 + self.act_sym = self.act_sym if self.act_sym is not None else self.sym + + if self.act_data_type is None: + if self.data_type in SUPPORTED_DTYPES and self.act_bits < 16: + self.act_data_type = self.data_type + logger.info(f"activation adopts {self.data_type}") + else: + self.act_data_type = "float" + tmp_act_bits = infer_bits_by_data_type(self.act_data_type) + if tmp_act_bits is not None and tmp_act_bits < 16 and tmp_act_bits != self.act_bits: + self.act_bits = tmp_act_bits + logger.warning( + f"`act_data_type` do not" + f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}." + ) + if tmp_act_bits is not None and tmp_act_bits < 16: + for ( + supported_dtype + ) in SUPPORTED_DTYPES: # To easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}} + if self.act_data_type.startswith(supported_dtype): + if supported_dtype + str(tmp_act_bits) == self.act_data_type: # Could not replace FP8_e4m3 + self.act_data_type = supported_dtype + break + for key in scheme_keys: + scheme[key] = getattr(self, key) + if res and QuantizationScheme.from_dict(scheme) == preset_name_to_scheme(res): + return res else: - self.act_data_type = "float" - tmp_act_bits = infer_bits_by_data_type(self.act_data_type) - if tmp_act_bits is not None and tmp_act_bits < 16 and tmp_act_bits != self.act_bits: - self.act_bits = tmp_act_bits - logger.warning( - f"`act_data_type` do not" - f" match the specified 'act_bits' setting. Resetting 'act_bits' to {tmp_act_bits}." - ) - if tmp_act_bits is not None and tmp_act_bits < 16: - for supported_dtype in SUPPORTED_DTYPES: # To easily handle dtype mx_fp4 and layer_config={xxx:{bits:8}} - if self.act_data_type.startswith(supported_dtype): - if supported_dtype + str(tmp_act_bits) == self.act_data_type: # Could not replace FP8_e4m3 - self.act_data_type = supported_dtype - break - for key in scheme_keys: - scheme[key] = getattr(self, key) - if res and QuantizationScheme.from_dict(scheme) == preset_name_to_scheme(res): - return res + return QuantizationScheme.from_dict(scheme) + + if isinstance(scheme, AutoScheme): + if len(scheme.options) <= 0: + raise ValueError("options of AutoScheme must not be empty") + options = [] + for option in scheme.options: + new_option = _parse_and_set(option, kwargs) + options.append(new_option) + scheme.options = options + for opt in options: + if isinstance(opt, str) and opt == "BF16": + continue + if isinstance(opt, QuantizationScheme): + if opt.bits >= 16 and (opt.act_bits is None or opt.act_bits >= 16): + continue + self.scheme = opt # Choose the first one that not 16 bits + break + # apply scheme to set default bits + scheme = _parse_and_set(self.scheme, kwargs) + is_auto_scheme = True else: - return QuantizationScheme.from_dict(scheme) + scheme = _parse_and_set(scheme, kwargs) + is_auto_scheme = False + + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for key in scheme_keys: + kwargs.pop(key, None) + + return scheme, is_auto_scheme def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: """Sets the torch compile configuration for the tuning.""" diff --git a/auto_round/utils.py b/auto_round/utils.py index 29d1474ee..afb7b2940 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -1958,9 +1958,9 @@ def _set_config(config, target_config): config_tmp.pop(key, None) matched_scheme = get_gguf_scheme(QuantizationScheme.from_dict(config_tmp)) # check matched if not matched_scheme: - if config.get("super_group_size", None) is not None: + if config.get("super_group_size", None) is not None or config.get("super_bits", None) is not None: new_type = new_type[:bits_index] + str(config["bits"]) + "_k" - if config.get("super_group_size", None) is None or new_type not in GGUF_INNER_CONFIG: + if new_type not in GGUF_INNER_CONFIG: prefix_idx = 0 if config.get("sym", True) else 1 new_type = new_type[:bits_index] + str(config["bits"]) + f"_{prefix_idx}" if new_type not in GGUF_INNER_CONFIG: @@ -1992,7 +1992,8 @@ def _set_config(config, target_config): elif new_type != "gguf:q8_0": new_type = "gguf:q6_k" elif lm_head_name is not None and layer_name == lm_head_name and tie_word_embeddings: - pass + # new_type = GGUF_CONFIG[target_gguf_format]["lm_head"] + continue elif isinstance(layer, torch.nn.Embedding): if "embedding" in GGUF_CONFIG[target_gguf_format]: new_type = GGUF_CONFIG[target_gguf_format]["embedding"] @@ -2914,7 +2915,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str if hasattr(model, "config") and hasattr(model.config, "tie_word_embeddings"): tie_word_embeddings = model.config.tie_word_embeddings - if quant_lm_head and tie_word_embeddings: + if quant_lm_head and tie_word_embeddings and not gguf_name: quant_lm_head = False logger.warning( "reset `quant_lm_head` to false as quantizing " "lm_head with tied weights has not been supported currently" @@ -2966,6 +2967,7 @@ def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str return layer_config, has_qlayer_outside_block, regex_config # embed + lm_head defaults for gguf + tie_word_embeddings &= not is_separate_lm_head(model) if lm_head_name not in layer_config and not tie_word_embeddings: cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["lm_head"]] cfg = {**cfg, "fixed_by_user": False, "scale_dtype": default_scale_dtype} @@ -3024,6 +3026,30 @@ def is_diffusion_model(model_or_path: Union[str, object]) -> bool: return False +def is_separate_lm_head(model: torch.nn.Module) -> bool: + dir_path = model.name_or_path + if not os.path.isdir(dir_path): + dir_path = download_hf_model(dir_path) + lm_head_name: str = get_lm_head_name(model) + lm_head_name += ".weight" + + if "model.safetensors.index.json" in os.listdir(dir_path): + with open(os.path.join(dir_path, "model.safetensors.index.json")) as f: + index_mapping = json.load(f) + if lm_head_name in index_mapping["weight_map"]: + return True + else: + return False + else: + from safetensors import safe_open + + f = safe_open(os.path.join(dir_path, "model.safetensors"), framework="pt") + if lm_head_name in f.keys(): + return True + else: + return False + + def to_standard_regex(pattern: str) -> str: """ Convert a user-specified string into a standardized regex for layer matching. diff --git a/test/test_cpu/test_gguf_format.py b/test/test_cpu/test_gguf_format.py index d71920b39..7505db913 100644 --- a/test/test_cpu/test_gguf_format.py +++ b/test/test_cpu/test_gguf_format.py @@ -12,6 +12,7 @@ class LLMDataLoader: + def __init__(self): self.batch_size = 1 @@ -21,11 +22,10 @@ def __iter__(self): class TestGGUF(unittest.TestCase): + @classmethod def setUpClass(self): self.model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2.5-0.5B-Instruct" - self.model_name = "Qwen/Qwen2.5-0.5B-Instruct" - self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", trust_remote_code=True) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) self.llm_dataloader = LLMDataLoader() @@ -55,8 +55,7 @@ def test_basic_usage(self): def test_q4_0(self): bits, group_size, sym = 4, 32, True autoround = AutoRound( - self.model, - self.tokenizer, + self.model_name, bits=bits, group_size=group_size, sym=sym, @@ -103,8 +102,7 @@ def test_q4_0(self): def test_func(self): bits, group_size, sym = 4, 128, True autoround = AutoRound( - self.model, - self.tokenizer, + self.model_name, # bits=bits, # group_size=group_size, # sym=sym, @@ -336,6 +334,77 @@ def test_vlm_gguf(self): self.assertAlmostEqual(file_size, 892, delta=1.0) shutil.rmtree("./saved", ignore_errors=True) + def test_qtype_setting(self): + # Qwen2.5-0.5B-Instruct no output, token_embed q6_k fallbakc to q8_0 336M + # Qwen3-0.6B output q6_k, token_embed q4_0 448M + # Qwen3-8B output q6_k, token_embed q4_0 4.5G + # Llama-3.2-1B-Instruct o output, token_embed q6_k 736M + from auto_round.export.export_to_gguf.config import ModelType + from auto_round.utils import get_layer_config_by_gguf_format, set_layer_config + + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2.5-0.5B-Instruct" + ar = AutoRound(model=model_name, scheme="gguf:q4_0", iters=0) + ar.formats = ["gguf:q4_0"] + ar.layer_config, _, _ = set_layer_config( + ar.model, + ar.layer_config, + ar.scheme, + ar.scale_dtype, + ar.supported_types, + ar.inner_supported_types, + ar.quant_block_list, + ar.fp_layers, + ar.quant_lm_head, + enable_gguf_official_mixed=True, + is_mllm=ar.mllm, + ) + self.assertTrue(ar.layer_config["model.embed_tokens"]["bits"] == 8) + self.assertTrue("lm_head" not in ar.layer_config) + + model_name = "Qwen/Qwen3-0.6B" + ar = AutoRound(model=model_name, scheme="gguf:q4_0", iters=0) + ar.formats = ["gguf:q4_0"] + ar.layer_config, _, _ = set_layer_config( + ar.model, + ar.layer_config, + ar.scheme, + ar.scale_dtype, + ar.supported_types, + ar.inner_supported_types, + ar.quant_block_list, + ar.fp_layers, + ar.quant_lm_head, + enable_gguf_official_mixed=True, + is_mllm=ar.mllm, + ) + self.assertTrue(ar.layer_config["model.embed_tokens"]["bits"] == 4) + self.assertTrue(ar.layer_config["lm_head"]["bits"] == 6 and ar.layer_config["lm_head"]["super_bits"] == 8) + + layer_config = { + "model.embed_tokens": {"bits": 6, "super_bits": 8}, + "lm_head": {"bits": 4}, + } + ar = AutoRound(model=model_name, scheme="gguf:q4_0", iters=0, layer_config=layer_config) + ar.formats = ["gguf:q4_0"] + ar.layer_config, _, _ = set_layer_config( + ar.model, + ar.layer_config, + ar.scheme, + ar.scale_dtype, + ar.supported_types, + ar.inner_supported_types, + ar.quant_block_list, + ar.fp_layers, + ar.quant_lm_head, + enable_gguf_official_mixed=True, + is_mllm=ar.mllm, + ) + self.assertTrue(ar.layer_config["lm_head"]["bits"] == 4) + self.assertTrue( + ar.layer_config["model.embed_tokens"]["bits"] == 6 + and ar.layer_config["model.embed_tokens"]["super_bits"] == 8 + ) + if __name__ == "__main__": unittest.main()