diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 30fcb2bd6..365523950 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -18,6 +18,7 @@ import json import os from concurrent.futures import ThreadPoolExecutor +from dataclasses import fields from enum import Enum import threadpoolctl as tctl @@ -26,9 +27,10 @@ import transformers from tqdm import tqdm -from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config +from auto_round.export.export_to_autoround.utils import check_neq_config from auto_round.export.utils import save_model from auto_round.logger import logger +from auto_round.schemes import QuantizationScheme from auto_round.utils import ( SUPPORTED_FORMATS, SUPPORTED_LAYER_TYPES, @@ -324,26 +326,20 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex for i in range(len(block_name_to_quantize)): block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".") - for layer_name in layer_config: - if ( - not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8 - ): ##lm head ##TODO fix act and so on - extra_config[layer_name] = {} - extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"] - extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"] - extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"] - extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"] - elif layer_config[layer_name]["in_blocks"] or ( + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for layer_name, cfg in layer_config.items(): + if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head + extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys} + elif cfg["in_blocks"] or ( block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize) ): - neq_keys = check_neq_config( - layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS} - ) + neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) if len(neq_keys) > 0: extra_config[layer_name] = {} - for key in neq_keys: - if layer_config[layer_name][key] is not None: - extra_config[layer_name][key] = layer_config[layer_name][key] + for key in scheme_keys: + if cfg[key] is not None: + extra_config[layer_name][key] = cfg[key] + if len(extra_config) > 0: quantization_config["extra_config"] = extra_config names = list(layer_config.keys()) diff --git a/auto_round/export/export_to_autoround/export_to_fp8.py b/auto_round/export/export_to_autoround/export_to_fp8.py index bfc916419..261f1dbbc 100644 --- a/auto_round/export/export_to_autoround/export_to_fp8.py +++ b/auto_round/export/export_to_autoround/export_to_fp8.py @@ -16,6 +16,7 @@ import json import os from concurrent.futures import ThreadPoolExecutor +from dataclasses import fields import threadpoolctl as tctl import torch @@ -23,9 +24,10 @@ from tqdm import tqdm from auto_round.data_type.utils import reshape_pad_tensor_by_group_size, revert_tensor_by_pad -from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config +from auto_round.export.export_to_autoround.utils import check_neq_config from auto_round.export.utils import save_model from auto_round.logger import logger +from auto_round.schemes import QuantizationScheme from auto_round.utils import ( SUPPORTED_LAYER_TYPES, _get_packing_device, @@ -169,26 +171,20 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round", for i in range(len(block_name_to_quantize)): block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".") - for layer_name in layer_config: - if ( - not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8 - ): ##lm head ##TODO fix act and so on - extra_config[layer_name] = {} - extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"] - extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"] - extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"] - extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"] - elif layer_config[layer_name]["in_blocks"] or ( + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for layer_name, cfg in layer_config.items(): + if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head + extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys} + elif cfg["in_blocks"] or ( block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize) ): - neq_keys = check_neq_config( - layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS} - ) + neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) if len(neq_keys) > 0: extra_config[layer_name] = {} - for key in neq_keys: - if layer_config[layer_name][key] is not None: - extra_config[layer_name][key] = layer_config[layer_name][key] + for key in scheme_keys: + if cfg[key] is not None: + extra_config[layer_name][key] = cfg[key] + if len(extra_config) > 0: quantization_config["extra_config"] = extra_config names = list(layer_config.keys()) diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index eaf3ad9ae..9e3a73533 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -17,6 +17,7 @@ import json import os from concurrent.futures import ThreadPoolExecutor +from dataclasses import fields import threadpoolctl as tctl import torch @@ -24,9 +25,10 @@ import transformers from tqdm import tqdm -from auto_round.export.export_to_autoround.utils import REQUIRED_CONFIG_KEYS, check_neq_config +from auto_round.export.export_to_autoround.utils import check_neq_config from auto_round.export.utils import save_model from auto_round.logger import logger +from auto_round.schemes import QuantizationScheme from auto_round.utils import ( SUPPORTED_LAYER_TYPES, _get_packing_device, @@ -195,26 +197,20 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): for i in range(len(block_name_to_quantize)): block_name_to_quantize[i] = os.path.commonprefix(block_name_to_quantize[i]).rstrip(".") - for layer_name in layer_config: - if ( - not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8 - ): ##lm head # TODO fix act and so on - extra_config[layer_name] = {} - extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"] - extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"] - extra_config[layer_name]["group_size"] = layer_config[layer_name]["group_size"] - extra_config[layer_name]["sym"] = layer_config[layer_name]["sym"] - elif layer_config[layer_name]["in_blocks"] or ( + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for layer_name, cfg in layer_config.items(): + if not cfg["in_blocks"] and cfg["bits"] <= 8: # lm head + extra_config[layer_name] = {key: cfg.get(key) for key in scheme_keys} + elif cfg["in_blocks"] or ( block_name_to_quantize is not None and check_start_with_block_name(layer_name, block_name_to_quantize) ): - neq_keys = check_neq_config( - layer_config[layer_name], **{k: quantization_config[k] for k in REQUIRED_CONFIG_KEYS} - ) + neq_keys = check_neq_config(cfg, **{k: quantization_config[k] for k in scheme_keys}) if len(neq_keys) > 0: extra_config[layer_name] = {} - for key in neq_keys: - if layer_config[layer_name][key] is not None: - extra_config[layer_name][key] = layer_config[layer_name][key] + for key in scheme_keys: + if cfg[key] is not None: + extra_config[layer_name][key] = cfg[key] + if len(extra_config) > 0: quantization_config["extra_config"] = extra_config names = list(layer_config.keys()) diff --git a/auto_round/export/export_to_autoround/utils.py b/auto_round/export/export_to_autoround/utils.py index 19bd92f43..ddef22f1e 100644 --- a/auto_round/export/export_to_autoround/utils.py +++ b/auto_round/export/export_to_autoround/utils.py @@ -12,36 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -REQUIRED_CONFIG_KEYS = ( - "data_type", - "bits", - "group_size", - "sym", - "act_bits", - "act_data_type", - "act_group_size", - "act_sym", - "act_dynamic", -) +from dataclasses import fields +from typing import List +from auto_round.schemes import QuantizationScheme -def check_neq_config(config: dict, **expected) -> dict[str, tuple]: + +def check_neq_config(config: dict, **expected) -> List[str]: """ Compare a config dict against expected values. Ensures all required keys are present in both config and expected. Returns: - dict[str, tuple]: {key: (actual, expected)} for mismatched values. + List[str]: [keys] for mismatched values. """ + scheme_keys = [f.name for f in fields(QuantizationScheme)] # 1. Check missing from expected - missing_expected = [k for k in REQUIRED_CONFIG_KEYS if k not in expected] + missing_expected = [k for k in scheme_keys if k not in expected] if missing_expected: raise ValueError(f"Missing expected values for keys: {missing_expected}") # 2. Check missing from layer config - missing_config = [k for k in REQUIRED_CONFIG_KEYS if k not in config] + missing_config = [k for k in scheme_keys if k not in config] if missing_config: raise ValueError(f"Missing config values for keys: {missing_config}") # 3. Collect mismatches - return {key: (config[key], expected[key]) for key in REQUIRED_CONFIG_KEYS if config[key] != expected[key]} + return [key for key in scheme_keys if config[key] != expected[key] and config[key] is not None] diff --git a/test/test_cpu/test_act_quantization.py b/test/test_cpu/test_act_quantization.py index c12a9014a..dfc387dee 100644 --- a/test/test_cpu/test_act_quantization.py +++ b/test/test_cpu/test_act_quantization.py @@ -22,9 +22,10 @@ def __iter__(self): class TestAutoRoundAct(unittest.TestCase): @classmethod def setUpClass(self): - model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" - self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", trust_remote_code=True) - self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + self.model_name = "/tf_dataset/auto_round/models/facebook/opt-125m" + self.save_dir = "./saved" + self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto", trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) self.llm_dataloader = LLMDataLoader() @classmethod @@ -139,6 +140,160 @@ def test_wfp8afp8_static(self): int(3 * 10 * 768 / 128), ) + def test_act_config_MXFP4_saving(self): + scheme = "MXFP4" + layer_config = {"lm_head": {"act_bits": 8, "bits": 8}, "k_proj": {"act_bits": 8, "bits": 8}} + autoround = AutoRound( + self.model_name, + scheme=scheme, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + quantized_model_path = self.save_dir + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") + lmhead_config = model.config.quantization_config.extra_config["lm_head"] + assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "mx_fp_rceil" + assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8 + assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 32 + assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"] + assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "mx_fp" + assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8 + assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 32 + assert "sym" in lmhead_config.keys() and lmhead_config["sym"] + assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None + assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None + # check inblock layer config values + kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"] + assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "mx_fp_rceil" + assert "act_bits" in kproj_config.keys() and kproj_config["act_bits"] == 8 + assert "act_group_size" in kproj_config.keys() and kproj_config["act_group_size"] == 32 + assert "act_sym" in kproj_config.keys() and kproj_config["act_sym"] + assert "data_type" in kproj_config.keys() and kproj_config["data_type"] == "mx_fp" + assert "bits" in kproj_config.keys() and kproj_config["bits"] == 8 + assert "group_size" in kproj_config.keys() and kproj_config["group_size"] == 32 + assert "sym" in kproj_config.keys() and kproj_config["sym"] + shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_act_config_NVFP4_saving(self): + scheme = "NVFP4" + layer_config = {"k_proj": {"act_bits": 16, "bits": 16}} + autoround = AutoRound( + self.model_name, + scheme=scheme, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + quantized_model_path = self.save_dir + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") + kproj_config = model.config.quantization_config.extra_config["model.decoder.layers.1.self_attn.k_proj"] + assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "nv_fp4_with_static_gs" + assert "act_bits" in kproj_config.keys() and kproj_config["act_bits"] == 16 + assert "act_group_size" in kproj_config.keys() and kproj_config["act_group_size"] == 16 + assert "act_sym" in kproj_config.keys() and kproj_config["act_sym"] + assert "data_type" in kproj_config.keys() and kproj_config["data_type"] == "nv_fp" + assert "bits" in kproj_config.keys() and kproj_config["bits"] == 16 + assert "group_size" in kproj_config.keys() and kproj_config["group_size"] == 16 + assert "sym" in kproj_config.keys() and kproj_config["sym"] + shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_WOQ_config_INT_saving(self): + scheme = "W4A16" + layer_config = {"k_proj": {"bits": 8}} # "lm_head": {"bits": 4}, + autoround = AutoRound( + self.model_name, + scheme=scheme, + iters=2, + seqlen=2, + sym=False, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + quantized_model_path = self.save_dir + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="cpu") + extra_config = model.config.quantization_config.extra_config + # lmhead_config = extra_config["lm_head"] + # assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "float" + # assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 16 + # assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 128 + # assert "act_sym" in lmhead_config.keys() and not lmhead_config["act_sym"] + # assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "int" + # assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 4 + # assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == 128 + # assert "sym" in lmhead_config.keys() and not lmhead_config["sym"] + # assert "act_dynamic" in lmhead_config.keys() and lmhead_config["act_dynamic"] + # assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None + # assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None + + # check inblock layer config values + kproj_config = extra_config["model.decoder.layers.1.self_attn.k_proj"] + assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "float" + assert "act_bits" in kproj_config.keys() and kproj_config["act_bits"] == 16 + assert "act_group_size" in kproj_config.keys() and kproj_config["act_group_size"] == 128 + assert "act_sym" in kproj_config.keys() and not kproj_config["act_sym"] + assert "data_type" in kproj_config.keys() and kproj_config["data_type"] == "int" + assert "bits" in kproj_config.keys() and kproj_config["bits"] == 8 + assert "group_size" in kproj_config.keys() and kproj_config["group_size"] == 128 + assert "sym" in kproj_config.keys() and not kproj_config["sym"] + assert "act_dynamic" in kproj_config.keys() and kproj_config["act_dynamic"] + shutil.rmtree(quantized_model_path, ignore_errors=True) + + def test_act_config_FP8_saving(self): + scheme = "FP8_STATIC" + layer_config = { + "lm_head": {"act_bits": 8, "bits": 8}, + # check fp8 woq config + "k_proj": { + "bits": 8, + "group_size": 0, + "data_type": "fp", + "act_bits": 16, + "act_data_type": "fp", + }, + } + autoround = AutoRound( + self.model_name, + scheme=scheme, + iters=2, + seqlen=2, + dataset=self.llm_dataloader, + layer_config=layer_config, + ) + quantized_model_path = self.save_dir + autoround.quantize_and_save(output_dir=quantized_model_path, format="auto_round") + from transformers import AutoConfig + + extra_config = AutoConfig.from_pretrained(quantized_model_path).quantization_config["extra_config"] + lmhead_config = extra_config["lm_head"] + assert "act_data_type" in lmhead_config.keys() and lmhead_config["act_data_type"] == "fp" + assert "act_bits" in lmhead_config.keys() and lmhead_config["act_bits"] == 8 + assert "act_group_size" in lmhead_config.keys() and lmhead_config["act_group_size"] == 0 + assert "act_sym" in lmhead_config.keys() and lmhead_config["act_sym"] + assert "data_type" in lmhead_config.keys() and lmhead_config["data_type"] == "fp" + assert "bits" in lmhead_config.keys() and lmhead_config["bits"] == 8 + assert "group_size" in lmhead_config.keys() and lmhead_config["group_size"] == -1 + assert "sym" in lmhead_config.keys() and lmhead_config["sym"] + assert "act_dynamic" in lmhead_config.keys() and not lmhead_config["act_dynamic"] + assert "super_bits" in lmhead_config.keys() and lmhead_config["super_bits"] is None + assert "super_group_size" in lmhead_config.keys() and lmhead_config["super_group_size"] is None + # check inblock layer config values + kproj_config = extra_config["model.decoder.layers.0.self_attn.k_proj"] + assert "act_data_type" in kproj_config.keys() and kproj_config["act_data_type"] == "fp" + assert "act_bits" in kproj_config.keys() and kproj_config["act_bits"] == 16 + assert "act_group_size" in kproj_config.keys() and kproj_config["act_group_size"] == 0 + assert "act_sym" in kproj_config.keys() and kproj_config["act_sym"] + assert "data_type" in kproj_config.keys() and kproj_config["data_type"] == "fp" + assert "bits" in kproj_config.keys() and kproj_config["bits"] == 8 + assert "group_size" in kproj_config.keys() and kproj_config["group_size"] == 0 + assert "sym" in kproj_config.keys() and kproj_config["sym"] + shutil.rmtree(quantized_model_path, ignore_errors=True) + if __name__ == "__main__": unittest.main()