From f733cfa9a5bab27e2fa2d6e0428d08d3329bfef4 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 16 Oct 2025 03:25:49 +0000 Subject: [PATCH 1/3] fix mxfp load Signed-off-by: root --- auto_round/inference/backend.py | 12 ++-- auto_round/testing_utils.py | 8 +++ test/test_cpu/test_mxfp4_save_load.py | 78 ++++++++++++++++++++++ test/test_cuda/test_mxfp_and_nvfp_quant.py | 9 +-- 4 files changed, 95 insertions(+), 12 deletions(-) create mode 100644 test/test_cpu/test_mxfp4_save_load.py diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index aa74f37e1..b56aa327b 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -107,6 +107,10 @@ class BackendInfo: "act_dynamic", ] +MX_TENSOR_DATA_TYPES = [ + "mx_fp", + "mx_fp_rceil", +] def feature_multiply_checker(in_feature, out_feature, config, in_feature_multiplier, out_feature_multiplier=None): if out_feature_multiplier is None: @@ -230,13 +234,13 @@ def fp8_static_scheme_checker( packing_format=LLM_COMPRESSOR_FORMAT, sym=[True], compute_dtype=["float32", "float16", "bfloat16"], - data_type=["mx_fp", "max_fp_rceil"], + data_type=MX_TENSOR_DATA_TYPES, group_size=[32], bits=[8], act_bits=[8], act_group_size=[32], act_sym=[True], - act_data_type=["mx_fp_rceil"], + act_data_type=MX_TENSOR_DATA_TYPES, act_dynamic=[True], priority=0, checkers=[feature_multiply_checker_32], @@ -250,13 +254,13 @@ def fp8_static_scheme_checker( packing_format=LLM_COMPRESSOR_FORMAT, sym=[True], compute_dtype=["float32", "float16", "bfloat16"], - data_type=["mx_fp"], + data_type=MX_TENSOR_DATA_TYPES, group_size=[32], bits=[4], act_bits=[4], act_group_size=[32], act_sym=[True], - act_data_type=["mx_fp_rceil"], + act_data_type=MX_TENSOR_DATA_TYPES, act_dynamic=[True], priority=0, checkers=[feature_multiply_checker_32], diff --git a/auto_round/testing_utils.py b/auto_round/testing_utils.py index 637c84146..d67329c2d 100644 --- a/auto_round/testing_utils.py +++ b/auto_round/testing_utils.py @@ -268,3 +268,11 @@ def decorator(test_func: Callable) -> Callable: return unittest.skipUnless(require_package_version(package, version_spec, on_fail="skip"), reason)(test_func) return decorator + + +def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool: + """Check if the model contains a specific module type.""" + for _, module in model.named_modules(): + if isinstance(module, target_module_type): + return True + return False diff --git a/test/test_cpu/test_mxfp4_save_load.py b/test/test_cpu/test_mxfp4_save_load.py new file mode 100644 index 000000000..cde86968f --- /dev/null +++ b/test/test_cpu/test_mxfp4_save_load.py @@ -0,0 +1,78 @@ +import shutil +import tempfile + +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM + +from auto_round import AutoRound +from auto_round import schemes as ar_schemes +from auto_round.experimental import qmodules as ar_qmodules +from auto_round.export.export_to_autoround import AutoRoundFormat +from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp +from auto_round.inference.backend import MX_TENSOR_DATA_TYPES +from auto_round.testing_utils import has_module + +testing_scheme_name_lst = [ + AutoRoundFormat.MXFP8.value, + AutoRoundFormat.MXFP4.value, +] +QMODULE_MAPPING = { + AutoRoundFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, + AutoRoundFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear, +} +SCHEMES_MAPPING = { + AutoRoundFormat.MXFP8.value: ar_schemes.MXFP8, + AutoRoundFormat.MXFP4.value: ar_schemes.MXFP4, +} + + +@pytest.mark.parametrize("scheme_name", testing_scheme_name_lst) +@pytest.mark.parametrize("weight_data_type", MX_TENSOR_DATA_TYPES) +@pytest.mark.parametrize("act_data_type", MX_TENSOR_DATA_TYPES) +@torch.inference_mode() +def test_e2e_quant_and_load(scheme_name, weight_data_type, act_data_type): + # Use a temporary directory for saving the quantized model + with tempfile.TemporaryDirectory() as temp_dir: + # FIXME: use CI model + model_name = "Qwen/Qwen2.5-0.5B-Instruct" + config = AutoConfig.from_pretrained(model_name) + config.num_hidden_layers = 2 # Use a smaller model for testing + + # Load the tokenizer and model + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = Qwen2ForCausalLM(config) + # model = AutoModelForCausalLM.from_pretrained( + # # model_name, + + # device_map="cpu", + # torch_dtype="auto", + # trust_remote_code=True, + # ) + + scheme = SCHEMES_MAPPING[scheme_name] + scheme.data_type = weight_data_type + scheme.act_data_type = act_data_type + # Initialize AutoRound for quantization + autoround = AutoRound( + model, + tokenizer, + scheme=scheme, + iters=0, + nsamples=2, + ) + + # Quantize and save the model to the temporary directory + quantized_model_path = f"{temp_dir}/tmp_autoround" + autoround.quantize_and_save(format="auto_round", output_dir=quantized_model_path) + + # Perform inference with the quantized model + model = AutoModelForCausalLM.from_pretrained( + quantized_model_path, + torch_dtype="auto", + ) + model.eval() + assert has_module( + model, QMODULE_MAPPING[scheme_name] + ), f"Expected {QMODULE_MAPPING[scheme_name].__name__} in the model." diff --git a/test/test_cuda/test_mxfp_and_nvfp_quant.py b/test/test_cuda/test_mxfp_and_nvfp_quant.py index d15cde5be..dda3bbeb8 100644 --- a/test/test_cuda/test_mxfp_and_nvfp_quant.py +++ b/test/test_cuda/test_mxfp_and_nvfp_quant.py @@ -10,7 +10,7 @@ from auto_round.experimental import qmodules as ar_qmodules from auto_round.export.export_to_autoround import AutoRoundFormat from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp - +from auto_round.testing_utils import has_module testing_schemes = [AutoRoundFormat.MXFP8.value, AutoRoundFormat.MXFP4.value, AutoRoundFormat.NVFP4.value] QMODULE_MAPPING = { AutoRoundFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, @@ -19,13 +19,6 @@ } -def has_module(model: torch.nn.Module, target_module_type: torch.nn.Module) -> bool: - """Check if the model contains a specific module type.""" - for _, module in model.named_modules(): - if isinstance(module, target_module_type): - return True - return False - @pytest.mark.parametrize("scheme", testing_schemes) @torch.inference_mode() From d33d0e6d598d9e1b511ccc6e42ffcbef87984f76 Mon Sep 17 00:00:00 2001 From: yiliu30 Date: Thu, 16 Oct 2025 22:53:52 -0400 Subject: [PATCH 2/3] extend mxfp data types Signed-off-by: yiliu30 --- ...test_mxfp4_save_load.py => test_mxfp_save_load.py} | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) rename test/test_cpu/{test_mxfp4_save_load.py => test_mxfp_save_load.py} (90%) diff --git a/test/test_cpu/test_mxfp4_save_load.py b/test/test_cpu/test_mxfp_save_load.py similarity index 90% rename from test/test_cpu/test_mxfp4_save_load.py rename to test/test_cpu/test_mxfp_save_load.py index cde86968f..374ccf5ce 100644 --- a/test/test_cpu/test_mxfp4_save_load.py +++ b/test/test_cpu/test_mxfp_save_load.py @@ -35,22 +35,13 @@ def test_e2e_quant_and_load(scheme_name, weight_data_type, act_data_type): # Use a temporary directory for saving the quantized model with tempfile.TemporaryDirectory() as temp_dir: - # FIXME: use CI model - model_name = "Qwen/Qwen2.5-0.5B-Instruct" + model_name = "/tf_dataset/auto_round/models/Qwen/Qwen2.5-0.5B-Instruct" config = AutoConfig.from_pretrained(model_name) config.num_hidden_layers = 2 # Use a smaller model for testing # Load the tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = Qwen2ForCausalLM(config) - # model = AutoModelForCausalLM.from_pretrained( - # # model_name, - - # device_map="cpu", - # torch_dtype="auto", - # trust_remote_code=True, - # ) - scheme = SCHEMES_MAPPING[scheme_name] scheme.data_type = weight_data_type scheme.act_data_type = act_data_type From f5d024f66a6c29df3c9b7bf8442f8fa8f802c8d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Oct 2025 02:55:09 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/inference/backend.py | 1 + test/test_cuda/test_mxfp_and_nvfp_quant.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/inference/backend.py b/auto_round/inference/backend.py index b56aa327b..0bc5b35b3 100644 --- a/auto_round/inference/backend.py +++ b/auto_round/inference/backend.py @@ -112,6 +112,7 @@ class BackendInfo: "mx_fp_rceil", ] + def feature_multiply_checker(in_feature, out_feature, config, in_feature_multiplier, out_feature_multiplier=None): if out_feature_multiplier is None: out_feature_multiplier = in_feature_multiplier diff --git a/test/test_cuda/test_mxfp_and_nvfp_quant.py b/test/test_cuda/test_mxfp_and_nvfp_quant.py index dda3bbeb8..0dc43b093 100644 --- a/test/test_cuda/test_mxfp_and_nvfp_quant.py +++ b/test/test_cuda/test_mxfp_and_nvfp_quant.py @@ -11,6 +11,7 @@ from auto_round.export.export_to_autoround import AutoRoundFormat from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp from auto_round.testing_utils import has_module + testing_schemes = [AutoRoundFormat.MXFP8.value, AutoRoundFormat.MXFP4.value, AutoRoundFormat.NVFP4.value] QMODULE_MAPPING = { AutoRoundFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, @@ -19,7 +20,6 @@ } - @pytest.mark.parametrize("scheme", testing_schemes) @torch.inference_mode() def test_e2e_quant_and_infer(scheme):